Fangjun Kuang
Committed by GitHub

Add JNI (#57)

  1 +Makefile
  2 +*.jar
  1 +package android.content.res
  2 +
  3 +// a dummy class for testing only
  4 +class AssetManager
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +fun main() {
  6 + var featConfig = FeatureConfig(
  7 + sampleRate=16000.0f,
  8 + featureDim=80,
  9 + )
  10 +
  11 + var modelConfig = OnlineTransducerModelConfig(
  12 + encoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
  13 + decoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
  14 + joiner="./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
  15 + numThreads=4,
  16 + debug=false,
  17 + )
  18 +
  19 + var endpointConfig = EndpointConfig()
  20 +
  21 + var config = OnlineRecognizerConfig(
  22 + modelConfig=modelConfig,
  23 + featConfig=featConfig,
  24 + endpointConfig=endpointConfig,
  25 + tokens="./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
  26 + enableEndpoint=true,
  27 + )
  28 +
  29 + var model = SherpaOnnx(
  30 + assetManager = AssetManager(),
  31 + config = config,
  32 + )
  33 + var samples = WaveReader.readWave(
  34 + assetManager = AssetManager(),
  35 + filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
  36 + )
  37 +
  38 + model.decodeSamples(samples!!)
  39 +
  40 + var tail_paddings = FloatArray(8000) // 0.5 seconds
  41 + model.decodeSamples(tail_paddings)
  42 +
  43 + model.inputFinished()
  44 + println(model.text)
  45 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +data class EndpointRule(
  6 + var mustContainNonSilence: Boolean,
  7 + var minTrailingSilence: Float,
  8 + var minUtteranceLength: Float,
  9 +)
  10 +
  11 +data class EndpointConfig(
  12 + var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
  13 + var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
  14 + var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
  15 +)
  16 +
  17 +data class OnlineTransducerModelConfig(
  18 + var encoder: String,
  19 + var decoder: String,
  20 + var joiner: String,
  21 + var numThreads: Int = 4,
  22 + var debug: Boolean = false,
  23 +)
  24 +
  25 +data class FeatureConfig(
  26 + var sampleRate: Float = 16000.0f,
  27 + var featureDim: Int = 80,
  28 +)
  29 +
  30 +data class OnlineRecognizerConfig(
  31 + var featConfig: FeatureConfig = FeatureConfig(),
  32 + var modelConfig: OnlineTransducerModelConfig,
  33 + var tokens: String,
  34 + var endpointConfig: EndpointConfig = EndpointConfig(),
  35 + var enableEndpoint: Boolean,
  36 +)
  37 +
  38 +class SherpaOnnx(
  39 + assetManager: AssetManager,
  40 + var config: OnlineRecognizerConfig
  41 +) {
  42 + private val ptr: Long
  43 +
  44 + init {
  45 + ptr = new(assetManager, config)
  46 + }
  47 +
  48 + protected fun finalize() {
  49 + delete(ptr)
  50 + }
  51 +
  52 +
  53 + fun decodeSamples(samples: FloatArray) =
  54 + decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
  55 +
  56 + fun inputFinished() = inputFinished(ptr)
  57 + fun reset() = reset(ptr)
  58 + fun isEndpoint(): Boolean = isEndpoint(ptr)
  59 +
  60 + val text: String
  61 + get() = getText(ptr)
  62 +
  63 + private external fun delete(ptr: Long)
  64 +
  65 + private external fun new(
  66 + assetManager: AssetManager,
  67 + config: OnlineRecognizerConfig,
  68 + ): Long
  69 +
  70 + private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
  71 + private external fun inputFinished(ptr: Long)
  72 + private external fun getText(ptr: Long): String
  73 + private external fun reset(ptr: Long)
  74 + private external fun isEndpoint(ptr: Long): Boolean
  75 +
  76 + companion object {
  77 + init {
  78 + System.loadLibrary("sherpa-onnx-jni")
  79 + }
  80 + }
  81 +}
  82 +
  83 +fun getFeatureConfig(): FeatureConfig {
  84 + val featConfig = FeatureConfig()
  85 + featConfig.sampleRate = 16000.0f
  86 + featConfig.featureDim = 80
  87 +
  88 + return featConfig
  89 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +class WaveReader {
  6 + companion object {
  7 + // Read a mono wave file.
  8 + // No resampling is made.
  9 + external fun readWave(
  10 + assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
  11 + ): FloatArray?
  12 +
  13 + init {
  14 + System.loadLibrary("sherpa-onnx-jni")
  15 + }
  16 + }
  17 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +mkdir -p build
  6 +cd build
  7 +
  8 +cmake \
  9 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  10 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  11 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  12 + -DBUILD_SHARED_LIBS=ON \
  13 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  14 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  15 + ..
  16 +
  17 +make -j4
  18 +ls -lh lib
  19 +
  20 +cd ..
  21 +
  22 +export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
  23 +
  24 +cd .github/scripts/
  25 +
  26 +git lfs install
  27 +git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
  28 +
  29 +kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt AssetManager.kt
  30 +
  31 +ls -lh main.jar
  32 +
  33 +java -Djava.library.path=../../build/lib -jar main.jar
  1 +name: jni
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - '.github/workflows/jni.yaml'
  9 + - 'CMakeLists.txt'
  10 + - 'cmake/**'
  11 + - 'sherpa-onnx/csrc/*'
  12 + - 'sherpa-onnx/jni/*'
  13 + - '.github/scripts/test-jni.sh'
  14 + pull_request:
  15 + branches:
  16 + - master
  17 + paths:
  18 + - '.github/workflows/jni.yaml'
  19 + - 'CMakeLists.txt'
  20 + - 'cmake/**'
  21 + - 'sherpa-onnx/csrc/*'
  22 + - 'sherpa-onnx/jni/*'
  23 + - '.github/scripts/test-jni.sh'
  24 +
  25 +concurrency:
  26 + group: jni-${{ github.ref }}
  27 + cancel-in-progress: true
  28 +
  29 +permissions:
  30 + contents: read
  31 +
  32 +jobs:
  33 + jni:
  34 + runs-on: ${{ matrix.os }}
  35 + strategy:
  36 + fail-fast: false
  37 + matrix:
  38 + os: [ubuntu-latest, macos-latest]
  39 +
  40 + steps:
  41 + - uses: actions/checkout@v2
  42 + with:
  43 + fetch-depth: 0
  44 +
  45 + - name: Display kotlin version
  46 + shell: bash
  47 + run: |
  48 + kotlinc -version
  49 +
  50 + - name: Display java version
  51 + shell: bash
  52 + run: |
  53 + java -version
  54 + echo "JAVA_HOME is: ${JAVA_HOME}"
  55 +
  56 + - name: Run JNI test
  57 + shell: bash
  58 + run: |
  59 + .github/scripts/test-jni.sh
@@ -16,6 +16,7 @@ option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) @@ -16,6 +16,7 @@ option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
16 option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) 16 option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
17 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) 17 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) 18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
  19 +option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
19 20
20 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 21 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
21 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 22 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -44,6 +45,11 @@ if(NOT CMAKE_BUILD_TYPE) @@ -44,6 +45,11 @@ if(NOT CMAKE_BUILD_TYPE)
44 set(CMAKE_BUILD_TYPE Release) 45 set(CMAKE_BUILD_TYPE Release)
45 endif() 46 endif()
46 47
  48 +if(DEFINED ANDROID_ABI)
  49 + message(STATUS "Set SHERPA_ONNX_ENABLE_JNI to ON for Android")
  50 + set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
  51 +endif()
  52 +
47 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") 53 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
48 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") 54 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
49 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") 55 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
@@ -51,6 +57,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") @@ -51,6 +57,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
51 message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}") 57 message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
52 message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}") 58 message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
53 message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") 59 message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
  60 +message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
54 61
55 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 62 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
56 set(CMAKE_CXX_EXTENSIONS OFF) 63 set(CMAKE_CXX_EXTENSIONS OFF)
@@ -2,3 +2,7 @@ add_subdirectory(csrc) @@ -2,3 +2,7 @@ add_subdirectory(csrc)
2 if(SHERPA_ONNX_ENABLE_PYTHON) 2 if(SHERPA_ONNX_ENABLE_PYTHON)
3 add_subdirectory(python) 3 add_subdirectory(python)
4 endif() 4 endif()
  5 +
  6 +if(SHERPA_ONNX_ENABLE_JNI)
  7 + add_subdirectory(jni)
  8 +endif()
  1 +include_directories(${CMAKE_SOURCE_DIR})
  2 +
  3 +if(NOT DEFINED ANDROID_ABI)
  4 + if(NOT DEFINED ENV{JAVA_HOME})
  5 + message(FATAL_ERROR "Please set the environment variable JAVA_HOME")
  6 + endif()
  7 + include_directories($ENV{JAVA_HOME}/include)
  8 + include_directories($ENV{JAVA_HOME}/include/linux)
  9 + include_directories($ENV{JAVA_HOME}/include/darwin)
  10 +endif()
  11 +
  12 +add_library(sherpa-onnx-jni jni.cc)
  13 +target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
  14 +install(TARGETS sherpa-onnx-jni DESTINATION lib)
  1 +// sherpa-onnx/jni/jni.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +// 2022 Pingfeng Luo
  5 +
  6 +// TODO(fangjun): Add documentation to functions/methods in this file
  7 +// and also show how to use them with kotlin, possibly with java.
  8 +
  9 +// If you use ndk, you can find "jni.h" inside
  10 +// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
  11 +#include "jni.h" // NOLINT
  12 +
  13 +#include <strstream>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#else
  19 +#include <fstream>
  20 +#endif
  21 +
  22 +#if __ANDROID_API__ >= 8
  23 +#include <android/log.h>
  24 +#define SHERPA_ONNX_LOGE(...) \
  25 + do { \
  26 + fprintf(stderr, ##__VA_ARGS__); \
  27 + fprintf(stderr, "\n"); \
  28 + __android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
  29 + } while (0)
  30 +#else
  31 +#define SHERPA_ONNX_LOGE(...) \
  32 + do { \
  33 + fprintf(stderr, ##__VA_ARGS__); \
  34 + fprintf(stderr, "\n"); \
  35 + } while (0)
  36 +#endif
  37 +
  38 +#include "sherpa-onnx/csrc/online-recognizer.h"
  39 +#include "sherpa-onnx/csrc/wave-reader.h"
  40 +
  41 +#define SHERPA_ONNX_EXTERN_C extern "C"
  42 +
  43 +namespace sherpa_onnx {
  44 +
  45 +class SherpaOnnx {
  46 + public:
  47 + SherpaOnnx(
  48 +#if __ANDROID_API__ >= 9
  49 + AAssetManager *mgr,
  50 +#endif
  51 + const sherpa_onnx::OnlineRecognizerConfig &config)
  52 + : recognizer_(
  53 +#if __ANDROID_API__ >= 9
  54 + mgr,
  55 +#endif
  56 + config),
  57 + stream_(recognizer_.CreateStream()),
  58 + tail_padding_(16000 * 0.32, 0) {
  59 + }
  60 +
  61 + void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
  62 + stream_->AcceptWaveform(sample_rate, samples, n);
  63 + Decode();
  64 + }
  65 +
  66 + void InputFinished() const {
  67 + stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
  68 + stream_->InputFinished();
  69 + Decode();
  70 + }
  71 +
  72 + const std::string GetText() const {
  73 + auto result = recognizer_.GetResult(stream_.get());
  74 + return result.text;
  75 + }
  76 +
  77 + bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
  78 +
  79 + void Reset() const { return recognizer_.Reset(stream_.get()); }
  80 +
  81 + private:
  82 + void Decode() const {
  83 + while (recognizer_.IsReady(stream_.get())) {
  84 + recognizer_.DecodeStream(stream_.get());
  85 + }
  86 + }
  87 +
  88 + private:
  89 + sherpa_onnx::OnlineRecognizer recognizer_;
  90 + std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
  91 + std::vector<float> tail_padding_;
  92 +};
  93 +
  94 +static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
  95 + OnlineRecognizerConfig ans;
  96 +
  97 + jclass cls = env->GetObjectClass(config);
  98 + jfieldID fid;
  99 +
  100 + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
  101 + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
  102 +
  103 + //---------- feat config ----------
  104 + fid = env->GetFieldID(cls, "featConfig",
  105 + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
  106 + jobject feat_config = env->GetObjectField(config, fid);
  107 + jclass feat_config_cls = env->GetObjectClass(feat_config);
  108 +
  109 + fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
  110 + ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
  111 +
  112 + fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
  113 + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
  114 +
  115 + //---------- enable endpoint ----------
  116 + fid = env->GetFieldID(cls, "enableEndpoint", "Z");
  117 + ans.enable_endpoint = env->GetBooleanField(config, fid);
  118 +
  119 + //---------- endpoint_config ----------
  120 +
  121 + fid = env->GetFieldID(cls, "endpointConfig",
  122 + "Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
  123 + jobject endpoint_config = env->GetObjectField(config, fid);
  124 + jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
  125 +
  126 + fid = env->GetFieldID(endpoint_config_cls, "rule1",
  127 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  128 + jobject rule1 = env->GetObjectField(endpoint_config, fid);
  129 + jclass rule_class = env->GetObjectClass(rule1);
  130 +
  131 + fid = env->GetFieldID(endpoint_config_cls, "rule2",
  132 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  133 + jobject rule2 = env->GetObjectField(endpoint_config, fid);
  134 +
  135 + fid = env->GetFieldID(endpoint_config_cls, "rule3",
  136 + "Lcom/k2fsa/sherpa/onnx/EndpointRule;");
  137 + jobject rule3 = env->GetObjectField(endpoint_config, fid);
  138 +
  139 + fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
  140 + ans.endpoint_config.rule1.must_contain_nonsilence =
  141 + env->GetBooleanField(rule1, fid);
  142 + ans.endpoint_config.rule2.must_contain_nonsilence =
  143 + env->GetBooleanField(rule2, fid);
  144 + ans.endpoint_config.rule3.must_contain_nonsilence =
  145 + env->GetBooleanField(rule3, fid);
  146 +
  147 + fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
  148 + ans.endpoint_config.rule1.min_trailing_silence =
  149 + env->GetFloatField(rule1, fid);
  150 + ans.endpoint_config.rule2.min_trailing_silence =
  151 + env->GetFloatField(rule2, fid);
  152 + ans.endpoint_config.rule3.min_trailing_silence =
  153 + env->GetFloatField(rule3, fid);
  154 +
  155 + fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
  156 + ans.endpoint_config.rule1.min_utterance_length =
  157 + env->GetFloatField(rule1, fid);
  158 + ans.endpoint_config.rule2.min_utterance_length =
  159 + env->GetFloatField(rule2, fid);
  160 + ans.endpoint_config.rule3.min_utterance_length =
  161 + env->GetFloatField(rule3, fid);
  162 +
  163 + //---------- tokens ----------
  164 +
  165 + fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
  166 + jstring s = (jstring)env->GetObjectField(config, fid);
  167 + const char *p = env->GetStringUTFChars(s, nullptr);
  168 + ans.tokens = p;
  169 + env->ReleaseStringUTFChars(s, p);
  170 +
  171 + //---------- model config ----------
  172 + fid = env->GetFieldID(cls, "modelConfig",
  173 + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
  174 + jobject model_config = env->GetObjectField(config, fid);
  175 + jclass model_config_cls = env->GetObjectClass(model_config);
  176 +
  177 + fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
  178 + s = (jstring)env->GetObjectField(model_config, fid);
  179 + p = env->GetStringUTFChars(s, nullptr);
  180 + ans.model_config.encoder_filename = p;
  181 + env->ReleaseStringUTFChars(s, p);
  182 +
  183 + fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
  184 + s = (jstring)env->GetObjectField(model_config, fid);
  185 + p = env->GetStringUTFChars(s, nullptr);
  186 + ans.model_config.decoder_filename = p;
  187 + env->ReleaseStringUTFChars(s, p);
  188 +
  189 + fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
  190 + s = (jstring)env->GetObjectField(model_config, fid);
  191 + p = env->GetStringUTFChars(s, nullptr);
  192 + ans.model_config.joiner_filename = p;
  193 + env->ReleaseStringUTFChars(s, p);
  194 +
  195 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  196 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
  197 +
  198 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  199 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
  200 +
  201 + return ans;
  202 +}
  203 +
  204 +} // namespace sherpa_onnx
  205 +
  206 +SHERPA_ONNX_EXTERN_C
  207 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
  208 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  209 +#if __ANDROID_API__ >= 9
  210 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  211 + if (!mgr) {
  212 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  213 + }
  214 +#endif
  215 +
  216 + auto config = sherpa_onnx::GetConfig(env, _config);
  217 + auto model = new sherpa_onnx::SherpaOnnx(
  218 +#if __ANDROID_API__ >= 9
  219 + mgr,
  220 +#endif
  221 + config);
  222 +
  223 + return (jlong)model;
  224 +}
  225 +
  226 +SHERPA_ONNX_EXTERN_C
  227 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
  228 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  229 + SHERPA_ONNX_LOGE("freed!");
  230 + delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  231 +}
  232 +
  233 +SHERPA_ONNX_EXTERN_C
  234 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
  235 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  236 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  237 + model->Reset();
  238 +}
  239 +
  240 +SHERPA_ONNX_EXTERN_C
  241 +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
  242 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  243 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  244 + return model->IsEndpoint();
  245 +}
  246 +
  247 +SHERPA_ONNX_EXTERN_C
  248 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
  249 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  250 + jfloat sample_rate) {
  251 + auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
  252 +
  253 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  254 + jsize n = env->GetArrayLength(samples);
  255 +
  256 + model->DecodeSamples(sample_rate, p, n);
  257 +
  258 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  259 +}
  260 +
  261 +SHERPA_ONNX_EXTERN_C
  262 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
  263 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  264 + reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
  265 +}
  266 +
  267 +SHERPA_ONNX_EXTERN_C
  268 +JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
  269 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  270 + // see
  271 + // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
  272 + auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
  273 + return env->NewStringUTF(text.c_str());
  274 +}
  275 +
  276 +SHERPA_ONNX_EXTERN_C
  277 +JNIEXPORT jfloatArray JNICALL
  278 +Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
  279 + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
  280 + jfloat expected_sample_rate) {
  281 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
  282 +#if __ANDROID_API__ >= 9
  283 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  284 + if (!mgr) {
  285 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  286 + return nullptr;
  287 + }
  288 +
  289 + AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
  290 + size_t asset_length = AAsset_getLength(asset);
  291 + std::vector<char> buffer(asset_length);
  292 + AAsset_read(asset, buffer.data(), asset_length);
  293 +
  294 + std::istrstream is(buffer.data(), asset_length);
  295 +#else
  296 + std::ifstream is(p_filename, std::ios::binary);
  297 +#endif
  298 +
  299 + bool is_ok = false;
  300 + std::vector<float> samples =
  301 + sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
  302 +
  303 +#if __ANDROID_API__ >= 9
  304 + AAsset_close(asset);
  305 +#endif
  306 + env->ReleaseStringUTFChars(filename, p_filename);
  307 +
  308 + if (!is_ok) {
  309 + return nullptr;
  310 + }
  311 +
  312 + jfloatArray ans = env->NewFloatArray(samples.size());
  313 + env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
  314 + return ans;
  315 +}