Fangjun Kuang
Committed by GitHub

Add examples for Kotlin API (#124)

1 -package android.content.res  
2 -  
3 -// a dummy class for testing only  
4 -class AssetManager  
1 -../../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt  
1 -../../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt  
@@ -8,9 +8,9 @@ on: @@ -8,9 +8,9 @@ on:
8 - '.github/workflows/jni.yaml' 8 - '.github/workflows/jni.yaml'
9 - 'CMakeLists.txt' 9 - 'CMakeLists.txt'
10 - 'cmake/**' 10 - 'cmake/**'
  11 + - 'kotlin-api-examples/**'
11 - 'sherpa-onnx/csrc/*' 12 - 'sherpa-onnx/csrc/*'
12 - 'sherpa-onnx/jni/*' 13 - 'sherpa-onnx/jni/*'
13 - - '.github/scripts/test-jni.sh'  
14 pull_request: 14 pull_request:
15 branches: 15 branches:
16 - master 16 - master
@@ -18,9 +18,9 @@ on: @@ -18,9 +18,9 @@ on:
18 - '.github/workflows/jni.yaml' 18 - '.github/workflows/jni.yaml'
19 - 'CMakeLists.txt' 19 - 'CMakeLists.txt'
20 - 'cmake/**' 20 - 'cmake/**'
  21 + - 'kotlin-api-examples/**'
21 - 'sherpa-onnx/csrc/*' 22 - 'sherpa-onnx/csrc/*'
22 - 'sherpa-onnx/jni/*' 23 - 'sherpa-onnx/jni/*'
23 - - '.github/scripts/test-jni.sh'  
24 24
25 concurrency: 25 concurrency:
26 group: jni-${{ github.ref }} 26 group: jni-${{ github.ref }}
@@ -56,4 +56,5 @@ jobs: @@ -56,4 +56,5 @@ jobs:
56 - name: Run JNI test 56 - name: Run JNI test
57 shell: bash 57 shell: bash
58 run: | 58 run: |
59 - .github/scripts/test-jni.sh 59 + cd ./kotlin-api-examples
  60 + ./run.sh
@@ -55,3 +55,4 @@ sherpa-onnx-zipformer-en-2023-04-01 @@ -55,3 +55,4 @@ sherpa-onnx-zipformer-en-2023-04-01
55 run-offline-decode-files.sh 55 run-offline-decode-files.sh
56 sherpa-onnx-nemo-ctc-en-citrinet-512 56 sherpa-onnx-nemo-ctc-en-citrinet-512
57 run-offline-decode-files-nemo-ctc.sh 57 run-offline-decode-files-nemo-ctc.sh
  58 +*.jar
@@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI) @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI)
51 set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE) 51 set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
52 endif() 52 endif()
53 53
  54 +if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
  55 + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
  56 + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
  57 +endif()
  58 +
54 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") 59 message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
55 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") 60 message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
56 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") 61 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
@@ -38,19 +38,23 @@ data class OnlineRecognizerConfig( @@ -38,19 +38,23 @@ data class OnlineRecognizerConfig(
38 ) 38 )
39 39
40 class SherpaOnnx( 40 class SherpaOnnx(
41 - assetManager: AssetManager, var config: OnlineRecognizerConfig 41 + assetManager: AssetManager? = null,
  42 + var config: OnlineRecognizerConfig,
42 ) { 43 ) {
43 private val ptr: Long 44 private val ptr: Long
44 45
45 init { 46 init {
46 - ptr = new(assetManager, config) 47 + if (assetManager != null) {
  48 + ptr = new(assetManager, config)
  49 + } else {
  50 + ptr = newFromFile(config)
  51 + }
47 } 52 }
48 53
49 protected fun finalize() { 54 protected fun finalize() {
50 delete(ptr) 55 delete(ptr)
51 } 56 }
52 57
53 -  
54 fun acceptWaveform(samples: FloatArray, sampleRate: Int) = 58 fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
55 acceptWaveform(ptr, samples, sampleRate) 59 acceptWaveform(ptr, samples, sampleRate)
56 60
@@ -70,6 +74,10 @@ class SherpaOnnx( @@ -70,6 +74,10 @@ class SherpaOnnx(
70 config: OnlineRecognizerConfig, 74 config: OnlineRecognizerConfig,
71 ): Long 75 ): Long
72 76
  77 + private external fun newFromFile(
  78 + config: OnlineRecognizerConfig,
  79 + ): Long
  80 +
73 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) 81 private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
74 private external fun inputFinished(ptr: Long) 82 private external fun inputFinished(ptr: Long)
75 private external fun getText(ptr: Long): String 83 private external fun getText(ptr: Long): String
@@ -86,7 +94,7 @@ class SherpaOnnx( @@ -86,7 +94,7 @@ class SherpaOnnx(
86 } 94 }
87 95
88 fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { 96 fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
89 - return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) 97 + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
90 } 98 }
91 99
92 /* 100 /*
@@ -4,10 +4,21 @@ import android.content.res.AssetManager @@ -4,10 +4,21 @@ import android.content.res.AssetManager
4 4
5 class WaveReader { 5 class WaveReader {
6 companion object { 6 companion object {
7 - // Read a mono wave file.  
8 - // No resampling is made.  
9 - external fun readWave(  
10 - assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f 7 + // Read a mono wave file asset
  8 + // The returned array has two entries:
  9 + // - the first entry contains an 1-D float array
  10 + // - the second entry is the sample rate
  11 + external fun readWaveFromAsset(
  12 + assetManager: AssetManager,
  13 + filename: String,
  14 + ): Array<Any>
  15 +
  16 + // Read a mono wave file from disk
  17 + // The returned array has two entries:
  18 + // - the first entry contains an 1-D float array
  19 + // - the second entry is the sample rate
  20 + external fun readWaveFromFile(
  21 + filename: String,
11 ): Array<Any> 22 ): Array<Any>
12 23
13 init { 24 init {
@@ -8,6 +8,9 @@ fun main() { @@ -8,6 +8,9 @@ fun main() {
8 featureDim = 80, 8 featureDim = 80,
9 ) 9 )
10 10
  11 + // please refer to
  12 + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  13 + // to dowload pre-trained models
11 var modelConfig = OnlineTransducerModelConfig( 14 var modelConfig = OnlineTransducerModelConfig(
12 encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", 15 encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
13 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", 16 decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
@@ -29,12 +32,10 @@ fun main() { @@ -29,12 +32,10 @@ fun main() {
29 ) 32 )
30 33
31 var model = SherpaOnnx( 34 var model = SherpaOnnx(
32 - assetManager = AssetManager(),  
33 config = config, 35 config = config,
34 ) 36 )
35 37
36 - var objArray = WaveReader.readWave(  
37 - assetManager = AssetManager(), 38 + var objArray = WaveReader.readWaveFromFile(
38 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", 39 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav",
39 ) 40 )
40 var samples : FloatArray = objArray[0] as FloatArray 41 var samples : FloatArray = objArray[0] as FloatArray
@@ -45,8 +46,8 @@ fun main() { @@ -45,8 +46,8 @@ fun main() {
45 model.decode() 46 model.decode()
46 } 47 }
47 48
48 - var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds  
49 - model.acceptWaveform(tail_paddings, sampleRate=sampleRate) 49 + var tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
  50 + model.acceptWaveform(tailPaddings, sampleRate=sampleRate)
50 model.inputFinished() 51 model.inputFinished()
51 while (model.isReady()) { 52 while (model.isReady()) {
52 model.decode() 53 model.decode()
  1 +../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
  1 +../android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/WaveReader.kt
  1 +package android.content.res
  2 +
  3 +class AssetManager {}
1 #!/usr/bin/env bash 1 #!/usr/bin/env bash
  2 +#
  3 +# This scripts shows how to build JNI libs for sherpa-onnx
  4 +# Note: This scripts runs only on Linux and macOS, though sherpa-onnx
  5 +# supports building JNI libs for Windows.
2 6
3 set -e 7 set -e
4 8
  9 +cd ..
5 mkdir -p build 10 mkdir -p build
6 cd build 11 cd build
7 12
@@ -17,17 +22,17 @@ cmake \ @@ -17,17 +22,17 @@ cmake \
17 make -j4 22 make -j4
18 ls -lh lib 23 ls -lh lib
19 24
20 -cd ..  
21 -  
22 export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH 25 export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
23 26
24 -cd .github/scripts/ 27 +cd ../kotlin-api-examples
25 28
26 -git lfs install  
27 -git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 29 +if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
  30 + git lfs install
  31 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
  32 +fi
28 33
29 -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt AssetManager.kt 34 +kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt
30 35
31 ls -lh main.jar 36 ls -lh main.jar
32 37
33 -java -Djava.library.path=../../build/lib -jar main.jar 38 +java -Djava.library.path=../build/lib -jar main.jar
@@ -31,18 +31,13 @@ namespace sherpa_onnx { @@ -31,18 +31,13 @@ namespace sherpa_onnx {
31 31
32 class SherpaOnnx { 32 class SherpaOnnx {
33 public: 33 public:
34 - SherpaOnnx(  
35 #if __ANDROID_API__ >= 9 34 #if __ANDROID_API__ >= 9
36 - AAssetManager *mgr, 35 + SherpaOnnx(AAssetManager *mgr, const OnlineRecognizerConfig &config)
  36 + : recognizer_(mgr, config), stream_(recognizer_.CreateStream()) {}
37 #endif 37 #endif
38 - const sherpa_onnx::OnlineRecognizerConfig &config)  
39 - : recognizer_(  
40 -#if __ANDROID_API__ >= 9  
41 - mgr,  
42 -#endif  
43 - config),  
44 - stream_(recognizer_.CreateStream()) {  
45 - } 38 +
  39 + explicit SherpaOnnx(const OnlineRecognizerConfig &config)
  40 + : recognizer_(config), stream_(recognizer_.CreateStream()) {}
46 41
47 void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { 42 void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
48 if (input_sample_rate_ == -1) { 43 if (input_sample_rate_ == -1) {
@@ -73,8 +68,8 @@ class SherpaOnnx { @@ -73,8 +68,8 @@ class SherpaOnnx {
73 void Decode() const { recognizer_.DecodeStream(stream_.get()); } 68 void Decode() const { recognizer_.DecodeStream(stream_.get()); }
74 69
75 private: 70 private:
76 - sherpa_onnx::OnlineRecognizer recognizer_;  
77 - std::unique_ptr<sherpa_onnx::OnlineStream> stream_; 71 + OnlineRecognizer recognizer_;
  72 + std::unique_ptr<OnlineStream> stream_;
78 int32_t input_sample_rate_ = -1; 73 int32_t input_sample_rate_ = -1;
79 }; 74 };
80 75
@@ -219,6 +214,16 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new( @@ -219,6 +214,16 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
219 } 214 }
220 215
221 SHERPA_ONNX_EXTERN_C 216 SHERPA_ONNX_EXTERN_C
  217 +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_newFromFile(
  218 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  219 + auto config = sherpa_onnx::GetConfig(env, _config);
  220 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  221 + auto model = new sherpa_onnx::SherpaOnnx(config);
  222 +
  223 + return (jlong)model;
  224 +}
  225 +
  226 +SHERPA_ONNX_EXTERN_C
222 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete( 227 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
223 JNIEnv *env, jobject /*obj*/, jlong ptr) { 228 JNIEnv *env, jobject /*obj*/, jlong ptr) {
224 delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); 229 delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
@@ -289,9 +294,47 @@ static jobject NewInteger(JNIEnv *env, int32_t value) { @@ -289,9 +294,47 @@ static jobject NewInteger(JNIEnv *env, int32_t value) {
289 return env->NewObject(cls, constructor, value); 294 return env->NewObject(cls, constructor, value);
290 } 295 }
291 296
  297 +static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is,
  298 + const char *p_filename) {
  299 + bool is_ok = false;
  300 + int32_t sampling_rate = -1;
  301 + std::vector<float> samples =
  302 + sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
  303 +
  304 + if (!is_ok) {
  305 + SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
  306 + exit(-1);
  307 + }
  308 +
  309 + jfloatArray samples_arr = env->NewFloatArray(samples.size());
  310 + env->SetFloatArrayRegion(samples_arr, 0, samples.size(), samples.data());
  311 +
  312 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  313 + 2, env->FindClass("java/lang/Object"), nullptr);
  314 +
  315 + env->SetObjectArrayElement(obj_arr, 0, samples_arr);
  316 + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
  317 +
  318 + return obj_arr;
  319 +}
  320 +
  321 +SHERPA_ONNX_EXTERN_C
  322 +JNIEXPORT jobjectArray JNICALL
  323 +Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromFile(
  324 + JNIEnv *env, jclass /*cls*/, jstring filename) {
  325 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
  326 + std::ifstream is(p_filename, std::ios::binary);
  327 +
  328 + auto obj_arr = ReadWaveImpl(env, is, p_filename);
  329 +
  330 + env->ReleaseStringUTFChars(filename, p_filename);
  331 +
  332 + return obj_arr;
  333 +}
  334 +
292 SHERPA_ONNX_EXTERN_C 335 SHERPA_ONNX_EXTERN_C
293 JNIEXPORT jobjectArray JNICALL 336 JNIEXPORT jobjectArray JNICALL
294 -Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( 337 +Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(
295 JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { 338 JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
296 const char *p_filename = env->GetStringUTFChars(filename, nullptr); 339 const char *p_filename = env->GetStringUTFChars(filename, nullptr);
297 #if __ANDROID_API__ >= 9 340 #if __ANDROID_API__ >= 9
@@ -308,27 +351,10 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( @@ -308,27 +351,10 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
308 std::ifstream is(p_filename, std::ios::binary); 351 std::ifstream is(p_filename, std::ios::binary);
309 #endif 352 #endif
310 353
311 - bool is_ok = false;  
312 - int32_t sampling_rate = -1;  
313 - std::vector<float> samples =  
314 - sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok); 354 + auto obj_arr = ReadWaveImpl(env, is, p_filename);
315 355
316 env->ReleaseStringUTFChars(filename, p_filename); 356 env->ReleaseStringUTFChars(filename, p_filename);
317 357
318 - if (!is_ok) {  
319 - SHERPA_ONNX_LOGE("Failed to read %s", p_filename);  
320 - exit(-1);  
321 - }  
322 -  
323 - jfloatArray ans = env->NewFloatArray(samples.size());  
324 - env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());  
325 -  
326 - jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(  
327 - 2, env->FindClass("java/lang/Object"), nullptr);  
328 -  
329 - env->SetObjectArrayElement(obj_arr, 0, ans);  
330 - env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));  
331 -  
332 return obj_arr; 358 return obj_arr;
333 } 359 }
334 360
@@ -340,8 +366,9 @@ JNIEXPORT jobjectArray JNICALL @@ -340,8 +366,9 @@ JNIEXPORT jobjectArray JNICALL
340 Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env, 366 Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_readWave(JNIEnv *env,
341 jclass /*cls*/, 367 jclass /*cls*/,
342 jstring filename) { 368 jstring filename) {
343 - auto data = Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(  
344 - env, nullptr, nullptr, filename); 369 + auto data =
  370 + Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWaveFromAsset(
  371 + env, nullptr, nullptr, filename);
345 return data; 372 return data;
346 } 373 }
347 374