Committed by
GitHub
Add build script for Android armv8a (#58)
正在显示
16 个修改的文件
包含
400 行增加
和
74 行删除
build-android-arm64-v8a.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +set -ex | ||
| 3 | + | ||
| 4 | +dir=build-android-arm64-v8a | ||
| 5 | + | ||
| 6 | +mkdir -p $dir | ||
| 7 | +cd $dir | ||
| 8 | + | ||
| 9 | +# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android | ||
| 10 | +# (optional) remove the hardcoded debug flag in Android NDK android-ndk | ||
| 11 | +# issue: https://github.com/android/ndk/issues/243 | ||
| 12 | +# | ||
| 13 | +# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23 | ||
| 14 | +# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23 | ||
| 15 | +# | ||
| 16 | +# delete "-g" line | ||
| 17 | +# | ||
| 18 | +# list(APPEND ANDROID_COMPILER_FLAGS | ||
| 19 | +# -g | ||
| 20 | +# -DANDROID | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +if [ -z $ANDROID_NDK ]; then | ||
| 24 | + ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669 | ||
| 25 | + # or use | ||
| 26 | + # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk | ||
| 27 | + # | ||
| 28 | + # Inside the $ANDROID_NDK directory, you can find a binary ndk-build | ||
| 29 | + # and some other files like the file "build/cmake/android.toolchain.cmake" | ||
| 30 | + | ||
| 31 | + if [ ! -d $ANDROID_NDK ]; then | ||
| 32 | + # For macOS, I have installed Android Studio, select the menu | ||
| 33 | + # Tools -> SDK manager -> Android SDK | ||
| 34 | + # and set "Android SDK location" to /Users/fangjun/software/my-android | ||
| 35 | + ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670 | ||
| 36 | + fi | ||
| 37 | +fi | ||
| 38 | + | ||
| 39 | +if [ ! -d $ANDROID_NDK ]; then | ||
| 40 | + echo Please set the environment variable ANDROID_NDK before you run this script | ||
| 41 | + exit 1 | ||
| 42 | +fi | ||
| 43 | + | ||
| 44 | +echo "ANDROID_NDK: $ANDROID_NDK" | ||
| 45 | +sleep 1 | ||
| 46 | + | ||
| 47 | +cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ | ||
| 48 | + -DCMAKE_BUILD_TYPE=Release \ | ||
| 49 | + -DBUILD_SHARED_LIBS=ON \ | ||
| 50 | + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ | ||
| 51 | + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ | ||
| 52 | + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ | ||
| 53 | + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ | ||
| 54 | + -DSHERPA_ONNX_ENABLE_JNI=ON \ | ||
| 55 | + -DCMAKE_INSTALL_PREFIX=./install \ | ||
| 56 | + -DANDROID_ABI="arm64-v8a" \ | ||
| 57 | + -DANDROID_PLATFORM=android-21 .. | ||
| 58 | +# make VERBOSE=1 -j4 | ||
| 59 | +make -j4 | ||
| 60 | +make install/strip | ||
| 61 | + |
| 1 | function(download_onnxruntime) | 1 | function(download_onnxruntime) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - if(CMAKE_SYSTEM_NAME STREQUAL Linux) | ||
| 5 | - if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) | ||
| 6 | - # For embedded systems | ||
| 7 | - set(possible_file_locations | ||
| 8 | - $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 9 | - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 10 | - ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 11 | - /tmp/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 12 | - /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 13 | - ) | ||
| 14 | - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") | ||
| 15 | - set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") | ||
| 16 | - | ||
| 17 | - else() | ||
| 18 | - # If you don't have access to the Internet, | ||
| 19 | - # please pre-download onnxruntime | ||
| 20 | - set(possible_file_locations | ||
| 21 | - $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz | ||
| 22 | - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz | ||
| 23 | - ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz | ||
| 24 | - /tmp/onnxruntime-linux-x64-1.14.0.tgz | ||
| 25 | - /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz | ||
| 26 | - ) | ||
| 27 | - | ||
| 28 | - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") | ||
| 29 | - set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") | ||
| 30 | - # After downloading, it contains: | ||
| 31 | - # ./lib/libonnxruntime.so.1.14.0 | ||
| 32 | - # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 | ||
| 33 | - # | ||
| 34 | - # ./include | ||
| 35 | - # It contains all the needed header files | ||
| 36 | - endif() | 4 | + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") |
| 5 | + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") | ||
| 6 | + | ||
| 7 | + if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) | ||
| 8 | + # For embedded systems | ||
| 9 | + set(possible_file_locations | ||
| 10 | + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 11 | + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 12 | + ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 13 | + /tmp/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 14 | + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz | ||
| 15 | + ) | ||
| 16 | + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") | ||
| 17 | + set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") | ||
| 18 | + elseif(CMAKE_SYSTEM_NAME STREQUAL Linux AND CMAKE_SYSTEM_PROCESSOR STREQUAL x86_64) | ||
| 19 | + # If you don't have access to the Internet, | ||
| 20 | + # please pre-download onnxruntime | ||
| 21 | + set(possible_file_locations | ||
| 22 | + $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz | ||
| 23 | + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz | ||
| 24 | + ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz | ||
| 25 | + /tmp/onnxruntime-linux-x64-1.14.0.tgz | ||
| 26 | + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz | ||
| 27 | + ) | ||
| 28 | + | ||
| 29 | + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") | ||
| 30 | + set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") | ||
| 31 | + # After downloading, it contains: | ||
| 32 | + # ./lib/libonnxruntime.so.1.14.0 | ||
| 33 | + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 | ||
| 34 | + # | ||
| 35 | + # ./include | ||
| 36 | + # It contains all the needed header files | ||
| 37 | elseif(APPLE) | 37 | elseif(APPLE) |
| 38 | # If you don't have access to the Internet, | 38 | # If you don't have access to the Internet, |
| 39 | # please pre-download onnxruntime | 39 | # please pre-download onnxruntime |
| @@ -69,6 +69,8 @@ function(download_onnxruntime) | @@ -69,6 +69,8 @@ function(download_onnxruntime) | ||
| 69 | # ./include | 69 | # ./include |
| 70 | # It contains all the needed header files | 70 | # It contains all the needed header files |
| 71 | else() | 71 | else() |
| 72 | + message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") | ||
| 73 | + message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") | ||
| 72 | message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") | 74 | message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") |
| 73 | endif() | 75 | endif() |
| 74 | 76 | ||
| @@ -91,11 +93,15 @@ function(download_onnxruntime) | @@ -91,11 +93,15 @@ function(download_onnxruntime) | ||
| 91 | endif() | 93 | endif() |
| 92 | message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") | 94 | message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") |
| 93 | 95 | ||
| 94 | - find_library(location_onnxruntime onnxruntime | ||
| 95 | - PATHS | ||
| 96 | - "${onnxruntime_SOURCE_DIR}/lib" | ||
| 97 | - NO_CMAKE_SYSTEM_PATH | ||
| 98 | - ) | 96 | + if(ANDROID) |
| 97 | + set(location_onnxruntime ${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so) | ||
| 98 | + else() | ||
| 99 | + find_library(location_onnxruntime onnxruntime | ||
| 100 | + PATHS | ||
| 101 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 102 | + NO_CMAKE_SYSTEM_PATH | ||
| 103 | + ) | ||
| 104 | + endif() | ||
| 99 | 105 | ||
| 100 | message(STATUS "location_onnxruntime: ${location_onnxruntime}") | 106 | message(STATUS "location_onnxruntime: ${location_onnxruntime}") |
| 101 | 107 |
| @@ -26,6 +26,10 @@ endif() | @@ -26,6 +26,10 @@ endif() | ||
| 26 | 26 | ||
| 27 | add_library(sherpa-onnx-core ${sources}) | 27 | add_library(sherpa-onnx-core ${sources}) |
| 28 | 28 | ||
| 29 | +if(ANDROID_NDK) | ||
| 30 | + target_link_libraries(sherpa-onnx-core android log) | ||
| 31 | +endif() | ||
| 32 | + | ||
| 29 | target_link_libraries(sherpa-onnx-core | 33 | target_link_libraries(sherpa-onnx-core |
| 30 | onnxruntime | 34 | onnxruntime |
| 31 | kaldi-native-fbank-core | 35 | kaldi-native-fbank-core |
| @@ -12,6 +12,11 @@ | @@ -12,6 +12,11 @@ | ||
| 12 | #include <utility> | 12 | #include <utility> |
| 13 | #include <vector> | 13 | #include <vector> |
| 14 | 14 | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 15 | #include "onnxruntime_cxx_api.h" // NOLINT | 20 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 16 | #include "sherpa-onnx/csrc/cat.h" | 21 | #include "sherpa-onnx/csrc/cat.h" |
| 17 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | @@ -30,14 +35,53 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | ||
| 30 | sess_opts_.SetIntraOpNumThreads(config.num_threads); | 35 | sess_opts_.SetIntraOpNumThreads(config.num_threads); |
| 31 | sess_opts_.SetInterOpNumThreads(config.num_threads); | 36 | sess_opts_.SetInterOpNumThreads(config.num_threads); |
| 32 | 37 | ||
| 33 | - InitEncoder(config.encoder_filename); | ||
| 34 | - InitDecoder(config.decoder_filename); | ||
| 35 | - InitJoiner(config.joiner_filename); | 38 | + { |
| 39 | + auto buf = ReadFile(config.encoder_filename); | ||
| 40 | + InitEncoder(buf.data(), buf.size()); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + { | ||
| 44 | + auto buf = ReadFile(config.decoder_filename); | ||
| 45 | + InitDecoder(buf.data(), buf.size()); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + { | ||
| 49 | + auto buf = ReadFile(config.joiner_filename); | ||
| 50 | + InitJoiner(buf.data(), buf.size()); | ||
| 51 | + } | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +#if __ANDROID_API__ >= 9 | ||
| 55 | +OnlineLstmTransducerModel::OnlineLstmTransducerModel( | ||
| 56 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config) | ||
| 57 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 58 | + config_(config), | ||
| 59 | + sess_opts_{}, | ||
| 60 | + allocator_{} { | ||
| 61 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 62 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 63 | + | ||
| 64 | + { | ||
| 65 | + auto buf = ReadFile(mgr, config.encoder_filename); | ||
| 66 | + InitEncoder(buf.data(), buf.size()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + { | ||
| 70 | + auto buf = ReadFile(mgr, config.decoder_filename); | ||
| 71 | + InitDecoder(buf.data(), buf.size()); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + { | ||
| 75 | + auto buf = ReadFile(mgr, config.joiner_filename); | ||
| 76 | + InitJoiner(buf.data(), buf.size()); | ||
| 77 | + } | ||
| 36 | } | 78 | } |
| 79 | +#endif | ||
| 37 | 80 | ||
| 38 | -void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | ||
| 39 | - encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 40 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 81 | +void OnlineLstmTransducerModel::InitEncoder(void *model_data, |
| 82 | + size_t model_data_length) { | ||
| 83 | + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 84 | + model_data_length, sess_opts_); | ||
| 41 | 85 | ||
| 42 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, | 86 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, |
| 43 | &encoder_input_names_ptr_); | 87 | &encoder_input_names_ptr_); |
| @@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | @@ -62,9 +106,10 @@ void OnlineLstmTransducerModel::InitEncoder(const std::string &filename) { | ||
| 62 | SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); | 106 | SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); |
| 63 | } | 107 | } |
| 64 | 108 | ||
| 65 | -void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | ||
| 66 | - decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 67 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 109 | +void OnlineLstmTransducerModel::InitDecoder(void *model_data, |
| 110 | + size_t model_data_length) { | ||
| 111 | + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 112 | + model_data_length, sess_opts_); | ||
| 68 | 113 | ||
| 69 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, | 114 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, |
| 70 | &decoder_input_names_ptr_); | 115 | &decoder_input_names_ptr_); |
| @@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | @@ -86,9 +131,10 @@ void OnlineLstmTransducerModel::InitDecoder(const std::string &filename) { | ||
| 86 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | 131 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); |
| 87 | } | 132 | } |
| 88 | 133 | ||
| 89 | -void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) { | ||
| 90 | - joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 91 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 134 | +void OnlineLstmTransducerModel::InitJoiner(void *model_data, |
| 135 | + size_t model_data_length) { | ||
| 136 | + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 137 | + model_data_length, sess_opts_); | ||
| 92 | 138 | ||
| 93 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, | 139 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, |
| 94 | &joiner_input_names_ptr_); | 140 | &joiner_input_names_ptr_); |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 17 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 18 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| @@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -19,6 +24,11 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 19 | public: | 24 | public: |
| 20 | explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); | 25 | explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); |
| 21 | 26 | ||
| 27 | +#if __ANDROID_API__ >= 9 | ||
| 28 | + OnlineLstmTransducerModel(AAssetManager *mgr, | ||
| 29 | + const OnlineTransducerModelConfig &config); | ||
| 30 | +#endif | ||
| 31 | + | ||
| 22 | std::vector<Ort::Value> StackStates( | 32 | std::vector<Ort::Value> StackStates( |
| 23 | const std::vector<std::vector<Ort::Value>> &states) const override; | 33 | const std::vector<std::vector<Ort::Value>> &states) const override; |
| 24 | 34 | ||
| @@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | @@ -47,9 +57,9 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { | ||
| 47 | OrtAllocator *Allocator() override { return allocator_; } | 57 | OrtAllocator *Allocator() override { return allocator_; } |
| 48 | 58 | ||
| 49 | private: | 59 | private: |
| 50 | - void InitEncoder(const std::string &encoder_filename); | ||
| 51 | - void InitDecoder(const std::string &decoder_filename); | ||
| 52 | - void InitJoiner(const std::string &joiner_filename); | 60 | + void InitEncoder(void *model_data, size_t model_data_length); |
| 61 | + void InitDecoder(void *model_data, size_t model_data_length); | ||
| 62 | + void InitJoiner(void *model_data, size_t model_data_length); | ||
| 53 | 63 | ||
| 54 | private: | 64 | private: |
| 55 | Ort::Env env_; | 65 | Ort::Env env_; |
| @@ -55,6 +55,17 @@ class OnlineRecognizer::Impl { | @@ -55,6 +55,17 @@ class OnlineRecognizer::Impl { | ||
| 55 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 55 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 56 | } | 56 | } |
| 57 | 57 | ||
| 58 | +#if __ANDROID_API__ >= 9 | ||
| 59 | + explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config) | ||
| 60 | + : config_(config), | ||
| 61 | + model_(OnlineTransducerModel::Create(mgr, config.model_config)), | ||
| 62 | + sym_(mgr, config.tokens), | ||
| 63 | + endpoint_(config_.endpoint_config) { | ||
| 64 | + decoder_ = | ||
| 65 | + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | ||
| 66 | + } | ||
| 67 | +#endif | ||
| 68 | + | ||
| 58 | std::unique_ptr<OnlineStream> CreateStream() const { | 69 | std::unique_ptr<OnlineStream> CreateStream() const { |
| 59 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 70 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); |
| 60 | stream->SetResult(decoder_->GetEmptyResult()); | 71 | stream->SetResult(decoder_->GetEmptyResult()); |
| @@ -156,6 +167,13 @@ class OnlineRecognizer::Impl { | @@ -156,6 +167,13 @@ class OnlineRecognizer::Impl { | ||
| 156 | 167 | ||
| 157 | OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) | 168 | OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) |
| 158 | : impl_(std::make_unique<Impl>(config)) {} | 169 | : impl_(std::make_unique<Impl>(config)) {} |
| 170 | + | ||
| 171 | +#if __ANDROID_API__ >= 9 | ||
| 172 | +OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr, | ||
| 173 | + const OnlineRecognizerConfig &config) | ||
| 174 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 175 | +#endif | ||
| 176 | + | ||
| 159 | OnlineRecognizer::~OnlineRecognizer() = default; | 177 | OnlineRecognizer::~OnlineRecognizer() = default; |
| 160 | 178 | ||
| 161 | std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | 179 | std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <string> | 9 | #include <string> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | #include "sherpa-onnx/csrc/endpoint.h" | 16 | #include "sherpa-onnx/csrc/endpoint.h" |
| 12 | #include "sherpa-onnx/csrc/features.h" | 17 | #include "sherpa-onnx/csrc/features.h" |
| 13 | #include "sherpa-onnx/csrc/online-stream.h" | 18 | #include "sherpa-onnx/csrc/online-stream.h" |
| @@ -45,6 +50,11 @@ struct OnlineRecognizerConfig { | @@ -45,6 +50,11 @@ struct OnlineRecognizerConfig { | ||
| 45 | class OnlineRecognizer { | 50 | class OnlineRecognizer { |
| 46 | public: | 51 | public: |
| 47 | explicit OnlineRecognizer(const OnlineRecognizerConfig &config); | 52 | explicit OnlineRecognizer(const OnlineRecognizerConfig &config); |
| 53 | + | ||
| 54 | +#if __ANDROID_API__ >= 9 | ||
| 55 | + OnlineRecognizer(AAssetManager *mgr, const OnlineRecognizerConfig &config); | ||
| 56 | +#endif | ||
| 57 | + | ||
| 48 | ~OnlineRecognizer(); | 58 | ~OnlineRecognizer(); |
| 49 | 59 | ||
| 50 | /// Create a stream for decoding. | 60 | /// Create a stream for decoding. |
| @@ -3,6 +3,11 @@ | @@ -3,6 +3,11 @@ | ||
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 4 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 5 | 5 | ||
| 6 | +#if __ANDROID_API__ >= 9 | ||
| 7 | +#include "android/asset_manager.h" | ||
| 8 | +#include "android/asset_manager_jni.h" | ||
| 9 | +#endif | ||
| 10 | + | ||
| 6 | #include <memory> | 11 | #include <memory> |
| 7 | #include <sstream> | 12 | #include <sstream> |
| 8 | #include <string> | 13 | #include <string> |
| @@ -18,15 +23,16 @@ enum class ModelType { | @@ -18,15 +23,16 @@ enum class ModelType { | ||
| 18 | kUnkown, | 23 | kUnkown, |
| 19 | }; | 24 | }; |
| 20 | 25 | ||
| 21 | -static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | 26 | +static ModelType GetModelType(char *model_data, size_t model_data_length, |
| 27 | + bool debug) { | ||
| 22 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | 28 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING); |
| 23 | Ort::SessionOptions sess_opts; | 29 | Ort::SessionOptions sess_opts; |
| 24 | 30 | ||
| 25 | - auto sess = std::make_unique<Ort::Session>( | ||
| 26 | - env, SHERPA_MAYBE_WIDE(config.encoder_filename).c_str(), sess_opts); | 31 | + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, |
| 32 | + sess_opts); | ||
| 27 | 33 | ||
| 28 | Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | 34 | Ort::ModelMetadata meta_data = sess->GetModelMetadata(); |
| 29 | - if (config.debug) { | 35 | + if (debug) { |
| 30 | std::ostringstream os; | 36 | std::ostringstream os; |
| 31 | PrintModelMetadata(os, meta_data); | 37 | PrintModelMetadata(os, meta_data); |
| 32 | fprintf(stderr, "%s\n", os.str().c_str()); | 38 | fprintf(stderr, "%s\n", os.str().c_str()); |
| @@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | @@ -52,7 +58,9 @@ static ModelType GetModelType(const OnlineTransducerModelConfig &config) { | ||
| 52 | 58 | ||
| 53 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | 59 | std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( |
| 54 | const OnlineTransducerModelConfig &config) { | 60 | const OnlineTransducerModelConfig &config) { |
| 55 | - auto model_type = GetModelType(config); | 61 | + auto buffer = ReadFile(config.encoder_filename); |
| 62 | + | ||
| 63 | + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 56 | 64 | ||
| 57 | switch (model_type) { | 65 | switch (model_type) { |
| 58 | case ModelType::kLstm: | 66 | case ModelType::kLstm: |
| @@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | @@ -67,4 +75,24 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 67 | return nullptr; | 75 | return nullptr; |
| 68 | } | 76 | } |
| 69 | 77 | ||
| 78 | +#if __ANDROID_API__ >= 9 | ||
| 79 | +std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( | ||
| 80 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config) { | ||
| 81 | + auto buffer = ReadFile(mgr, config.encoder_filename); | ||
| 82 | + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 83 | + | ||
| 84 | + switch (model_type) { | ||
| 85 | + case ModelType::kLstm: | ||
| 86 | + return std::make_unique<OnlineLstmTransducerModel>(mgr, config); | ||
| 87 | + case ModelType::kZipformer: | ||
| 88 | + return std::make_unique<OnlineZipformerTransducerModel>(mgr, config); | ||
| 89 | + case ModelType::kUnkown: | ||
| 90 | + return nullptr; | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + // unreachable code | ||
| 94 | + return nullptr; | ||
| 95 | +} | ||
| 96 | +#endif | ||
| 97 | + | ||
| 70 | } // namespace sherpa_onnx | 98 | } // namespace sherpa_onnx |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <utility> | 8 | #include <utility> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 16 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 17 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 13 | 18 | ||
| @@ -22,6 +27,11 @@ class OnlineTransducerModel { | @@ -22,6 +27,11 @@ class OnlineTransducerModel { | ||
| 22 | static std::unique_ptr<OnlineTransducerModel> Create( | 27 | static std::unique_ptr<OnlineTransducerModel> Create( |
| 23 | const OnlineTransducerModelConfig &config); | 28 | const OnlineTransducerModelConfig &config); |
| 24 | 29 | ||
| 30 | +#if __ANDROID_API__ >= 9 | ||
| 31 | + static std::unique_ptr<OnlineTransducerModel> Create( | ||
| 32 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config); | ||
| 33 | +#endif | ||
| 34 | + | ||
| 25 | /** Stack a list of individual states into a batch. | 35 | /** Stack a list of individual states into a batch. |
| 26 | * | 36 | * |
| 27 | * It is the inverse operation of `UnStackStates`. | 37 | * It is the inverse operation of `UnStackStates`. |
| @@ -13,6 +13,11 @@ | @@ -13,6 +13,11 @@ | ||
| 13 | #include <utility> | 13 | #include <utility> |
| 14 | #include <vector> | 14 | #include <vector> |
| 15 | 15 | ||
| 16 | +#if __ANDROID_API__ >= 9 | ||
| 17 | +#include "android/asset_manager.h" | ||
| 18 | +#include "android/asset_manager_jni.h" | ||
| 19 | +#endif | ||
| 20 | + | ||
| 16 | #include "onnxruntime_cxx_api.h" // NOLINT | 21 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 17 | #include "sherpa-onnx/csrc/cat.h" | 22 | #include "sherpa-onnx/csrc/cat.h" |
| 18 | #include "sherpa-onnx/csrc/macros.h" | 23 | #include "sherpa-onnx/csrc/macros.h" |
| @@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | @@ -32,14 +37,53 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | ||
| 32 | sess_opts_.SetIntraOpNumThreads(config.num_threads); | 37 | sess_opts_.SetIntraOpNumThreads(config.num_threads); |
| 33 | sess_opts_.SetInterOpNumThreads(config.num_threads); | 38 | sess_opts_.SetInterOpNumThreads(config.num_threads); |
| 34 | 39 | ||
| 35 | - InitEncoder(config.encoder_filename); | ||
| 36 | - InitDecoder(config.decoder_filename); | ||
| 37 | - InitJoiner(config.joiner_filename); | 40 | + { |
| 41 | + auto buf = ReadFile(config.encoder_filename); | ||
| 42 | + InitEncoder(buf.data(), buf.size()); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + { | ||
| 46 | + auto buf = ReadFile(config.decoder_filename); | ||
| 47 | + InitDecoder(buf.data(), buf.size()); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + { | ||
| 51 | + auto buf = ReadFile(config.joiner_filename); | ||
| 52 | + InitJoiner(buf.data(), buf.size()); | ||
| 53 | + } | ||
| 54 | +} | ||
| 55 | + | ||
| 56 | +#if __ANDROID_API__ >= 9 | ||
| 57 | +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | ||
| 58 | + AAssetManager *mgr, const OnlineTransducerModelConfig &config) | ||
| 59 | + : env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 60 | + config_(config), | ||
| 61 | + sess_opts_{}, | ||
| 62 | + allocator_{} { | ||
| 63 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 64 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 65 | + | ||
| 66 | + { | ||
| 67 | + auto buf = ReadFile(mgr, config.encoder_filename); | ||
| 68 | + InitEncoder(buf.data(), buf.size()); | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + { | ||
| 72 | + auto buf = ReadFile(mgr, config.decoder_filename); | ||
| 73 | + InitDecoder(buf.data(), buf.size()); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + { | ||
| 77 | + auto buf = ReadFile(mgr, config.joiner_filename); | ||
| 78 | + InitJoiner(buf.data(), buf.size()); | ||
| 79 | + } | ||
| 38 | } | 80 | } |
| 81 | +#endif | ||
| 39 | 82 | ||
| 40 | -void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { | ||
| 41 | - encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 42 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 83 | +void OnlineZipformerTransducerModel::InitEncoder(void *model_data, |
| 84 | + size_t model_data_length) { | ||
| 85 | + encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 86 | + model_data_length, sess_opts_); | ||
| 43 | 87 | ||
| 44 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, | 88 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, |
| 45 | &encoder_input_names_ptr_); | 89 | &encoder_input_names_ptr_); |
| @@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { | @@ -84,9 +128,10 @@ void OnlineZipformerTransducerModel::InitEncoder(const std::string &filename) { | ||
| 84 | } | 128 | } |
| 85 | } | 129 | } |
| 86 | 130 | ||
| 87 | -void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { | ||
| 88 | - decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 89 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 131 | +void OnlineZipformerTransducerModel::InitDecoder(void *model_data, |
| 132 | + size_t model_data_length) { | ||
| 133 | + decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 134 | + model_data_length, sess_opts_); | ||
| 90 | 135 | ||
| 91 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, | 136 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, |
| 92 | &decoder_input_names_ptr_); | 137 | &decoder_input_names_ptr_); |
| @@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { | @@ -108,9 +153,10 @@ void OnlineZipformerTransducerModel::InitDecoder(const std::string &filename) { | ||
| 108 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | 153 | SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); |
| 109 | } | 154 | } |
| 110 | 155 | ||
| 111 | -void OnlineZipformerTransducerModel::InitJoiner(const std::string &filename) { | ||
| 112 | - joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 113 | - env_, SHERPA_MAYBE_WIDE(filename).c_str(), sess_opts_); | 156 | +void OnlineZipformerTransducerModel::InitJoiner(void *model_data, |
| 157 | + size_t model_data_length) { | ||
| 158 | + joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, | ||
| 159 | + model_data_length, sess_opts_); | ||
| 114 | 160 | ||
| 115 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, | 161 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, |
| 116 | &joiner_input_names_ptr_); | 162 | &joiner_input_names_ptr_); |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 17 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 18 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 19 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| @@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | @@ -20,6 +25,11 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | ||
| 20 | explicit OnlineZipformerTransducerModel( | 25 | explicit OnlineZipformerTransducerModel( |
| 21 | const OnlineTransducerModelConfig &config); | 26 | const OnlineTransducerModelConfig &config); |
| 22 | 27 | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + OnlineZipformerTransducerModel(AAssetManager *mgr, | ||
| 30 | + const OnlineTransducerModelConfig &config); | ||
| 31 | +#endif | ||
| 32 | + | ||
| 23 | std::vector<Ort::Value> StackStates( | 33 | std::vector<Ort::Value> StackStates( |
| 24 | const std::vector<std::vector<Ort::Value>> &states) const override; | 34 | const std::vector<std::vector<Ort::Value>> &states) const override; |
| 25 | 35 | ||
| @@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | @@ -48,9 +58,9 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { | ||
| 48 | OrtAllocator *Allocator() override { return allocator_; } | 58 | OrtAllocator *Allocator() override { return allocator_; } |
| 49 | 59 | ||
| 50 | private: | 60 | private: |
| 51 | - void InitEncoder(const std::string &encoder_filename); | ||
| 52 | - void InitDecoder(const std::string &decoder_filename); | ||
| 53 | - void InitJoiner(const std::string &joiner_filename); | 61 | + void InitEncoder(void *model_data, size_t model_data_length); |
| 62 | + void InitDecoder(void *model_data, size_t model_data_length); | ||
| 63 | + void InitJoiner(void *model_data, size_t model_data_length); | ||
| 54 | 64 | ||
| 55 | private: | 65 | private: |
| 56 | Ort::Env env_; | 66 | Ort::Env env_; |
| @@ -3,9 +3,16 @@ | @@ -3,9 +3,16 @@ | ||
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | #include "sherpa-onnx/csrc/onnx-utils.h" | 4 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 5 | 5 | ||
| 6 | +#include <fstream> | ||
| 6 | #include <string> | 7 | #include <string> |
| 7 | #include <vector> | 8 | #include <vector> |
| 8 | 9 | ||
| 10 | +#if __ANDROID_API__ >= 9 | ||
| 11 | +#include "android/asset_manager.h" | ||
| 12 | +#include "android/asset_manager_jni.h" | ||
| 13 | +#include "android/log.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 9 | #include "onnxruntime_cxx_api.h" // NOLINT | 16 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 10 | 17 | ||
| 11 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| @@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) { | @@ -116,4 +123,30 @@ void Print3D(Ort::Value *v) { | ||
| 116 | fprintf(stderr, "\n"); | 123 | fprintf(stderr, "\n"); |
| 117 | } | 124 | } |
| 118 | 125 | ||
| 126 | +std::vector<char> ReadFile(const std::string &filename) { | ||
| 127 | + std::ifstream input(filename, std::ios::binary); | ||
| 128 | + std::vector<char> buffer(std::istreambuf_iterator<char>(input), {}); | ||
| 129 | + return buffer; | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +#if __ANDROID_API__ >= 9 | ||
| 133 | +std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) { | ||
| 134 | + AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); | ||
| 135 | + if (!asset) { | ||
| 136 | + __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx", | ||
| 137 | + "Read binary file: Load %s failed", filename.c_str()); | ||
| 138 | + exit(-1); | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset)); | ||
| 142 | + size_t asset_length = AAsset_getLength(asset); | ||
| 143 | + | ||
| 144 | + AAsset_close(asset); | ||
| 145 | + | ||
| 146 | + std::vector<char> buffer(p, p + asset_length); | ||
| 147 | + | ||
| 148 | + return buffer; | ||
| 149 | +} | ||
| 150 | +#endif | ||
| 151 | + | ||
| 119 | } // namespace sherpa_onnx | 152 | } // namespace sherpa_onnx |
| @@ -14,6 +14,11 @@ | @@ -14,6 +14,11 @@ | ||
| 14 | #include <string> | 14 | #include <string> |
| 15 | #include <vector> | 15 | #include <vector> |
| 16 | 16 | ||
| 17 | +#if __ANDROID_API__ >= 9 | ||
| 18 | +#include "android/asset_manager.h" | ||
| 19 | +#include "android/asset_manager_jni.h" | ||
| 20 | +#endif | ||
| 21 | + | ||
| 17 | #include "onnxruntime_cxx_api.h" // NOLINT | 22 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 18 | 23 | ||
| 19 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| @@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) { | @@ -74,6 +79,12 @@ void Fill(Ort::Value *tensor, T value) { | ||
| 74 | std::fill(p, p + n, value); | 79 | std::fill(p, p + n, value); |
| 75 | } | 80 | } |
| 76 | 81 | ||
| 82 | +std::vector<char> ReadFile(const std::string &filename); | ||
| 83 | + | ||
| 84 | +#if __ANDROID_API__ >= 9 | ||
| 85 | +std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); | ||
| 86 | +#endif | ||
| 87 | + | ||
| 77 | } // namespace sherpa_onnx | 88 | } // namespace sherpa_onnx |
| 78 | 89 | ||
| 79 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 90 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| @@ -7,11 +7,32 @@ | @@ -7,11 +7,32 @@ | ||
| 7 | #include <cassert> | 7 | #include <cassert> |
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | #include <sstream> | 9 | #include <sstream> |
| 10 | +#include <strstream> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | + | ||
| 14 | +#if __ANDROID_API__ >= 9 | ||
| 15 | +#include "android/asset_manager.h" | ||
| 16 | +#include "android/asset_manager_jni.h" | ||
| 17 | +#endif | ||
| 10 | 18 | ||
| 11 | namespace sherpa_onnx { | 19 | namespace sherpa_onnx { |
| 12 | 20 | ||
| 13 | SymbolTable::SymbolTable(const std::string &filename) { | 21 | SymbolTable::SymbolTable(const std::string &filename) { |
| 14 | std::ifstream is(filename); | 22 | std::ifstream is(filename); |
| 23 | + Init(is); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | +SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) { | ||
| 28 | + auto buf = ReadFile(mgr, filename); | ||
| 29 | + | ||
| 30 | + std::istrstream is(buf.data(), buf.size()); | ||
| 31 | + Init(is); | ||
| 32 | +} | ||
| 33 | +#endif | ||
| 34 | + | ||
| 35 | +void SymbolTable::Init(std::istream &is) { | ||
| 15 | std::string sym; | 36 | std::string sym; |
| 16 | int32_t id; | 37 | int32_t id; |
| 17 | while (is >> sym >> id) { | 38 | while (is >> sym >> id) { |
| @@ -8,6 +8,11 @@ | @@ -8,6 +8,11 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <unordered_map> | 9 | #include <unordered_map> |
| 10 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 11 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 12 | 17 | ||
| 13 | /// It manages mapping between symbols and integer IDs. | 18 | /// It manages mapping between symbols and integer IDs. |
| @@ -22,6 +27,10 @@ class SymbolTable { | @@ -22,6 +27,10 @@ class SymbolTable { | ||
| 22 | /// Fields are separated by space(s). | 27 | /// Fields are separated by space(s). |
| 23 | explicit SymbolTable(const std::string &filename); | 28 | explicit SymbolTable(const std::string &filename); |
| 24 | 29 | ||
| 30 | +#if __ANDROID_API__ >= 9 | ||
| 31 | + SymbolTable(AAssetManager *mgr, const std::string &filename); | ||
| 32 | +#endif | ||
| 33 | + | ||
| 25 | /// Return a string representation of this symbol table | 34 | /// Return a string representation of this symbol table |
| 26 | std::string ToString() const; | 35 | std::string ToString() const; |
| 27 | 36 | ||
| @@ -37,6 +46,9 @@ class SymbolTable { | @@ -37,6 +46,9 @@ class SymbolTable { | ||
| 37 | bool contains(const std::string &sym) const; | 46 | bool contains(const std::string &sym) const; |
| 38 | 47 | ||
| 39 | private: | 48 | private: |
| 49 | + void Init(std::istream &is); | ||
| 50 | + | ||
| 51 | + private: | ||
| 40 | std::unordered_map<std::string, int32_t> sym2id_; | 52 | std::unordered_map<std::string, int32_t> sym2id_; |
| 41 | std::unordered_map<int32_t, std::string> id2sym_; | 53 | std::unordered_map<int32_t, std::string> id2sym_; |
| 42 | }; | 54 | }; |
| @@ -20,7 +20,7 @@ | @@ -20,7 +20,7 @@ | ||
| 20 | #endif | 20 | #endif |
| 21 | 21 | ||
| 22 | #if __ANDROID_API__ >= 8 | 22 | #if __ANDROID_API__ >= 8 |
| 23 | -#include <android/log.h> | 23 | +#include "android/log.h" |
| 24 | #define SHERPA_ONNX_LOGE(...) \ | 24 | #define SHERPA_ONNX_LOGE(...) \ |
| 25 | do { \ | 25 | do { \ |
| 26 | fprintf(stderr, ##__VA_ARGS__); \ | 26 | fprintf(stderr, ##__VA_ARGS__); \ |
-
请 注册 或 登录 后发表评论