Fangjun Kuang
Committed by GitHub

Add CTC HLG decoding for JNI (#810)

@@ -37,7 +37,7 @@ jobs: @@ -37,7 +37,7 @@ jobs:
37 strategy: 37 strategy:
38 fail-fast: false 38 fail-fast: false
39 matrix: 39 matrix:
40 - os: [ubuntu-latest, macos-latest] 40 + os: [ubuntu-latest, macos-latest, macos-14]
41 41
42 steps: 42 steps:
43 - uses: actions/checkout@v4 43 - uses: actions/checkout@v4
@@ -49,6 +49,11 @@ jobs: @@ -49,6 +49,11 @@ jobs:
49 with: 49 with:
50 key: ${{ matrix.os }} 50 key: ${{ matrix.os }}
51 51
  52 + - name: OS info
  53 + shell: bash
  54 + run: |
  55 + uname -a
  56 +
52 - name: Display kotlin version 57 - name: Display kotlin version
53 shell: bash 58 shell: bash
54 run: | 59 run: |
@@ -58,6 +63,7 @@ jobs: @@ -58,6 +63,7 @@ jobs:
58 shell: bash 63 shell: bash
59 run: | 64 run: |
60 java -version 65 java -version
  66 + javac -help
61 echo "JAVA_HOME is: ${JAVA_HOME}" 67 echo "JAVA_HOME is: ${JAVA_HOME}"
62 68
63 - name: Run JNI test 69 - name: Run JNI test
@@ -38,7 +38,7 @@ jobs: @@ -38,7 +38,7 @@ jobs:
38 strategy: 38 strategy:
39 fail-fast: false 39 fail-fast: false
40 matrix: 40 matrix:
41 - os: [ubuntu-latest] 41 + os: [ubuntu-latest, macos-latest, macos-14]
42 42
43 steps: 43 steps:
44 - uses: actions/checkout@v4 44 - uses: actions/checkout@v4
@@ -50,10 +50,24 @@ jobs: @@ -50,10 +50,24 @@ jobs:
50 with: 50 with:
51 key: ${{ matrix.os }}-java 51 key: ${{ matrix.os }}-java
52 52
  53 + - name: OS info
  54 + shell: bash
  55 + run: |
  56 + uname -a
  57 +
  58 + - uses: actions/setup-java@v4
  59 + with:
  60 + distribution: 'temurin' # See 'Supported distributions' for available options
  61 + java-version: '21'
  62 +
53 - name: Display java version 63 - name: Display java version
54 shell: bash 64 shell: bash
55 run: | 65 run: |
56 java -version 66 java -version
  67 + java -help
  68 + echo "----"
  69 + javac -version
  70 + javac -help
57 echo "JAVA_HOME is: ${JAVA_HOME}" 71 echo "JAVA_HOME is: ${JAVA_HOME}"
58 72
59 cmake --version 73 cmake --version
@@ -100,6 +114,9 @@ jobs: @@ -100,6 +114,9 @@ jobs:
100 # Delete model files to save space 114 # Delete model files to save space
101 rm -rf sherpa-onnx-streaming-* 115 rm -rf sherpa-onnx-streaming-*
102 116
  117 + ./run-streaming-decode-file-ctc-hlg.sh
  118 + rm -rf sherpa-onnx-streaming-*
  119 +
103 ./run-streaming-decode-file-paraformer.sh 120 ./run-streaming-decode-file-paraformer.sh
104 rm -rf sherpa-onnx-streaming-* 121 rm -rf sherpa-onnx-streaming-*
105 122
@@ -118,3 +135,6 @@ jobs: @@ -118,3 +135,6 @@ jobs:
118 135
119 ./run-non-streaming-decode-file-whisper.sh 136 ./run-non-streaming-decode-file-whisper.sh
120 rm -rf sherpa-onnx-whisper-* 137 rm -rf sherpa-onnx-whisper-*
  138 +
  139 + ./run-non-streaming-decode-file-nemo.sh
  140 + rm -rf sherpa-onnx-nemo-*
@@ -8,10 +8,6 @@ @@ -8,10 +8,6 @@
8 <Nullable>enable</Nullable> 8 <Nullable>enable</Nullable>
9 </PropertyGroup> 9 </PropertyGroup>
10 10
11 - <PropertyGroup>  
12 - <RestoreSources>/tmp/packages;$(RestoreSources);https://api.nuget.org/v3/index.json</RestoreSources>  
13 - </PropertyGroup>  
14 -  
15 <ItemGroup> 11 <ItemGroup>
16 <PackageReference Include="CommandLineParser" Version="2.9.1" /> 12 <PackageReference Include="CommandLineParser" Version="2.9.1" />
17 <PackageReference Include="org.k2fsa.sherpa.onnx" Version="*" /> 13 <PackageReference Include="org.k2fsa.sherpa.onnx" Version="*" />
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +// This file shows how to use an offline NeMo CTC model, i.e., non-streaming NeMo CTC model,,
  4 +// to decode files.
  5 +import com.k2fsa.sherpa.onnx.*;
  6 +
  7 +public class NonStreamingDecodeFileNemo {
  8 + public static void main(String[] args) {
  9 + // please refer to
  10 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  11 + // to download model files
  12 + String model = "./sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx";
  13 + String tokens = "./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt";
  14 +
  15 + String waveFilename = "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav";
  16 +
  17 + WaveReader reader = new WaveReader(waveFilename);
  18 +
  19 + OfflineNemoEncDecCtcModelConfig nemo =
  20 + OfflineNemoEncDecCtcModelConfig.builder().setModel(model).build();
  21 +
  22 + OfflineModelConfig modelConfig =
  23 + OfflineModelConfig.builder()
  24 + .setNemo(nemo)
  25 + .setTokens(tokens)
  26 + .setNumThreads(1)
  27 + .setDebug(true)
  28 + .build();
  29 +
  30 + OfflineRecognizerConfig config =
  31 + OfflineRecognizerConfig.builder()
  32 + .setOfflineModelConfig(modelConfig)
  33 + .setDecodingMethod("greedy_search")
  34 + .build();
  35 +
  36 + OfflineRecognizer recognizer = new OfflineRecognizer(config);
  37 + OfflineStream stream = recognizer.createStream();
  38 + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
  39 +
  40 + recognizer.decode(stream);
  41 +
  42 + String text = recognizer.getResult(stream).getText();
  43 +
  44 + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);
  45 +
  46 + stream.release();
  47 + recognizer.release();
  48 + }
  49 +}
@@ -8,6 +8,7 @@ This directory contains examples for the JAVA API of sherpa-onnx. @@ -8,6 +8,7 @@ This directory contains examples for the JAVA API of sherpa-onnx.
8 8
9 ``` 9 ```
10 ./run-streaming-decode-file-ctc.sh 10 ./run-streaming-decode-file-ctc.sh
  11 +./run-streaming-decode-file-ctc-hlg.sh
11 ./run-streaming-decode-file-paraformer.sh 12 ./run-streaming-decode-file-paraformer.sh
12 ./run-streaming-decode-file-transducer.sh 13 ./run-streaming-decode-file-transducer.sh
13 ``` 14 ```
@@ -18,4 +19,5 @@ This directory contains examples for the JAVA API of sherpa-onnx. @@ -18,4 +19,5 @@ This directory contains examples for the JAVA API of sherpa-onnx.
18 ./run-non-streaming-decode-file-paraformer.sh 19 ./run-non-streaming-decode-file-paraformer.sh
19 ./run-non-streaming-decode-file-transducer.sh 20 ./run-non-streaming-decode-file-transducer.sh
20 ./run-non-streaming-decode-file-whisper.sh 21 ./run-non-streaming-decode-file-whisper.sh
  22 +./run-non-streaming-decode-file-nemo.sh
21 ``` 23 ```
  1 +// Copyright 2024 Xiaomi Corporation
  2 +
  3 +// This file shows how to use an online CTC model, i.e., streaming CTC model,
  4 +// to decode files.
  5 +import com.k2fsa.sherpa.onnx.*;
  6 +
  7 +public class StreamingDecodeFileCtcHLG {
  8 + public static void main(String[] args) {
  9 + // please refer to
  10 + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  11 + // to download model files
  12 + String model =
  13 + "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx";
  14 + String tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt";
  15 + String hlg = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst";
  16 + String waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav";
  17 +
  18 + WaveReader reader = new WaveReader(waveFilename);
  19 +
  20 + OnlineZipformer2CtcModelConfig ctc =
  21 + OnlineZipformer2CtcModelConfig.builder().setModel(model).build();
  22 +
  23 + OnlineModelConfig modelConfig =
  24 + OnlineModelConfig.builder()
  25 + .setZipformer2Ctc(ctc)
  26 + .setTokens(tokens)
  27 + .setNumThreads(1)
  28 + .setDebug(true)
  29 + .build();
  30 +
  31 + OnlineCtcFstDecoderConfig ctcFstDecoderConfig =
  32 + OnlineCtcFstDecoderConfig.builder().setGraph("hlg").build();
  33 +
  34 + OnlineRecognizerConfig config =
  35 + OnlineRecognizerConfig.builder()
  36 + .setOnlineModelConfig(modelConfig)
  37 + .setCtcFstDecoderConfig(ctcFstDecoderConfig)
  38 + .build();
  39 +
  40 + OnlineRecognizer recognizer = new OnlineRecognizer(config);
  41 + OnlineStream stream = recognizer.createStream();
  42 + stream.acceptWaveform(reader.getSamples(), reader.getSampleRate());
  43 +
  44 + float[] tailPaddings = new float[(int) (0.3 * reader.getSampleRate())];
  45 + stream.acceptWaveform(tailPaddings, reader.getSampleRate());
  46 +
  47 + while (recognizer.isReady(stream)) {
  48 + recognizer.decode(stream);
  49 + }
  50 +
  51 + String text = recognizer.getResult(stream).getText();
  52 +
  53 + System.out.printf("filename:%s\nresult:%s\n", waveFilename, text);
  54 +
  55 + stream.release();
  56 + recognizer.release();
  57 + }
  58 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
  6 + mkdir -p ../build
  7 + pushd ../build
  8 + cmake \
  9 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  10 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  11 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  12 + -DBUILD_SHARED_LIBS=ON \
  13 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  14 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  15 + ..
  16 +
  17 + make -j4
  18 + ls -lh lib
  19 + popd
  20 +fi
  21 +
  22 +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
  23 + pushd ../sherpa-onnx/java-api
  24 + make
  25 + popd
  26 +fi
  27 +
  28 +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
  29 + cmake \
  30 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  31 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  32 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  33 + -DBUILD_SHARED_LIBS=ON \
  34 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  35 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  36 + ..
  37 +
  38 + make -j4
  39 + ls -lh lib
  40 +fi
  41 +
  42 +if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then
  43 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  44 + tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  45 + rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  46 +fi
  47 +
  48 +java \
  49 + -Djava.library.path=$PWD/../build/lib \
  50 + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
  51 + NonStreamingDecodeFileNemo.java
  1 +#!/usr/bin/env bash
  2 +set -ex
  3 +
  4 +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then
  5 + mkdir -p ../build
  6 + pushd ../build
  7 + cmake \
  8 + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \
  9 + -DSHERPA_ONNX_ENABLE_TESTS=OFF \
  10 + -DSHERPA_ONNX_ENABLE_CHECK=OFF \
  11 + -DBUILD_SHARED_LIBS=ON \
  12 + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
  13 + -DSHERPA_ONNX_ENABLE_JNI=ON \
  14 + ..
  15 +
  16 + make -j4
  17 + ls -lh lib
  18 + popd
  19 +fi
  20 +
  21 +if [ ! -f ../sherpa-onnx/java-api/build/sherpa-onnx.jar ]; then
  22 + pushd ../sherpa-onnx/java-api
  23 + make
  24 + popd
  25 +fi
  26 +
  27 +if [ ! -f ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt ]; then
  28 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  29 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  30 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  31 +fi
  32 +
  33 +java \
  34 + -Djava.library.path=$PWD/../build/lib \
  35 + -cp ../sherpa-onnx/java-api/build/sherpa-onnx.jar \
  36 + StreamingDecodeFileCtcHLG.java
@@ -69,6 +69,12 @@ function testOnlineAsr() { @@ -69,6 +69,12 @@ function testOnlineAsr() {
69 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 69 rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
70 fi 70 fi
71 71
  72 + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 ]; then
  73 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  74 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  75 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  76 + fi
  77 +
72 out_filename=test_online_asr.jar 78 out_filename=test_online_asr.jar
73 kotlinc-jvm -include-runtime -d $out_filename \ 79 kotlinc-jvm -include-runtime -d $out_filename \
74 test_online_asr.kt \ 80 test_online_asr.kt \
@@ -160,6 +166,24 @@ function testOfflineAsr() { @@ -160,6 +166,24 @@ function testOfflineAsr() {
160 rm sherpa-onnx-whisper-tiny.en.tar.bz2 166 rm sherpa-onnx-whisper-tiny.en.tar.bz2
161 fi 167 fi
162 168
  169 + if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then
  170 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  171 + tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  172 + rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2
  173 + fi
  174 +
  175 + if [ ! -f ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt ]; then
  176 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  177 + tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  178 + rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  179 + fi
  180 +
  181 + if [ ! -f ./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/tokens.txt ]; then
  182 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
  183 + tar xvf sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
  184 + rm sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2
  185 + fi
  186 +
163 out_filename=test_offline_asr.jar 187 out_filename=test_offline_asr.jar
164 kotlinc-jvm -include-runtime -d $out_filename \ 188 kotlinc-jvm -include-runtime -d $out_filename \
165 test_offline_asr.kt \ 189 test_offline_asr.kt \
1 package com.k2fsa.sherpa.onnx 1 package com.k2fsa.sherpa.onnx
2 2
3 fun main() { 3 fun main() {
4 - val recognizer = createOfflineRecognizer() 4 + val types = arrayOf(0, 2, 5, 6)
  5 + for (type in types) {
  6 + test(type)
  7 + }
  8 +}
  9 +
  10 +fun test(type: Int) {
  11 + val recognizer = createOfflineRecognizer(type)
5 12
6 - val waveFilename = "./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav" 13 + val waveFilename = when (type) {
  14 + 0 -> "./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
  15 + 2 -> "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"
  16 + 5 -> "./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/test_wavs/1.wav"
  17 + 6 -> "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
  18 + else -> null
  19 + }
7 20
8 val objArray = WaveReader.readWaveFromFile( 21 val objArray = WaveReader.readWaveFromFile(
9 - filename = waveFilename, 22 + filename = waveFilename!!,
10 ) 23 )
11 val samples: FloatArray = objArray[0] as FloatArray 24 val samples: FloatArray = objArray[0] as FloatArray
12 val sampleRate: Int = objArray[1] as Int 25 val sampleRate: Int = objArray[1] as Int
@@ -22,10 +35,10 @@ fun main() { @@ -22,10 +35,10 @@ fun main() {
22 recognizer.release() 35 recognizer.release()
23 } 36 }
24 37
25 -fun createOfflineRecognizer(): OfflineRecognizer { 38 +fun createOfflineRecognizer(type: Int): OfflineRecognizer {
26 val config = OfflineRecognizerConfig( 39 val config = OfflineRecognizerConfig(
27 featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), 40 featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
28 - modelConfig = getOfflineModelConfig(type = 2)!!, 41 + modelConfig = getOfflineModelConfig(type = type)!!,
29 ) 42 )
30 43
31 return OfflineRecognizer(config = config) 44 return OfflineRecognizer(config = config)
@@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx @@ -3,6 +3,7 @@ package com.k2fsa.sherpa.onnx
3 fun main() { 3 fun main() {
4 testOnlineAsr("transducer") 4 testOnlineAsr("transducer")
5 testOnlineAsr("zipformer2-ctc") 5 testOnlineAsr("zipformer2-ctc")
  6 + testOnlineAsr("ctc-hlg")
6 } 7 }
7 8
8 fun testOnlineAsr(type: String) { 9 fun testOnlineAsr(type: String) {
@@ -11,6 +12,7 @@ fun testOnlineAsr(type: String) { @@ -11,6 +12,7 @@ fun testOnlineAsr(type: String) {
11 featureDim = 80, 12 featureDim = 80,
12 ) 13 )
13 14
  15 + var ctcFstDecoderConfig = OnlineCtcFstDecoderConfig()
14 val waveFilename: String 16 val waveFilename: String
15 val modelConfig: OnlineModelConfig = when (type) { 17 val modelConfig: OnlineModelConfig = when (type) {
16 "transducer" -> { 18 "transducer" -> {
@@ -40,6 +42,18 @@ fun testOnlineAsr(type: String) { @@ -40,6 +42,18 @@ fun testOnlineAsr(type: String) {
40 debug = false, 42 debug = false,
41 ) 43 )
42 } 44 }
  45 + "ctc-hlg" -> {
  46 + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/1.wav"
  47 + ctcFstDecoderConfig.graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst"
  48 + OnlineModelConfig(
  49 + zipformer2Ctc = OnlineZipformer2CtcModelConfig(
  50 + model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx",
  51 + ),
  52 + tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt",
  53 + numThreads = 1,
  54 + debug = false,
  55 + )
  56 + }
43 else -> throw IllegalArgumentException(type) 57 else -> throw IllegalArgumentException(type)
44 } 58 }
45 59
@@ -51,6 +65,7 @@ fun testOnlineAsr(type: String) { @@ -51,6 +65,7 @@ fun testOnlineAsr(type: String) {
51 modelConfig = modelConfig, 65 modelConfig = modelConfig,
52 lmConfig = lmConfig, 66 lmConfig = lmConfig,
53 featConfig = featConfig, 67 featConfig = featConfig,
  68 + ctcFstDecoderConfig=ctcFstDecoderConfig,
54 endpointConfig = endpointConfig, 69 endpointConfig = endpointConfig,
55 enableEndpoint = true, 70 enableEndpoint = true,
56 decodingMethod = "greedy_search", 71 decodingMethod = "greedy_search",
@@ -14,6 +14,7 @@ java_files += OnlineParaformerModelConfig.java @@ -14,6 +14,7 @@ java_files += OnlineParaformerModelConfig.java
14 java_files += OnlineZipformer2CtcModelConfig.java 14 java_files += OnlineZipformer2CtcModelConfig.java
15 java_files += OnlineTransducerModelConfig.java 15 java_files += OnlineTransducerModelConfig.java
16 java_files += OnlineModelConfig.java 16 java_files += OnlineModelConfig.java
  17 +java_files += OnlineCtcFstDecoderConfig.java
17 java_files += OnlineStream.java 18 java_files += OnlineStream.java
18 java_files += OnlineRecognizerConfig.java 19 java_files += OnlineRecognizerConfig.java
19 java_files += OnlineRecognizerResult.java 20 java_files += OnlineRecognizerResult.java
@@ -22,6 +23,7 @@ java_files += OnlineRecognizer.java @@ -22,6 +23,7 @@ java_files += OnlineRecognizer.java
22 java_files += OfflineTransducerModelConfig.java 23 java_files += OfflineTransducerModelConfig.java
23 java_files += OfflineParaformerModelConfig.java 24 java_files += OfflineParaformerModelConfig.java
24 java_files += OfflineWhisperModelConfig.java 25 java_files += OfflineWhisperModelConfig.java
  26 +java_files += OfflineNemoEncDecCtcModelConfig.java
25 java_files += OfflineModelConfig.java 27 java_files += OfflineModelConfig.java
26 java_files += OfflineRecognizerConfig.java 28 java_files += OfflineRecognizerConfig.java
27 java_files += OfflineRecognizerResult.java 29 java_files += OfflineRecognizerResult.java
@@ -42,10 +44,12 @@ $(info -- class files $(class_files)) @@ -42,10 +44,12 @@ $(info -- class files $(class_files))
42 all: $(out_jar) 44 all: $(out_jar)
43 45
44 $(out_jar): $(class_files) 46 $(out_jar): $(class_files)
45 - jar --create --verbose --file $(out_jar) -C $(out_dir) . 47 + # jar --create --verbose --file $(out_jar) -C $(out_dir) ./
  48 + jar cvf $(out_jar) -C $(out_dir) ./
46 49
47 clean: 50 clean:
48 $(RM) -rfv $(out_dir) 51 $(RM) -rfv $(out_dir)
49 52
50 $(class_files): $(out_dir)/$(package_dir)/%.class: src/$(package_dir)/%.java 53 $(class_files): $(out_dir)/$(package_dir)/%.class: src/$(package_dir)/%.java
51 - javac -d $(out_dir) --class-path $(out_dir) $< 54 + mkdir -p build
  55 + javac -d $(out_dir) -cp $(out_dir) $<
@@ -5,6 +5,7 @@ public class OfflineModelConfig { @@ -5,6 +5,7 @@ public class OfflineModelConfig {
5 private final OfflineTransducerModelConfig transducer; 5 private final OfflineTransducerModelConfig transducer;
6 private final OfflineParaformerModelConfig paraformer; 6 private final OfflineParaformerModelConfig paraformer;
7 private final OfflineWhisperModelConfig whisper; 7 private final OfflineWhisperModelConfig whisper;
  8 + private final OfflineNemoEncDecCtcModelConfig nemo;
8 private final String tokens; 9 private final String tokens;
9 private final int numThreads; 10 private final int numThreads;
10 private final boolean debug; 11 private final boolean debug;
@@ -16,6 +17,7 @@ public class OfflineModelConfig { @@ -16,6 +17,7 @@ public class OfflineModelConfig {
16 this.transducer = builder.transducer; 17 this.transducer = builder.transducer;
17 this.paraformer = builder.paraformer; 18 this.paraformer = builder.paraformer;
18 this.whisper = builder.whisper; 19 this.whisper = builder.whisper;
  20 + this.nemo = builder.nemo;
19 this.tokens = builder.tokens; 21 this.tokens = builder.tokens;
20 this.numThreads = builder.numThreads; 22 this.numThreads = builder.numThreads;
21 this.debug = builder.debug; 23 this.debug = builder.debug;
@@ -64,6 +66,7 @@ public class OfflineModelConfig { @@ -64,6 +66,7 @@ public class OfflineModelConfig {
64 private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build(); 66 private OfflineParaformerModelConfig paraformer = OfflineParaformerModelConfig.builder().build();
65 private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build(); 67 private OfflineTransducerModelConfig transducer = OfflineTransducerModelConfig.builder().build();
66 private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build(); 68 private OfflineWhisperModelConfig whisper = OfflineWhisperModelConfig.builder().build();
  69 + private OfflineNemoEncDecCtcModelConfig nemo = OfflineNemoEncDecCtcModelConfig.builder().build();
67 private String tokens = ""; 70 private String tokens = "";
68 private int numThreads = 1; 71 private int numThreads = 1;
69 private boolean debug = true; 72 private boolean debug = true;
@@ -84,6 +87,11 @@ public class OfflineModelConfig { @@ -84,6 +87,11 @@ public class OfflineModelConfig {
84 return this; 87 return this;
85 } 88 }
86 89
  90 + public Builder setNemo(OfflineNemoEncDecCtcModelConfig nemo) {
  91 + this.nemo = nemo;
  92 + return this;
  93 + }
  94 +
87 public Builder setWhisper(OfflineWhisperModelConfig whisper) { 95 public Builder setWhisper(OfflineWhisperModelConfig whisper) {
88 this.whisper = whisper; 96 this.whisper = whisper;
89 return this; 97 return this;
  1 +// Copyright 2024 Xiaomi Corporation
  2 +package com.k2fsa.sherpa.onnx;
  3 +
  4 +public class OfflineNemoEncDecCtcModelConfig {
  5 + private final String model;
  6 +
  7 + private OfflineNemoEncDecCtcModelConfig(Builder builder) {
  8 + this.model = builder.model;
  9 + }
  10 +
  11 + public static Builder builder() {
  12 + return new Builder();
  13 + }
  14 +
  15 + public String getModel() {
  16 + return model;
  17 + }
  18 +
  19 + public static class Builder {
  20 + private String model = "";
  21 +
  22 + public OfflineNemoEncDecCtcModelConfig build() {
  23 + return new OfflineNemoEncDecCtcModelConfig(this);
  24 + }
  25 +
  26 + public Builder setModel(String model) {
  27 + this.model = model;
  28 + return this;
  29 + }
  30 + }
  31 +}
  1 +// Copyright 2024 Xiaomi Corporation
  2 +package com.k2fsa.sherpa.onnx;
  3 +
  4 +public class OnlineCtcFstDecoderConfig {
  5 + private final String graph;
  6 + private final int maxActive;
  7 +
  8 + private OnlineCtcFstDecoderConfig(Builder builder) {
  9 + this.graph = builder.graph;
  10 + this.maxActive = builder.maxActive;
  11 + }
  12 +
  13 + public static Builder builder() {
  14 + return new Builder();
  15 + }
  16 +
  17 + public String getGraph() {
  18 + return graph;
  19 + }
  20 +
  21 + public float getMaxActive() {
  22 + return maxActive;
  23 + }
  24 +
  25 + public static class Builder {
  26 + private String graph = "";
  27 + private int maxActive = 3000;
  28 +
  29 + public OnlineCtcFstDecoderConfig build() {
  30 + return new OnlineCtcFstDecoderConfig(this);
  31 + }
  32 +
  33 + public Builder setGraph(String model) {
  34 + this.graph = graph;
  35 + return this;
  36 + }
  37 +
  38 + public Builder setMaxActive(int maxActive) {
  39 + this.maxActive = maxActive;
  40 + return this;
  41 + }
  42 + }
  43 +}
@@ -6,6 +6,8 @@ public class OnlineRecognizerConfig { @@ -6,6 +6,8 @@ public class OnlineRecognizerConfig {
6 private final FeatureConfig featConfig; 6 private final FeatureConfig featConfig;
7 private final OnlineModelConfig modelConfig; 7 private final OnlineModelConfig modelConfig;
8 private final OnlineLMConfig lmConfig; 8 private final OnlineLMConfig lmConfig;
  9 +
  10 + private final OnlineCtcFstDecoderConfig ctcFstDecoderConfig;
9 private final EndpointConfig endpointConfig; 11 private final EndpointConfig endpointConfig;
10 private final boolean enableEndpoint; 12 private final boolean enableEndpoint;
11 private final String decodingMethod; 13 private final String decodingMethod;
@@ -17,6 +19,7 @@ public class OnlineRecognizerConfig { @@ -17,6 +19,7 @@ public class OnlineRecognizerConfig {
17 this.featConfig = builder.featConfig; 19 this.featConfig = builder.featConfig;
18 this.modelConfig = builder.modelConfig; 20 this.modelConfig = builder.modelConfig;
19 this.lmConfig = builder.lmConfig; 21 this.lmConfig = builder.lmConfig;
  22 + this.ctcFstDecoderConfig = builder.ctcFstDecoderConfig;
20 this.endpointConfig = builder.endpointConfig; 23 this.endpointConfig = builder.endpointConfig;
21 this.enableEndpoint = builder.enableEndpoint; 24 this.enableEndpoint = builder.enableEndpoint;
22 this.decodingMethod = builder.decodingMethod; 25 this.decodingMethod = builder.decodingMethod;
@@ -37,6 +40,7 @@ public class OnlineRecognizerConfig { @@ -37,6 +40,7 @@ public class OnlineRecognizerConfig {
37 private FeatureConfig featConfig = FeatureConfig.builder().build(); 40 private FeatureConfig featConfig = FeatureConfig.builder().build();
38 private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build(); 41 private OnlineModelConfig modelConfig = OnlineModelConfig.builder().build();
39 private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build(); 42 private OnlineLMConfig lmConfig = OnlineLMConfig.builder().build();
  43 + private OnlineCtcFstDecoderConfig ctcFstDecoderConfig = OnlineCtcFstDecoderConfig.builder().build();
40 private EndpointConfig endpointConfig = EndpointConfig.builder().build(); 44 private EndpointConfig endpointConfig = EndpointConfig.builder().build();
41 private boolean enableEndpoint = true; 45 private boolean enableEndpoint = true;
42 private String decodingMethod = "greedy_search"; 46 private String decodingMethod = "greedy_search";
@@ -63,6 +67,11 @@ public class OnlineRecognizerConfig { @@ -63,6 +67,11 @@ public class OnlineRecognizerConfig {
63 return this; 67 return this;
64 } 68 }
65 69
  70 + public Builder setCtcFstDecoderConfig(OnlineCtcFstDecoderConfig ctcFstDecoderConfig) {
  71 + this.ctcFstDecoderConfig = ctcFstDecoderConfig;
  72 + return this;
  73 + }
  74 +
66 public Builder setEndpointConfig(EndpointConfig endpointConfig) { 75 public Builder setEndpointConfig(EndpointConfig endpointConfig) {
67 this.endpointConfig = endpointConfig; 76 this.endpointConfig = endpointConfig;
68 return this; 77 return this;
@@ -147,6 +147,19 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { @@ -147,6 +147,19 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
147 ans.model_config.whisper.tail_paddings = 147 ans.model_config.whisper.tail_paddings =
148 env->GetIntField(whisper_config, fid); 148 env->GetIntField(whisper_config, fid);
149 149
  150 + fid = env->GetFieldID(
  151 + model_config_cls, "nemo",
  152 + "Lcom/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig;");
  153 + jobject nemo_config = env->GetObjectField(model_config, fid);
  154 + jclass nemo_config_cls = env->GetObjectClass(nemo_config);
  155 +
  156 + fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
  157 +
  158 + s = (jstring)env->GetObjectField(nemo_config, fid);
  159 + p = env->GetStringUTFChars(s, nullptr);
  160 + ans.model_config.nemo_ctc.model = p;
  161 + env->ReleaseStringUTFChars(s, p);
  162 +
150 return ans; 163 return ans;
151 } 164 }
152 165
@@ -198,6 +198,22 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -198,6 +198,22 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
198 fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); 198 fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
199 ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); 199 ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
200 200
  201 + fid = env->GetFieldID(cls, "ctcFstDecoderConfig",
  202 + "Lcom/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig;");
  203 +
  204 + jobject fst_decoder_config = env->GetObjectField(config, fid);
  205 + jclass fst_decoder_config_cls = env->GetObjectClass(fst_decoder_config);
  206 +
  207 + fid = env->GetFieldID(fst_decoder_config_cls, "graph", "Ljava/lang/String;");
  208 + s = (jstring)env->GetObjectField(fst_decoder_config, fid);
  209 + p = env->GetStringUTFChars(s, nullptr);
  210 + ans.ctc_fst_decoder_config.graph = p;
  211 + env->ReleaseStringUTFChars(s, p);
  212 +
  213 + fid = env->GetFieldID(fst_decoder_config_cls, "maxActive", "I");
  214 + ans.ctc_fst_decoder_config.max_active =
  215 + env->GetIntField(fst_decoder_config, fid);
  216 +
201 return ans; 217 return ans;
202 } 218 }
203 } // namespace sherpa_onnx 219 } // namespace sherpa_onnx
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #include <fstream> 6 #include <fstream>
7 7
8 #include "sherpa-onnx/csrc/macros.h" 8 #include "sherpa-onnx/csrc/macros.h"
  9 +#include "sherpa-onnx/csrc/onnx-utils.h"
9 #include "sherpa-onnx/jni/common.h" 10 #include "sherpa-onnx/jni/common.h"
10 11
11 static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, 12 static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is,
@@ -18,6 +18,10 @@ data class OfflineParaformerModelConfig( @@ -18,6 +18,10 @@ data class OfflineParaformerModelConfig(
18 var model: String = "", 18 var model: String = "",
19 ) 19 )
20 20
  21 +data class OfflineNemoEncDecCtcModelConfig(
  22 + var model: String = "",
  23 +)
  24 +
21 data class OfflineWhisperModelConfig( 25 data class OfflineWhisperModelConfig(
22 var encoder: String = "", 26 var encoder: String = "",
23 var decoder: String = "", 27 var decoder: String = "",
@@ -30,6 +34,7 @@ data class OfflineModelConfig( @@ -30,6 +34,7 @@ data class OfflineModelConfig(
30 var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), 34 var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
31 var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), 35 var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
32 var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), 36 var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
  37 + var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(),
33 var numThreads: Int = 1, 38 var numThreads: Int = 1,
34 var debug: Boolean = false, 39 var debug: Boolean = false,
35 var provider: String = "cpu", 40 var provider: String = "cpu",
@@ -216,6 +221,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? { @@ -216,6 +221,16 @@ fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
216 ) 221 )
217 } 222 }
218 223
  224 + 6 -> {
  225 + val modelDir = "sherpa-onnx-nemo-ctc-en-citrinet-512"
  226 + return OfflineModelConfig(
  227 + nemo = OfflineNemoEncDecCtcModelConfig(
  228 + model = "$modelDir/model.int8.onnx",
  229 + ),
  230 + tokens = "$modelDir/tokens.txt",
  231 + )
  232 + }
  233 +
219 } 234 }
220 return null 235 return null
221 } 236 }
@@ -45,11 +45,17 @@ data class OnlineLMConfig( @@ -45,11 +45,17 @@ data class OnlineLMConfig(
45 var scale: Float = 0.5f, 45 var scale: Float = 0.5f,
46 ) 46 )
47 47
  48 +data class OnlineCtcFstDecoderConfig(
  49 + var graph: String = "",
  50 + var maxActive: Int = 3000,
  51 +)
  52 +
48 53
49 data class OnlineRecognizerConfig( 54 data class OnlineRecognizerConfig(
50 var featConfig: FeatureConfig = FeatureConfig(), 55 var featConfig: FeatureConfig = FeatureConfig(),
51 var modelConfig: OnlineModelConfig, 56 var modelConfig: OnlineModelConfig,
52 var lmConfig: OnlineLMConfig = OnlineLMConfig(), 57 var lmConfig: OnlineLMConfig = OnlineLMConfig(),
  58 + var ctcFstDecoderConfig : OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(),
53 var endpointConfig: EndpointConfig = EndpointConfig(), 59 var endpointConfig: EndpointConfig = EndpointConfig(),
54 var enableEndpoint: Boolean = true, 60 var enableEndpoint: Boolean = true,
55 var decodingMethod: String = "greedy_search", 61 var decodingMethod: String = "greedy_search",