Fangjun Kuang
Committed by GitHub

Add build script for Android armv8a (#58)

  1 +#!/usr/bin/env bash
  2 +set -ex
  3 +
  4 +dir=build-android-arm64-v8a
  5 +
  6 +mkdir -p $dir
  7 +cd $dir
  8 +
  9 +# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
  10 +# (optional) remove the hardcoded debug flag in Android NDK android-ndk
  11 +# issue: https://github.com/android/ndk/issues/243
  12 +#
  13 +# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23
  14 +# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23
  15 +#
  16 +# delete "-g" line
  17 +#
  18 +# list(APPEND ANDROID_COMPILER_FLAGS
  19 +# -g
  20 +# -DANDROID
  21 +
  22 +
  23 +if [ -z $ANDROID_NDK ]; then
  24 + ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
  25 + # or use
  26 + # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
  27 + #
  28 + # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
  29 + # and some other files like the file "build/cmake/android.toolchain.cmake"
  30 +
  31 + if [ ! -d $ANDROID_NDK ]; then
  32 + # For macOS, I have installed Android Studio, select the menu
  33 + # Tools -> SDK manager -> Android SDK
  34 + # and set "Android SDK location" to /Users/fangjun/software/my-android
  35 + ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670
  36 + fi
  37 +fi
  38 +
  39 +if [ ! -d $ANDROID_NDK ]; then
  40 + echo Please set the environment variable ANDROID_NDK before you run this script
  41 + exit 1
  42 +fi
  43 +
  44 +echo "ANDROID_NDK: $ANDROID_NDK"
  45 +sleep 1
  46 +
  47 +cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
  48 + -DCMAKE_BUILD_TYPE=Release \
  49 + -DBUILD_SHARED_LIBS=ON \
  50 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  51 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  52 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  53 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  54 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  55 + -DCMAKE_INSTALL_PREFIX=./install \
  56 + -DANDROID_ABI="arm64-v8a" \
  57 + -DANDROID_PLATFORM=android-21 ..
  58 +# make VERBOSE=1 -j4
  59 +make -j4
  60 +make install/strip
  61 +
1 function(download_onnxruntime) 1 function(download_onnxruntime)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - if(CMAKE_SYSTEM_NAME STREQUAL Linux)  
5 - if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)  
6 - # For embedded systems  
7 - set(possible_file_locations  
8 - $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz  
9 - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz  
10 - ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz  
11 - /tmp/onnxruntime-linux-aarch64-1.14.0.tgz  
12 - /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz  
13 - )  
14 - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")  
15 - set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")  
16 -  
17 - else()  
18 - # If you don't have access to the Internet,  
19 - # please pre-download onnxruntime  
20 - set(possible_file_locations  
21 - $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz  
22 - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz  
23 - ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz  
24 - /tmp/onnxruntime-linux-x64-1.14.0.tgz  
25 - /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz  
26 - )  
27 -  
28 - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")  
29 - set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090")  
30 - # After downloading, it contains:  
31 - # ./lib/libonnxruntime.so.1.14.0  
32 - # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0  
33 - #  
34 - # ./include  
35 - # It contains all the needed header files  
36 - endif() 4 + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
  5 + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
  6 +
  7 + if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
  8 + # For embedded systems
  9 + set(possible_file_locations
  10 + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz
  11 + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
  12 + ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz
  13 + /tmp/onnxruntime-linux-aarch64-1.14.0.tgz
  14 + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz
  15 + )
  16 + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz")
  17 + set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86")
  18 + elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64)
  19 + # If you don't have access to the Internet,
  20 + # please pre-download onnxruntime
  21 + set(possible_file_locations
  22 + $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz
  23 + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
  24 + ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
  25 + /tmp/onnxruntime-linux-x64-1.14.0.tgz
  26 + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
  27 + )
  28 +
  29 + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
  30 + set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090")
  31 + # After downloading, it contains:
  32 + # ./lib/libonnxruntime.so.1.14.0
  33 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0
  34 + #
  35 + # ./include
  36 + # It contains all the needed header files
37 elseif(APPLE) 37 elseif(APPLE)
38 # If you don't have access to the Internet, 38 # If you don't have access to the Internet,
39 # please pre-download onnxruntime 39 # please pre-download onnxruntime
@@ -69,6 +69,8 @@ function(download_onnxruntime) @@ -69,6 +69,8 @@ function(download_onnxruntime)
69 # ./include 69 # ./include
70 # It contains all the needed header files 70 # It contains all the needed header files
71 else() 71 else()
  72 + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
  73 + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
72 message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") 74 message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
73 endif() 75 endif()
74 76
@@ -91,11 +93,15 @@ function(download_onnxruntime) @@ -91,11 +93,15 @@ function(download_onnxruntime)
91 endif() 93 endif()
92 message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") 94 message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
93 95
94 - find_library(location_onnxruntime onnxruntime  
95 - PATHS  
96 - "${onnxruntime_SOURCE_DIR}/lib"  
97 - NO_CMAKE_SYSTEM_PATH  
98 - ) 96 + if(ANDROID)
  97 + set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so)
  98 + else()
  99 + find_library(location_onnxruntime onnxruntime
  100 + PATHS
  101 + "${onnxruntime_SOURCE_DIR}/lib"
  102 + NO_CMAKE_SYSTEM_PATH
  103 + )
  104 + endif()
99 105
100 message(STATUS "location_onnxruntime: ${location_onnxruntime}") 106 message(STATUS "location_onnxruntime: ${location_onnxruntime}")
101 107
@@ -26,6 +26,10 @@ endif() @@ -26,6 +26,10 @@ endif()
26 26
27 add_library(sherpa-onnx-core ${sources}) 27 add_library(sherpa-onnx-core ${sources})
28 28
  29 +if(ANDROID_NDK)
  30 + target_link_libraries(sherpa-onnx-core android log)
  31 +endif()
  32 +
29 target_link_libraries(sherpa-onnx-core 33 target_link_libraries(sherpa-onnx-core
30 onnxruntime 34 onnxruntime
31 kaldi-native-fbank-core 35 kaldi-native-fbank-core
@@ -12,6 +12,11 @@ @@ -12,6 +12,11 @@
12 #include <utility> 12 #include <utility>
13 #include <vector> 13 #include <vector>
14 14
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
15 #include "onnxruntime_cxx_api.h" // NOLINT 20 #include "onnxruntime_cxx_api.h" // NOLINT
16 #include "sherpa-onnx/csrc/cat.h" 21 #include "sherpa-onnx/csrc/cat.h"
17 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
@@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( @@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
30 sess_opts_.SetIntraOpNumThreads(config.num_threads); 35 sess_opts_.SetIntraOpNumThreads(config.num_threads);
31 sess_opts_.SetInterOpNumThreads(config.num_threads); 36 sess_opts_.SetInterOpNumThreads(config.num_threads);
32 37
33 - InitEncoder(config.encoder_filename);  
34 - InitDecoder(config.decoder_filename);  
35 - InitJoiner(config.joiner_filename); 38 + {
  39 + auto buf = ReadFile(config.encoder_filename);
  40 + InitEncoder(buf.data(), buf.size());
  41 + }
  42 +
  43 + {
  44 + auto buf = ReadFile(config.decoder_filename);
  45 + InitDecoder(buf.data(), buf.size());
  46 + }
  47 +
  48 + {
  49 + auto buf = ReadFile(config.joiner_filename);
  50 + InitJoiner(buf.data(), buf.size());
  51 + }
  52 +}
  53 +
  54 +#if __ANDROID_API__ >= 9
  55 +OnlineLstmTransducerModel::OnlineLstmTransducerModel(
  56 + AAssetManager *mgr, const OnlineTransducerModelConfig &config)
  57 + : env_(ORT_LOGGING_LEVEL_WARNING),
  58 + config_(config),
  59 + sess_opts_{},
  60 + allocator_{} {
  61 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  62 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  63 +
  64 + {
  65 + auto buf = ReadFile(mgr, config.encoder_filename);
  66 + InitEncoder(buf.data(), buf.size());
  67 + }
  68 +
  69 + {
  70 + auto buf = ReadFile(mgr, config.decoder_filename);
  71 + InitDecoder(buf.data(), buf.size());
  72 + }
  73 +
  74 + {
  75 + auto buf = ReadFile(mgr, config.joiner_filename);
  76 + InitJoiner(buf.data(), buf.size());
  77 + }
36 } 78 }
  79 +#endif
37 80
38 -void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {  
39 - encoder_sess_ = std::make_unique<Ort::Session>(  
40 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 81 +void OnlineLstmTransducerModel::InitEncoder(void *model_data,
  82 + size_t model_data_length) {
  83 + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  84 + model_data_length, sess_opts_);
41 85
42 GetInputNames(encoder_sess_.get(), &encoder_input_names_, 86 GetInputNames(encoder_sess_.get(), &encoder_input_names_,
43 &encoder_input_names_ptr_); 87 &encoder_input_names_ptr_);
@@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { @@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) {
62 SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); 106 SHERPA_ONNX_READ_META_DATA(d_model_, "d_model");
63 } 107 }
64 108
65 -void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {  
66 - decoder_sess_ = std::make_unique<Ort::Session>(  
67 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 109 +void OnlineLstmTransducerModel::InitDecoder(void *model_data,
  110 + size_t model_data_length) {
  111 + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  112 + model_data_length, sess_opts_);
68 113
69 GetInputNames(decoder_sess_.get(), &decoder_input_names_, 114 GetInputNames(decoder_sess_.get(), &decoder_input_names_,
70 &decoder_input_names_ptr_); 115 &decoder_input_names_ptr_);
@@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { @@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) {
86 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); 131 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
87 } 132 }
88 133
89 -void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {  
90 - joiner_sess_ = std::make_unique<Ort::Session>(  
91 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 134 +void OnlineLstmTransducerModel::InitJoiner(void *model_data,
  135 + size_t model_data_length) {
  136 + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  137 + model_data_length, sess_opts_);
92 138
93 GetInputNames(joiner_sess_.get(), &joiner_input_names_, 139 GetInputNames(joiner_sess_.get(), &joiner_input_names_,
94 &joiner_input_names_ptr_); 140 &joiner_input_names_ptr_);
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
12 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 18 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
14 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
19 public: 24 public:
20 explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); 25 explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
21 26
  27 +#if __ANDROID_API__ >= 9
  28 + OnlineLstmTransducerModel(AAssetManager *mgr,
  29 + const OnlineTransducerModelConfig &config);
  30 +#endif
  31 +
22 std::vector<Ort::Value> StackStates( 32 std::vector<Ort::Value> StackStates(
23 const std::vector<std::vector<Ort::Value>> &states) const override; 33 const std::vector<std::vector<Ort::Value>> &states) const override;
24 34
@@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
47 OrtAllocator *Allocator() override { return allocator_; } 57 OrtAllocator *Allocator() override { return allocator_; }
48 58
49 private: 59 private:
50 - void InitEncoder(const std::string &encoder_filename);  
51 - void InitDecoder(const std::string &decoder_filename);  
52 - void InitJoiner(const std::string &joiner_filename); 60 + void InitEncoder(void *model_data, size_t model_data_length);
  61 + void InitDecoder(void *model_data, size_t model_data_length);
  62 + void InitJoiner(void *model_data, size_t model_data_length);
53 63
54 private: 64 private:
55 Ort::Env env_; 65 Ort::Env env_;
@@ -55,6 +55,17 @@ class OnlineRecognizer::Impl { @@ -55,6 +55,17 @@ class OnlineRecognizer::Impl {
55 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 55 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
56 } 56 }
57 57
  58 +#if __ANDROID_API__ >= 9
  59 + explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)
  60 + : config_(config),
  61 + model_(OnlineTransducerModel::Create(mgr, config.model_config)),
  62 + sym_(mgr, config.tokens),
  63 + endpoint_(config_.endpoint_config) {
  64 + decoder_ =
  65 + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  66 + }
  67 +#endif
  68 +
58 std::unique_ptr<OnlineStream> CreateStream() const { 69 std::unique_ptr<OnlineStream> CreateStream() const {
59 auto stream = std::make_unique<OnlineStream>(config_.feat_config); 70 auto stream = std::make_unique<OnlineStream>(config_.feat_config);
60 stream->SetResult(decoder_->GetEmptyResult()); 71 stream->SetResult(decoder_->GetEmptyResult());
@@ -156,6 +167,13 @@ class OnlineRecognizer::Impl { @@ -156,6 +167,13 @@ class OnlineRecognizer::Impl {
156 167
157 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) 168 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
158 : impl_(std::make_unique<Impl>(config)) {} 169 : impl_(std::make_unique<Impl>(config)) {}
  170 +
  171 +#if __ANDROID_API__ >= 9
  172 +OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
  173 + const OnlineRecognizerConfig &config)
  174 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  175 +#endif
  176 +
159 OnlineRecognizer::~OnlineRecognizer() = default; 177 OnlineRecognizer::~OnlineRecognizer() = default;
160 178
161 std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { 179 std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <memory> 8 #include <memory>
9 #include <string> 9 #include <string>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "sherpa-onnx/csrc/endpoint.h" 16 #include "sherpa-onnx/csrc/endpoint.h"
12 #include "sherpa-onnx/csrc/features.h" 17 #include "sherpa-onnx/csrc/features.h"
13 #include "sherpa-onnx/csrc/online-stream.h" 18 #include "sherpa-onnx/csrc/online-stream.h"
@@ -45,6 +50,11 @@ struct OnlineRecognizerConfig { @@ -45,6 +50,11 @@ struct OnlineRecognizerConfig {
45 class OnlineRecognizer { 50 class OnlineRecognizer {
46 public: 51 public:
47 explicit OnlineRecognizer(const OnlineRecognizerConfig &config); 52 explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
  53 +
  54 +#if __ANDROID_API__ >= 9
  55 + OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config);
  56 +#endif
  57 +
48 ~OnlineRecognizer(); 58 ~OnlineRecognizer();
49 59
50 /// Create a stream for decoding. 60 /// Create a stream for decoding.
@@ -3,6 +3,11 @@ @@ -3,6 +3,11 @@
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/online-transducer-model.h" 4 #include "sherpa-onnx/csrc/online-transducer-model.h"
5 5
  6 +#if __ANDROID_API__ >= 9
  7 +#include "android/asset_manager.h"
  8 +#include "android/asset_manager_jni.h"
  9 +#endif
  10 +
6 #include <memory> 11 #include <memory>
7 #include <sstream> 12 #include <sstream>
8 #include <string> 13 #include <string>
@@ -18,15 +23,16 @@ enum class ModelType { @@ -18,15 +23,16 @@ enum class ModelType {
18 kUnkown, 23 kUnkown,
19 }; 24 };
20 25
21 -static ModelType GetModelType(const OnlineTransducerModelConfig &config) { 26 +static ModelType GetModelType(char *model_data, size_t model_data_length,
  27 + bool debug) {
22 Ort::Env env(ORT_LOGGING_LEVEL_WARNING); 28 Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
23 Ort::SessionOptions sess_opts; 29 Ort::SessionOptions sess_opts;
24 30
25 - auto sess = std::make_unique<Ort::Session>(  
26 - env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); 31 + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
  32 + sess_opts);
27 33
28 Ort::ModelMetadata meta_data = sess->GetModelMetadata(); 34 Ort::ModelMetadata meta_data = sess->GetModelMetadata();
29 - if (config.debug) { 35 + if (debug) {
30 std::ostringstream os; 36 std::ostringstream os;
31 PrintModelMetadata(os, meta_data); 37 PrintModelMetadata(os, meta_data);
32 fprintf(stderr, "%s\n", os.str().c_str()); 38 fprintf(stderr, "%s\n", os.str().c_str());
@@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { @@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) {
52 58
53 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 59 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
54 const OnlineTransducerModelConfig &config) { 60 const OnlineTransducerModelConfig &config) {
55 - auto model_type = GetModelType(config); 61 + auto buffer = ReadFile(config.encoder_filename);
  62 +
  63 + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
56 64
57 switch (model_type) { 65 switch (model_type) {
58 case ModelType::kLstm: 66 case ModelType::kLstm:
@@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
67 return nullptr; 75 return nullptr;
68 } 76 }
69 77
  78 +#if __ANDROID_API__ >= 9
  79 +std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
  80 + AAssetManager *mgr, const OnlineTransducerModelConfig &config) {
  81 + auto buffer = ReadFile(mgr, config.encoder_filename);
  82 + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  83 +
  84 + switch (model_type) {
  85 + case ModelType::kLstm:
  86 + return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
  87 + case ModelType::kZipformer:
  88 + return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
  89 + case ModelType::kUnkown:
  90 + return nullptr;
  91 + }
  92 +
  93 + // unreachable code
  94 + return nullptr;
  95 +}
  96 +#endif
  97 +
70 } // namespace sherpa_onnx 98 } // namespace sherpa_onnx
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <utility> 8 #include <utility>
9 #include <vector> 9 #include <vector>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 17 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
13 18
@@ -22,6 +27,11 @@ class OnlineTransducerModel { @@ -22,6 +27,11 @@ class OnlineTransducerModel {
22 static std::unique_ptr<OnlineTransducerModel> Create( 27 static std::unique_ptr<OnlineTransducerModel> Create(
23 const OnlineTransducerModelConfig &config); 28 const OnlineTransducerModelConfig &config);
24 29
  30 +#if __ANDROID_API__ >= 9
  31 + static std::unique_ptr<OnlineTransducerModel> Create(
  32 + AAssetManager *mgr, const OnlineTransducerModelConfig &config);
  33 +#endif
  34 +
25 /** Stack a list of individual states into a batch. 35 /** Stack a list of individual states into a batch.
26 * 36 *
27 * It is the inverse operation of `UnStackStates`. 37 * It is the inverse operation of `UnStackStates`.
@@ -13,6 +13,11 @@ @@ -13,6 +13,11 @@
13 #include <utility> 13 #include <utility>
14 #include <vector> 14 #include <vector>
15 15
  16 +#if __ANDROID_API__ >= 9
  17 +#include "android/asset_manager.h"
  18 +#include "android/asset_manager_jni.h"
  19 +#endif
  20 +
16 #include "onnxruntime_cxx_api.h" // NOLINT 21 #include "onnxruntime_cxx_api.h" // NOLINT
17 #include "sherpa-onnx/csrc/cat.h" 22 #include "sherpa-onnx/csrc/cat.h"
18 #include "sherpa-onnx/csrc/macros.h" 23 #include "sherpa-onnx/csrc/macros.h"
@@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( @@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
32 sess_opts_.SetIntraOpNumThreads(config.num_threads); 37 sess_opts_.SetIntraOpNumThreads(config.num_threads);
33 sess_opts_.SetInterOpNumThreads(config.num_threads); 38 sess_opts_.SetInterOpNumThreads(config.num_threads);
34 39
35 - InitEncoder(config.encoder_filename);  
36 - InitDecoder(config.decoder_filename);  
37 - InitJoiner(config.joiner_filename); 40 + {
  41 + auto buf = ReadFile(config.encoder_filename);
  42 + InitEncoder(buf.data(), buf.size());
  43 + }
  44 +
  45 + {
  46 + auto buf = ReadFile(config.decoder_filename);
  47 + InitDecoder(buf.data(), buf.size());
  48 + }
  49 +
  50 + {
  51 + auto buf = ReadFile(config.joiner_filename);
  52 + InitJoiner(buf.data(), buf.size());
  53 + }
  54 +}
  55 +
  56 +#if __ANDROID_API__ >= 9
  57 +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
  58 + AAssetManager *mgr, const OnlineTransducerModelConfig &config)
  59 + : env_(ORT_LOGGING_LEVEL_WARNING),
  60 + config_(config),
  61 + sess_opts_{},
  62 + allocator_{} {
  63 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  64 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  65 +
  66 + {
  67 + auto buf = ReadFile(mgr, config.encoder_filename);
  68 + InitEncoder(buf.data(), buf.size());
  69 + }
  70 +
  71 + {
  72 + auto buf = ReadFile(mgr, config.decoder_filename);
  73 + InitDecoder(buf.data(), buf.size());
  74 + }
  75 +
  76 + {
  77 + auto buf = ReadFile(mgr, config.joiner_filename);
  78 + InitJoiner(buf.data(), buf.size());
  79 + }
38 } 80 }
  81 +#endif
39 82
40 -void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {  
41 - encoder_sess_ = std::make_unique<Ort::Session>(  
42 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 83 +void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
  84 + size_t model_data_length) {
  85 + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  86 + model_data_length, sess_opts_);
43 87
44 GetInputNames(encoder_sess_.get(), &encoder_input_names_, 88 GetInputNames(encoder_sess_.get(), &encoder_input_names_,
45 &encoder_input_names_ptr_); 89 &encoder_input_names_ptr_);
@@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { @@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) {
84 } 128 }
85 } 129 }
86 130
87 -void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {  
88 - decoder_sess_ = std::make_unique<Ort::Session>(  
89 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 131 +void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
  132 + size_t model_data_length) {
  133 + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  134 + model_data_length, sess_opts_);
90 135
91 GetInputNames(decoder_sess_.get(), &decoder_input_names_, 136 GetInputNames(decoder_sess_.get(), &decoder_input_names_,
92 &decoder_input_names_ptr_); 137 &decoder_input_names_ptr_);
@@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { @@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) {
108 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); 153 SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
109 } 154 }
110 155
111 -void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) {  
112 - joiner_sess_ = std::make_unique<Ort::Session>(  
113 - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); 156 +void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
  157 + size_t model_data_length) {
  158 + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
  159 + model_data_length, sess_opts_);
114 160
115 GetInputNames(joiner_sess_.get(), &joiner_input_names_, 161 GetInputNames(joiner_sess_.get(), &joiner_input_names_,
116 &joiner_input_names_ptr_); 162 &joiner_input_names_ptr_);
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
12 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 18 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
14 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { @@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
20 explicit OnlineZipformerTransducerModel( 25 explicit OnlineZipformerTransducerModel(
21 const OnlineTransducerModelConfig &config); 26 const OnlineTransducerModelConfig &config);
22 27
  28 +#if __ANDROID_API__ >= 9
  29 + OnlineZipformerTransducerModel(AAssetManager *mgr,
  30 + const OnlineTransducerModelConfig &config);
  31 +#endif
  32 +
23 std::vector<Ort::Value> StackStates( 33 std::vector<Ort::Value> StackStates(
24 const std::vector<std::vector<Ort::Value>> &states) const override; 34 const std::vector<std::vector<Ort::Value>> &states) const override;
25 35
@@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { @@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
48 OrtAllocator *Allocator() override { return allocator_; } 58 OrtAllocator *Allocator() override { return allocator_; }
49 59
50 private: 60 private:
51 - void InitEncoder(const std::string &encoder_filename);  
52 - void InitDecoder(const std::string &decoder_filename);  
53 - void InitJoiner(const std::string &joiner_filename); 61 + void InitEncoder(void *model_data, size_t model_data_length);
  62 + void InitDecoder(void *model_data, size_t model_data_length);
  63 + void InitJoiner(void *model_data, size_t model_data_length);
54 64
55 private: 65 private:
56 Ort::Env env_; 66 Ort::Env env_;
@@ -3,9 +3,16 @@ @@ -3,9 +3,16 @@
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/onnx-utils.h" 4 #include "sherpa-onnx/csrc/onnx-utils.h"
5 5
  6 +#include <fstream>
6 #include <string> 7 #include <string>
7 #include <vector> 8 #include <vector>
8 9
  10 +#if __ANDROID_API__ >= 9
  11 +#include "android/asset_manager.h"
  12 +#include "android/asset_manager_jni.h"
  13 +#include "android/log.h"
  14 +#endif
  15 +
9 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
10 17
11 namespace sherpa_onnx { 18 namespace sherpa_onnx {
@@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) { @@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) {
116 fprintf(stderr, "\n"); 123 fprintf(stderr, "\n");
117 } 124 }
118 125
  126 +std::vector<char> ReadFile(const std::string &filename) {
  127 + std::ifstream input(filename, std::ios::binary);
  128 + std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
  129 + return buffer;
  130 +}
  131 +
  132 +#if __ANDROID_API__ >= 9
  133 +std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
  134 + AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
  135 + if (!asset) {
  136 + __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
  137 + "Read binary file: Load %s failed", filename.c_str());
  138 + exit(-1);
  139 + }
  140 +
  141 + auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
  142 + size_t asset_length = AAsset_getLength(asset);
  143 +
  144 + AAsset_close(asset);
  145 +
  146 + std::vector<char> buffer(p, p + asset_length);
  147 +
  148 + return buffer;
  149 +}
  150 +#endif
  151 +
119 } // namespace sherpa_onnx 152 } // namespace sherpa_onnx
@@ -14,6 +14,11 @@ @@ -14,6 +14,11 @@
14 #include <string> 14 #include <string>
15 #include <vector> 15 #include <vector>
16 16
  17 +#if __ANDROID_API__ >= 9
  18 +#include "android/asset_manager.h"
  19 +#include "android/asset_manager_jni.h"
  20 +#endif
  21 +
17 #include "onnxruntime_cxx_api.h" // NOLINT 22 #include "onnxruntime_cxx_api.h" // NOLINT
18 23
19 namespace sherpa_onnx { 24 namespace sherpa_onnx {
@@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) { @@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) {
74 std::fill(p, p + n, value); 79 std::fill(p, p + n, value);
75 } 80 }
76 81
  82 +std::vector<char> ReadFile(const std::string &filename);
  83 +
  84 +#if __ANDROID_API__ >= 9
  85 +std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
  86 +#endif
  87 +
77 } // namespace sherpa_onnx 88 } // namespace sherpa_onnx
78 89
79 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 90 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
@@ -7,11 +7,32 @@ @@ -7,11 +7,32 @@
7 #include <cassert> 7 #include <cassert>
8 #include <fstream> 8 #include <fstream>
9 #include <sstream> 9 #include <sstream>
  10 +#include <strstream>
  11 +
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +
  14 +#if __ANDROID_API__ >= 9
  15 +#include "android/asset_manager.h"
  16 +#include "android/asset_manager_jni.h"
  17 +#endif
10 18
11 namespace sherpa_onnx { 19 namespace sherpa_onnx {
12 20
13 SymbolTable::SymbolTable(const std::string &filename) { 21 SymbolTable::SymbolTable(const std::string &filename) {
14 std::ifstream is(filename); 22 std::ifstream is(filename);
  23 + Init(is);
  24 +}
  25 +
  26 +#if __ANDROID_API__ >= 9
  27 +SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
  28 + auto buf = ReadFile(mgr, filename);
  29 +
  30 + std::istrstream is(buf.data(), buf.size());
  31 + Init(is);
  32 +}
  33 +#endif
  34 +
  35 +void SymbolTable::Init(std::istream &is) {
15 std::string sym; 36 std::string sym;
16 int32_t id; 37 int32_t id;
17 while (is >> sym >> id) { 38 while (is >> sym >> id) {
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <string> 8 #include <string>
9 #include <unordered_map> 9 #include <unordered_map>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 namespace sherpa_onnx { 16 namespace sherpa_onnx {
12 17
13 /// It manages mapping between symbols and integer IDs. 18 /// It manages mapping between symbols and integer IDs.
@@ -22,6 +27,10 @@ class SymbolTable { @@ -22,6 +27,10 @@ class SymbolTable {
22 /// Fields are separated by space(s). 27 /// Fields are separated by space(s).
23 explicit SymbolTable(const std::string &filename); 28 explicit SymbolTable(const std::string &filename);
24 29
  30 +#if __ANDROID_API__ >= 9
  31 + SymbolTable(AAssetManager *mgr, const std::string &filename);
  32 +#endif
  33 +
25 /// Return a string representation of this symbol table 34 /// Return a string representation of this symbol table
26 std::string ToString() const; 35 std::string ToString() const;
27 36
@@ -37,6 +46,9 @@ class SymbolTable { @@ -37,6 +46,9 @@ class SymbolTable {
37 bool contains(const std::string &sym) const; 46 bool contains(const std::string &sym) const;
38 47
39 private: 48 private:
  49 + void Init(std::istream &is);
  50 +
  51 + private:
40 std::unordered_map<std::string, int32_t> sym2id_; 52 std::unordered_map<std::string, int32_t> sym2id_;
41 std::unordered_map<int32_t, std::string> id2sym_; 53 std::unordered_map<int32_t, std::string> id2sym_;
42 }; 54 };
@@ -20,7 +20,7 @@ @@ -20,7 +20,7 @@
20 #endif 20 #endif
21 21
22 #if __ANDROID_API__ >= 8 22 #if __ANDROID_API__ >= 8
23 -#include <android/log.h> 23 +#include "android/log.h"
24 #define SHERPA_ONNX_LOGE(...) \ 24 #define SHERPA_ONNX_LOGE(...) \
25 do { \ 25 do { \
26 fprintf(stderr, ##__VA_ARGS__); \ 26 fprintf(stderr, ##__VA_ARGS__); \