Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-22 22:36:05 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-22 22:36:05 +0800
Commit
5a5d029490d8ccc4516566545b7556f4a1254fb5
5a5d0294
1 parent
ef93dcd7
Add build script for Android armv8a (#58)
隐藏空白字符变更
内嵌
并排对比
正在显示
16 个修改的文件
包含
400 行增加
和
74 行删除
build-android-arm64-v8a.sh
cmake/onnxruntime.cmake
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/online-lstm-transducer-model.cc
sherpa-onnx/csrc/online-lstm-transducer-model.h
sherpa-onnx/csrc/online-recognizer.cc
sherpa-onnx/csrc/online-recognizer.h
sherpa-onnx/csrc/online-transducer-model.cc
sherpa-onnx/csrc/online-transducer-model.h
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
sherpa-onnx/csrc/online-zipformer-transducer-model.h
sherpa-onnx/csrc/onnx-utils.cc
sherpa-onnx/csrc/onnx-utils.h
sherpa-onnx/csrc/symbol-table.cc
sherpa-onnx/csrc/symbol-table.h
sherpa-onnx/jni/jni.cc
build-android-arm64-v8a.sh
0 → 100755
查看文件 @
5a5d029
#!/usr/bin/env bash
set
-ex
dir
=
build-android-arm64-v8a
mkdir -p
$dir
cd
$dir
# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
# (optional) remove the hardcoded debug flag in Android NDK android-ndk
# issue: https://github.com/android/ndk/issues/243
#
# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23
# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23
#
# delete "-g" line
#
# list(APPEND ANDROID_COMPILER_FLAGS
# -g
# -DANDROID
if
[
-z
$ANDROID_NDK
]
;
then
ANDROID_NDK
=
/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
# or use
# ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
#
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"
if
[
! -d
$ANDROID_NDK
]
;
then
# For macOS, I have installed Android Studio, select the menu
# Tools -> SDK manager -> Android SDK
# and set "Android SDK location" to /Users/fangjun/software/my-android
ANDROID_NDK
=
/Users/fangjun/software/my-android/ndk/22.1.7171670
fi
fi
if
[
! -d
$ANDROID_NDK
]
;
then
echo
Please
set
the environment variable ANDROID_NDK before you run this script
exit
1
fi
echo
"ANDROID_NDK:
$ANDROID_NDK
"
sleep 1
cmake -DCMAKE_TOOLCHAIN_FILE
=
"
$ANDROID_NDK
/build/cmake/android.toolchain.cmake"
\
-DCMAKE_BUILD_TYPE
=
Release
\
-DBUILD_SHARED_LIBS
=
ON
\
-DSHERPA_ONNX_ENABLE_PYTHON
=
OFF
\
-DSHERPA_ONNX_ENABLE_TESTS
=
OFF
\
-DSHERPA_ONNX_ENABLE_CHECK
=
OFF
\
-DSHERPA_ONNX_ENABLE_PORTAUDIO
=
OFF
\
-DSHERPA_ONNX_ENABLE_JNI
=
ON
\
-DCMAKE_INSTALL_PREFIX
=
./install
\
-DANDROID_ABI
=
"arm64-v8a"
\
-DANDROID_PLATFORM
=
android-21 ..
# make VERBOSE=1 -j4
make -j4
make install/strip
...
...
cmake/onnxruntime.cmake
查看文件 @
5a5d029
function
(
download_onnxruntime
)
include
(
FetchContent
)
if
(
CMAKE_SYSTEM_NAME STREQUAL Linux
)
if
(
CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64
)
# For embedded systems
set
(
possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz
${
PROJECT_SOURCE_DIR
}
/onnxruntime-linux-aarch64-1.14.0.tgz
${
PROJECT_BINARY_DIR
}
/onnxruntime-linux-aarch64-1.14.0.tgz
/tmp/onnxruntime-linux-aarch64-1.14.0.tgz
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz
)
set
(
onnxruntime_URL
"https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz"
)
set
(
onnxruntime_HASH
"SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86"
)
else
()
# If you don't have access to the Internet,
# please pre-download onnxruntime
set
(
possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz
${
PROJECT_SOURCE_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
${
PROJECT_BINARY_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
/tmp/onnxruntime-linux-x64-1.14.0.tgz
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
)
set
(
onnxruntime_URL
"https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz"
)
set
(
onnxruntime_HASH
"SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090"
)
# After downloading, it contains:
# ./lib/libonnxruntime.so.1.14.0
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0
#
# ./include
# It contains all the needed header files
endif
()
message
(
STATUS
"CMAKE_SYSTEM_NAME:
${
CMAKE_SYSTEM_NAME
}
"
)
message
(
STATUS
"CMAKE_SYSTEM_PROCESSOR:
${
CMAKE_SYSTEM_PROCESSOR
}
"
)
if
(
CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64
)
# For embedded systems
set
(
possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz
${
PROJECT_SOURCE_DIR
}
/onnxruntime-linux-aarch64-1.14.0.tgz
${
PROJECT_BINARY_DIR
}
/onnxruntime-linux-aarch64-1.14.0.tgz
/tmp/onnxruntime-linux-aarch64-1.14.0.tgz
/star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz
)
set
(
onnxruntime_URL
"https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz"
)
set
(
onnxruntime_HASH
"SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86"
)
elseif
(
CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64
)
# If you don't have access to the Internet,
# please pre-download onnxruntime
set
(
possible_file_locations
$ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz
${
PROJECT_SOURCE_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
${
PROJECT_BINARY_DIR
}
/onnxruntime-linux-x64-1.14.0.tgz
/tmp/onnxruntime-linux-x64-1.14.0.tgz
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
)
set
(
onnxruntime_URL
"https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz"
)
set
(
onnxruntime_HASH
"SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090"
)
# After downloading, it contains:
# ./lib/libonnxruntime.so.1.14.0
# ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0
#
# ./include
# It contains all the needed header files
elseif
(
APPLE
)
# If you don't have access to the Internet,
# please pre-download onnxruntime
...
...
@@ -69,6 +69,8 @@ function(download_onnxruntime)
# ./include
# It contains all the needed header files
else
()
message
(
STATUS
"CMAKE_SYSTEM_NAME:
${
CMAKE_SYSTEM_NAME
}
"
)
message
(
STATUS
"CMAKE_SYSTEM_PROCESSOR:
${
CMAKE_SYSTEM_PROCESSOR
}
"
)
message
(
FATAL_ERROR
"Only support Linux, macOS, and Windows at present. Will support other OSes later"
)
endif
()
...
...
@@ -91,11 +93,15 @@ function(download_onnxruntime)
endif
()
message
(
STATUS
"onnxruntime is downloaded to
${
onnxruntime_SOURCE_DIR
}
"
)
find_library
(
location_onnxruntime onnxruntime
PATHS
"
${
onnxruntime_SOURCE_DIR
}
/lib"
NO_CMAKE_SYSTEM_PATH
)
if
(
ANDROID
)
set
(
location_onnxruntime
${
onnxruntime_SOURCE_DIR
}
/lib/libonnxruntime.so
)
else
()
find_library
(
location_onnxruntime onnxruntime
PATHS
"
${
onnxruntime_SOURCE_DIR
}
/lib"
NO_CMAKE_SYSTEM_PATH
)
endif
()
message
(
STATUS
"location_onnxruntime:
${
location_onnxruntime
}
"
)
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
5a5d029
...
...
@@ -26,6 +26,10 @@ endif()
add_library
(
sherpa-onnx-core
${
sources
}
)
if
(
ANDROID_NDK
)
target_link_libraries
(
sherpa-onnx-core android log
)
endif
()
target_link_libraries
(
sherpa-onnx-core
onnxruntime
kaldi-native-fbank-core
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.cc
查看文件 @
5a5d029
...
...
@@ -12,6 +12,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
...
...
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
sess_opts_
.
SetIntraOpNumThreads
(
config
.
num_threads
);
sess_opts_
.
SetInterOpNumThreads
(
config
.
num_threads
);
InitEncoder
(
config
.
encoder_filename
);
InitDecoder
(
config
.
decoder_filename
);
InitJoiner
(
config
.
joiner_filename
);
{
auto
buf
=
ReadFile
(
config
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel
::
OnlineLstmTransducerModel
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
),
config_
(
config
),
sess_opts_
{},
allocator_
{}
{
sess_opts_
.
SetIntraOpNumThreads
(
config
.
num_threads
);
sess_opts_
.
SetInterOpNumThreads
(
config
.
num_threads
);
{
auto
buf
=
ReadFile
(
mgr
,
config
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#endif
void
OnlineLstmTransducerModel
::
InitEncoder
(
const
std
::
string
&
filename
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineLstmTransducerModel
::
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
encoder_sess_
.
get
(),
&
encoder_input_names_
,
&
encoder_input_names_ptr_
);
...
...
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA
(
d_model_
,
"d_model"
);
}
void
OnlineLstmTransducerModel
::
InitDecoder
(
const
std
::
string
&
filename
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineLstmTransducerModel
::
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
decoder_sess_
.
get
(),
&
decoder_input_names_
,
&
decoder_input_names_ptr_
);
...
...
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA
(
context_size_
,
"context_size"
);
}
void
OnlineLstmTransducerModel
::
InitJoiner
(
const
std
::
string
&
filename
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineLstmTransducerModel
::
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
joiner_sess_
.
get
(),
&
joiner_input_names_
,
&
joiner_input_names_ptr_
);
...
...
sherpa-onnx/csrc/online-lstm-transducer-model.h
查看文件 @
5a5d029
...
...
@@ -9,6 +9,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
...
...
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
public
:
explicit
OnlineLstmTransducerModel
(
const
OnlineTransducerModelConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlineLstmTransducerModel
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
);
#endif
std
::
vector
<
Ort
::
Value
>
StackStates
(
const
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
&
states
)
const
override
;
...
...
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
OrtAllocator
*
Allocator
()
override
{
return
allocator_
;
}
private
:
void
InitEncoder
(
const
std
::
string
&
encoder_filename
);
void
InitDecoder
(
const
std
::
string
&
decoder_filename
);
void
InitJoiner
(
const
std
::
string
&
joiner_filename
);
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
);
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
);
void
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
);
private
:
Ort
::
Env
env_
;
...
...
sherpa-onnx/csrc/online-recognizer.cc
查看文件 @
5a5d029
...
...
@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
}
#if __ANDROID_API__ >= 9
explicit
Impl
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
config_
(
config
),
model_
(
OnlineTransducerModel
::
Create
(
mgr
,
config
.
model_config
)),
sym_
(
mgr
,
config
.
tokens
),
endpoint_
(
config_
.
endpoint_config
)
{
decoder_
=
std
::
make_unique
<
OnlineTransducerGreedySearchDecoder
>
(
model_
.
get
());
}
#endif
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
{
auto
stream
=
std
::
make_unique
<
OnlineStream
>
(
config_
.
feat_config
);
stream
->
SetResult
(
decoder_
->
GetEmptyResult
());
...
...
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
OnlineRecognizer
::
OnlineRecognizer
(
const
OnlineRecognizerConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
#if __ANDROID_API__ >= 9
OnlineRecognizer
::
OnlineRecognizer
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
mgr
,
config
))
{}
#endif
OnlineRecognizer
::~
OnlineRecognizer
()
=
default
;
std
::
unique_ptr
<
OnlineStream
>
OnlineRecognizer
::
CreateStream
()
const
{
...
...
sherpa-onnx/csrc/online-recognizer.h
查看文件 @
5a5d029
...
...
@@ -8,6 +8,11 @@
#include <memory>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-stream.h"
...
...
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
class
OnlineRecognizer
{
public
:
explicit
OnlineRecognizer
(
const
OnlineRecognizerConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlineRecognizer
(
AAssetManager
*
mgr
,
const
OnlineRecognizerConfig
&
config
);
#endif
~
OnlineRecognizer
();
/// Create a stream for decoding.
...
...
sherpa-onnx/csrc/online-transducer-model.cc
查看文件 @
5a5d029
...
...
@@ -3,6 +3,11 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <memory>
#include <sstream>
#include <string>
...
...
@@ -18,15 +23,16 @@ enum class ModelType {
kUnkown
,
};
static
ModelType
GetModelType
(
const
OnlineTransducerModelConfig
&
config
)
{
static
ModelType
GetModelType
(
char
*
model_data
,
size_t
model_data_length
,
bool
debug
)
{
Ort
::
Env
env
(
ORT_LOGGING_LEVEL_WARNING
);
Ort
::
SessionOptions
sess_opts
;
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
SHERPA_MAYBE_WIDE
(
config
.
encoder_filename
).
c_str
(),
sess_opts
);
auto
sess
=
std
::
make_unique
<
Ort
::
Session
>
(
env
,
model_data
,
model_data_length
,
sess_opts
);
Ort
::
ModelMetadata
meta_data
=
sess
->
GetModelMetadata
();
if
(
config
.
debug
)
{
if
(
debug
)
{
std
::
ostringstream
os
;
PrintModelMetadata
(
os
,
meta_data
);
fprintf
(
stderr
,
"%s
\n
"
,
os
.
str
().
c_str
());
...
...
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
std
::
unique_ptr
<
OnlineTransducerModel
>
OnlineTransducerModel
::
Create
(
const
OnlineTransducerModelConfig
&
config
)
{
auto
model_type
=
GetModelType
(
config
);
auto
buffer
=
ReadFile
(
config
.
encoder_filename
);
auto
model_type
=
GetModelType
(
buffer
.
data
(),
buffer
.
size
(),
config
.
debug
);
switch
(
model_type
)
{
case
ModelType
:
:
kLstm
:
...
...
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return
nullptr
;
}
#if __ANDROID_API__ >= 9
std
::
unique_ptr
<
OnlineTransducerModel
>
OnlineTransducerModel
::
Create
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
)
{
auto
buffer
=
ReadFile
(
mgr
,
config
.
encoder_filename
);
auto
model_type
=
GetModelType
(
buffer
.
data
(),
buffer
.
size
(),
config
.
debug
);
switch
(
model_type
)
{
case
ModelType
:
:
kLstm
:
return
std
::
make_unique
<
OnlineLstmTransducerModel
>
(
mgr
,
config
);
case
ModelType
:
:
kZipformer
:
return
std
::
make_unique
<
OnlineZipformerTransducerModel
>
(
mgr
,
config
);
case
ModelType
:
:
kUnkown
:
return
nullptr
;
}
// unreachable code
return
nullptr
;
}
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/online-transducer-model.h
查看文件 @
5a5d029
...
...
@@ -8,6 +8,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
...
...
@@ -22,6 +27,11 @@ class OnlineTransducerModel {
static
std
::
unique_ptr
<
OnlineTransducerModel
>
Create
(
const
OnlineTransducerModelConfig
&
config
);
#if __ANDROID_API__ >= 9
static
std
::
unique_ptr
<
OnlineTransducerModel
>
Create
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
);
#endif
/** Stack a list of individual states into a batch.
*
* It is the inverse operation of `UnStackStates`.
...
...
sherpa-onnx/csrc/online-zipformer-transducer-model.cc
查看文件 @
5a5d029
...
...
@@ -13,6 +13,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
...
...
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
sess_opts_
.
SetIntraOpNumThreads
(
config
.
num_threads
);
sess_opts_
.
SetInterOpNumThreads
(
config
.
num_threads
);
InitEncoder
(
config
.
encoder_filename
);
InitDecoder
(
config
.
decoder_filename
);
InitJoiner
(
config
.
joiner_filename
);
{
auto
buf
=
ReadFile
(
config
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#if __ANDROID_API__ >= 9
OnlineZipformerTransducerModel
::
OnlineZipformerTransducerModel
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
),
config_
(
config
),
sess_opts_
{},
allocator_
{}
{
sess_opts_
.
SetIntraOpNumThreads
(
config
.
num_threads
);
sess_opts_
.
SetInterOpNumThreads
(
config
.
num_threads
);
{
auto
buf
=
ReadFile
(
mgr
,
config
.
encoder_filename
);
InitEncoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
decoder_filename
);
InitDecoder
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
joiner_filename
);
InitJoiner
(
buf
.
data
(),
buf
.
size
());
}
}
#endif
void
OnlineZipformerTransducerModel
::
InitEncoder
(
const
std
::
string
&
filename
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineZipformerTransducerModel
::
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
encoder_sess_
.
get
(),
&
encoder_input_names_
,
&
encoder_input_names_ptr_
);
...
...
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
}
}
void
OnlineZipformerTransducerModel
::
InitDecoder
(
const
std
::
string
&
filename
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineZipformerTransducerModel
::
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
decoder_sess_
.
get
(),
&
decoder_input_names_
,
&
decoder_input_names_ptr_
);
...
...
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
SHERPA_ONNX_READ_META_DATA
(
context_size_
,
"context_size"
);
}
void
OnlineZipformerTransducerModel
::
InitJoiner
(
const
std
::
string
&
filename
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
SHERPA_MAYBE_WIDE
(
filename
).
c_str
(),
sess_opts_
);
void
OnlineZipformerTransducerModel
::
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
joiner_sess_
.
get
(),
&
joiner_input_names_
,
&
joiner_input_names_ptr_
);
...
...
sherpa-onnx/csrc/online-zipformer-transducer-model.h
查看文件 @
5a5d029
...
...
@@ -9,6 +9,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
...
...
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
explicit
OnlineZipformerTransducerModel
(
const
OnlineTransducerModelConfig
&
config
);
#if __ANDROID_API__ >= 9
OnlineZipformerTransducerModel
(
AAssetManager
*
mgr
,
const
OnlineTransducerModelConfig
&
config
);
#endif
std
::
vector
<
Ort
::
Value
>
StackStates
(
const
std
::
vector
<
std
::
vector
<
Ort
::
Value
>>
&
states
)
const
override
;
...
...
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
OrtAllocator
*
Allocator
()
override
{
return
allocator_
;
}
private
:
void
InitEncoder
(
const
std
::
string
&
encoder_filename
);
void
InitDecoder
(
const
std
::
string
&
decoder_filename
);
void
InitJoiner
(
const
std
::
string
&
joiner_filename
);
void
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
);
void
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
);
void
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
);
private
:
Ort
::
Env
env_
;
...
...
sherpa-onnx/csrc/onnx-utils.cc
查看文件 @
5a5d029
...
...
@@ -3,9 +3,16 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/onnx-utils.h"
#include <fstream>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#include "android/log.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
namespace
sherpa_onnx
{
...
...
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
fprintf
(
stderr
,
"
\n
"
);
}
std
::
vector
<
char
>
ReadFile
(
const
std
::
string
&
filename
)
{
std
::
ifstream
input
(
filename
,
std
::
ios
::
binary
);
std
::
vector
<
char
>
buffer
(
std
::
istreambuf_iterator
<
char
>
(
input
),
{});
return
buffer
;
}
#if __ANDROID_API__ >= 9
std
::
vector
<
char
>
ReadFile
(
AAssetManager
*
mgr
,
const
std
::
string
&
filename
)
{
AAsset
*
asset
=
AAssetManager_open
(
mgr
,
filename
.
c_str
(),
AASSET_MODE_BUFFER
);
if
(
!
asset
)
{
__android_log_print
(
ANDROID_LOG_FATAL
,
"sherpa-onnx"
,
"Read binary file: Load %s failed"
,
filename
.
c_str
());
exit
(
-
1
);
}
auto
p
=
reinterpret_cast
<
const
char
*>
(
AAsset_getBuffer
(
asset
));
size_t
asset_length
=
AAsset_getLength
(
asset
);
AAsset_close
(
asset
);
std
::
vector
<
char
>
buffer
(
p
,
p
+
asset_length
);
return
buffer
;
}
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/onnx-utils.h
查看文件 @
5a5d029
...
...
@@ -14,6 +14,11 @@
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
namespace
sherpa_onnx
{
...
...
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
std
::
fill
(
p
,
p
+
n
,
value
);
}
std
::
vector
<
char
>
ReadFile
(
const
std
::
string
&
filename
);
#if __ANDROID_API__ >= 9
std
::
vector
<
char
>
ReadFile
(
AAssetManager
*
mgr
,
const
std
::
string
&
filename
);
#endif
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
...
...
sherpa-onnx/csrc/symbol-table.cc
查看文件 @
5a5d029
...
...
@@ -7,11 +7,32 @@
#include <cassert>
#include <fstream>
#include <sstream>
#include <strstream>
#include "sherpa-onnx/csrc/onnx-utils.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
namespace
sherpa_onnx
{
SymbolTable
::
SymbolTable
(
const
std
::
string
&
filename
)
{
std
::
ifstream
is
(
filename
);
Init
(
is
);
}
#if __ANDROID_API__ >= 9
SymbolTable
::
SymbolTable
(
AAssetManager
*
mgr
,
const
std
::
string
&
filename
)
{
auto
buf
=
ReadFile
(
mgr
,
filename
);
std
::
istrstream
is
(
buf
.
data
(),
buf
.
size
());
Init
(
is
);
}
#endif
void
SymbolTable
::
Init
(
std
::
istream
&
is
)
{
std
::
string
sym
;
int32_t
id
;
while
(
is
>>
sym
>>
id
)
{
...
...
sherpa-onnx/csrc/symbol-table.h
查看文件 @
5a5d029
...
...
@@ -8,6 +8,11 @@
#include <string>
#include <unordered_map>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
namespace
sherpa_onnx
{
/// It manages mapping between symbols and integer IDs.
...
...
@@ -22,6 +27,10 @@ class SymbolTable {
/// Fields are separated by space(s).
explicit
SymbolTable
(
const
std
::
string
&
filename
);
#if __ANDROID_API__ >= 9
SymbolTable
(
AAssetManager
*
mgr
,
const
std
::
string
&
filename
);
#endif
/// Return a string representation of this symbol table
std
::
string
ToString
()
const
;
...
...
@@ -37,6 +46,9 @@ class SymbolTable {
bool
contains
(
const
std
::
string
&
sym
)
const
;
private
:
void
Init
(
std
::
istream
&
is
);
private
:
std
::
unordered_map
<
std
::
string
,
int32_t
>
sym2id_
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id2sym_
;
};
...
...
sherpa-onnx/jni/jni.cc
查看文件 @
5a5d029
...
...
@@ -20,7 +20,7 @@
#endif
#if __ANDROID_API__ >= 8
#include
<android/log.h>
#include
"android/log.h"
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, ##__VA_ARGS__); \
...
...
请
注册
或
登录
后发表评论