Fangjun Kuang
Committed by GitHub

Add non-streaming ASR (#92)

正在显示 48 个修改的文件 包含 1526 行增加150 行删除
@@ -33,18 +33,20 @@ fun main() { @@ -33,18 +33,20 @@ fun main() {
33 config = config, 33 config = config,
34 ) 34 )
35 35
36 - var samples = WaveReader.readWave( 36 + var objArray = WaveReader.readWave(
37 assetManager = AssetManager(), 37 assetManager = AssetManager(),
38 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", 38 filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
39 ) 39 )
  40 + var samples : FloatArray = objArray[0] as FloatArray
  41 + var sampleRate : Int = objArray[1] as Int
40 42
41 - model.acceptWaveform(samples!!, sampleRate=16000) 43 + model.acceptWaveform(samples, sampleRate=sampleRate)
42 while (model.isReady()) { 44 while (model.isReady()) {
43 model.decode() 45 model.decode()
44 } 46 }
45 47
46 - var tail_paddings = FloatArray(8000) // 0.5 seconds  
47 - model.acceptWaveform(tail_paddings, sampleRate=16000) 48 + var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
  49 + model.acceptWaveform(tail_paddings, sampleRate=sampleRate)
48 model.inputFinished() 50 model.inputFinished()
49 while (model.isReady()) { 51 while (model.isReady()) {
50 model.decode() 52 model.decode()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run Conformer transducer (English)"
  18 +log "------------------------------------------------------------"
  19 +
  20 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-conformer-en-2023-03-18
  21 +log "Start testing ${repo_url}"
  22 +repo=$(basename $repo_url)
  23 +log "Download pretrained model and test-data from $repo_url"
  24 +
  25 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  26 +pushd $repo
  27 +git lfs pull --include "*.onnx"
  28 +cd test_wavs
  29 +popd
  30 +
  31 +waves=(
  32 +$repo/test_wavs/0.wav
  33 +$repo/test_wavs/1.wav
  34 +$repo/test_wavs/2.wav
  35 +)
  36 +
  37 +for wave in ${waves[@]}; do
  38 + time $EXE \
  39 + $repo/tokens.txt \
  40 + $repo/encoder-epoch-99-avg-1.onnx \
  41 + $repo/decoder-epoch-99-avg-1.onnx \
  42 + $repo/joiner-epoch-99-avg-1.onnx \
  43 + $wave \
  44 + 2
  45 +done
  46 +
  47 +
  48 +if command -v sox &> /dev/null; then
  49 + echo "test 8kHz"
  50 + sox $repo/test_wavs/0.wav -r 8000 8k.wav
  51 + time $EXE \
  52 + $repo/tokens.txt \
  53 + $repo/encoder-epoch-99-avg-1.onnx \
  54 + $repo/decoder-epoch-99-avg-1.onnx \
  55 + $repo/joiner-epoch-99-avg-1.onnx \
  56 + 8k.wav \
  57 + 2
  58 +fi
  59 +
  60 +rm -rf $repo
@@ -40,7 +40,7 @@ for wave in ${waves[@]}; do @@ -40,7 +40,7 @@ for wave in ${waves[@]}; do
40 $repo/decoder-epoch-99-avg-1.onnx \ 40 $repo/decoder-epoch-99-avg-1.onnx \
41 $repo/joiner-epoch-99-avg-1.onnx \ 41 $repo/joiner-epoch-99-avg-1.onnx \
42 $wave \ 42 $wave \
43 - 4 43 + 2
44 done 44 done
45 45
46 rm -rf $repo 46 rm -rf $repo
@@ -72,7 +72,7 @@ for wave in ${waves[@]}; do @@ -72,7 +72,7 @@ for wave in ${waves[@]}; do
72 $repo/decoder-epoch-11-avg-1.onnx \ 72 $repo/decoder-epoch-11-avg-1.onnx \
73 $repo/joiner-epoch-11-avg-1.onnx \ 73 $repo/joiner-epoch-11-avg-1.onnx \
74 $wave \ 74 $wave \
75 - 4 75 + 2
76 done 76 done
77 77
78 rm -rf $repo 78 rm -rf $repo
@@ -104,7 +104,7 @@ for wave in ${waves[@]}; do @@ -104,7 +104,7 @@ for wave in ${waves[@]}; do
104 $repo/decoder-epoch-99-avg-1.onnx \ 104 $repo/decoder-epoch-99-avg-1.onnx \
105 $repo/joiner-epoch-99-avg-1.onnx \ 105 $repo/joiner-epoch-99-avg-1.onnx \
106 $wave \ 106 $wave \
107 - 4 107 + 2
108 done 108 done
109 109
110 rm -rf $repo 110 rm -rf $repo
@@ -138,7 +138,7 @@ for wave in ${waves[@]}; do @@ -138,7 +138,7 @@ for wave in ${waves[@]}; do
138 $repo/decoder-epoch-99-avg-1.onnx \ 138 $repo/decoder-epoch-99-avg-1.onnx \
139 $repo/joiner-epoch-99-avg-1.onnx \ 139 $repo/joiner-epoch-99-avg-1.onnx \
140 $wave \ 140 $wave \
141 - 4 141 + 2
142 done 142 done
143 143
144 # Decode a URL 144 # Decode a URL
@@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then @@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
149 $repo/decoder-epoch-99-avg-1.onnx \ 149 $repo/decoder-epoch-99-avg-1.onnx \
150 $repo/joiner-epoch-99-avg-1.onnx \ 150 $repo/joiner-epoch-99-avg-1.onnx \
151 https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ 151 https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \
152 - 4 152 + 2
153 fi 153 fi
154 154
155 rm -rf $repo 155 rm -rf $repo
@@ -7,11 +7,11 @@ on: @@ -7,11 +7,11 @@ on:
7 paths: 7 paths:
8 - '.github/workflows/linux.yaml' 8 - '.github/workflows/linux.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
  10 + - '.github/scripts/test-offline-transducer.sh'
10 - 'CMakeLists.txt' 11 - 'CMakeLists.txt'
11 - 'cmake/**' 12 - 'cmake/**'
12 - 'sherpa-onnx/csrc/*' 13 - 'sherpa-onnx/csrc/*'
13 - 'sherpa-onnx/c-api/*' 14 - 'sherpa-onnx/c-api/*'
14 - - 'ffmpeg-examples/**'  
15 - 'c-api-examples/**' 15 - 'c-api-examples/**'
16 pull_request: 16 pull_request:
17 branches: 17 branches:
@@ -19,11 +19,11 @@ on: @@ -19,11 +19,11 @@ on:
19 paths: 19 paths:
20 - '.github/workflows/linux.yaml' 20 - '.github/workflows/linux.yaml'
21 - '.github/scripts/test-online-transducer.sh' 21 - '.github/scripts/test-online-transducer.sh'
  22 + - '.github/scripts/test-offline-transducer.sh'
22 - 'CMakeLists.txt' 23 - 'CMakeLists.txt'
23 - 'cmake/**' 24 - 'cmake/**'
24 - 'sherpa-onnx/csrc/*' 25 - 'sherpa-onnx/csrc/*'
25 - 'sherpa-onnx/c-api/*' 26 - 'sherpa-onnx/c-api/*'
26 - - 'ffmpeg-examples/**'  
27 27
28 concurrency: 28 concurrency:
29 group: linux-${{ github.ref }} 29 group: linux-${{ github.ref }}
@@ -39,35 +39,26 @@ jobs: @@ -39,35 +39,26 @@ jobs:
39 fail-fast: false 39 fail-fast: false
40 matrix: 40 matrix:
41 os: [ubuntu-latest] 41 os: [ubuntu-latest]
  42 + build_type: [Release, Debug]
42 43
43 steps: 44 steps:
44 - uses: actions/checkout@v2 45 - uses: actions/checkout@v2
45 with: 46 with:
46 fetch-depth: 0 47 fetch-depth: 0
47 48
48 - - name: Install ffmpeg 49 + - name: Install sox
49 shell: bash 50 shell: bash
50 run: | 51 run: |
51 - sudo apt-get install -y software-properties-common  
52 - sudo add-apt-repository ppa:savoury1/ffmpeg4  
53 - sudo add-apt-repository ppa:savoury1/ffmpeg5  
54 -  
55 - sudo apt-get install -y libavdevice-dev libavutil-dev ffmpeg  
56 - pkg-config --modversion libavutil  
57 - ffmpeg -version  
58 -  
59 - - name: Show ffmpeg version  
60 - shell: bash  
61 - run: |  
62 - pkg-config --modversion libavutil  
63 - ffmpeg -version 52 + sudo apt-get update
  53 + sudo apt-get install -y sox
  54 + sox -h
64 55
65 - name: Configure CMake 56 - name: Configure CMake
66 shell: bash 57 shell: bash
67 run: | 58 run: |
68 mkdir build 59 mkdir build
69 cd build 60 cd build
70 - cmake -D CMAKE_BUILD_TYPE=Release .. 61 + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
71 62
72 - name: Build sherpa-onnx for ubuntu 63 - name: Build sherpa-onnx for ubuntu
73 shell: bash 64 shell: bash
@@ -78,21 +69,19 @@ jobs: @@ -78,21 +69,19 @@ jobs:
78 ls -lh lib 69 ls -lh lib
79 ls -lh bin 70 ls -lh bin
80 71
81 - cd ../ffmpeg-examples  
82 - make  
83 -  
84 - name: Display dependencies of sherpa-onnx for linux 72 - name: Display dependencies of sherpa-onnx for linux
85 shell: bash 73 shell: bash
86 run: | 74 run: |
87 file build/bin/sherpa-onnx 75 file build/bin/sherpa-onnx
88 readelf -d build/bin/sherpa-onnx 76 readelf -d build/bin/sherpa-onnx
89 77
90 - - name: Test sherpa-onnx-ffmpeg 78 + - name: Test offline transducer
  79 + shell: bash
91 run: | 80 run: |
92 - export PATH=$PWD/ffmpeg-examples:$PATH  
93 - export EXE=sherpa-onnx-ffmpeg 81 + export PATH=$PWD/build/bin:$PATH
  82 + export EXE=sherpa-onnx-offline
94 83
95 - .github/scripts/test-online-transducer.sh 84 + .github/scripts/test-offline-transducer.sh
96 85
97 - name: Test online transducer 86 - name: Test online transducer
98 shell: bash 87 shell: bash
@@ -7,6 +7,7 @@ on: @@ -7,6 +7,7 @@ on:
7 paths: 7 paths:
8 - '.github/workflows/macos.yaml' 8 - '.github/workflows/macos.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
  10 + - '.github/scripts/test-offline-transducer.sh'
10 - 'CMakeLists.txt' 11 - 'CMakeLists.txt'
11 - 'cmake/**' 12 - 'cmake/**'
12 - 'sherpa-onnx/csrc/*' 13 - 'sherpa-onnx/csrc/*'
@@ -16,6 +17,7 @@ on: @@ -16,6 +17,7 @@ on:
16 paths: 17 paths:
17 - '.github/workflows/macos.yaml' 18 - '.github/workflows/macos.yaml'
18 - '.github/scripts/test-online-transducer.sh' 19 - '.github/scripts/test-online-transducer.sh'
  20 + - '.github/scripts/test-offline-transducer.sh'
19 - 'CMakeLists.txt' 21 - 'CMakeLists.txt'
20 - 'cmake/**' 22 - 'cmake/**'
21 - 'sherpa-onnx/csrc/*' 23 - 'sherpa-onnx/csrc/*'
@@ -34,18 +36,25 @@ jobs: @@ -34,18 +36,25 @@ jobs:
34 fail-fast: false 36 fail-fast: false
35 matrix: 37 matrix:
36 os: [macos-latest] 38 os: [macos-latest]
  39 + build_type: [Release, Debug]
37 40
38 steps: 41 steps:
39 - uses: actions/checkout@v2 42 - uses: actions/checkout@v2
40 with: 43 with:
41 fetch-depth: 0 44 fetch-depth: 0
42 45
  46 + - name: Install sox
  47 + shell: bash
  48 + run: |
  49 + brew install sox
  50 + sox -h
  51 +
43 - name: Configure CMake 52 - name: Configure CMake
44 shell: bash 53 shell: bash
45 run: | 54 run: |
46 mkdir build 55 mkdir build
47 cd build 56 cd build
48 - cmake -D CMAKE_BUILD_TYPE=Release .. 57 + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
49 58
50 - name: Build sherpa-onnx for macos 59 - name: Build sherpa-onnx for macos
51 shell: bash 60 shell: bash
@@ -64,6 +73,14 @@ jobs: @@ -64,6 +73,14 @@ jobs:
64 otool -L build/bin/sherpa-onnx 73 otool -L build/bin/sherpa-onnx
65 otool -l build/bin/sherpa-onnx 74 otool -l build/bin/sherpa-onnx
66 75
  76 + - name: Test offline transducer
  77 + shell: bash
  78 + run: |
  79 + export PATH=$PWD/build/bin:$PATH
  80 + export EXE=sherpa-onnx-offline
  81 +
  82 + .github/scripts/test-offline-transducer.sh
  83 +
67 - name: Test online transducer 84 - name: Test online transducer
68 shell: bash 85 shell: bash
69 run: | 86 run: |
@@ -39,3 +39,5 @@ tags @@ -39,3 +39,5 @@ tags
39 run-decode-file-python.sh 39 run-decode-file-python.sh
40 android/SherpaOnnx/app/src/main/assets/ 40 android/SherpaOnnx/app/src/main/assets/
41 *.ncnn.* 41 *.ncnn.*
  42 +run-sherpa-onnx-offline.sh
  43 +sherpa-onnx-conformer-en-2023-03-18
@@ -13,7 +13,7 @@ endif() @@ -13,7 +13,7 @@ endif()
13 13
14 option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) 14 option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
15 option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) 15 option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
16 -option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) 16 +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF)
17 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) 17 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) 18 option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) 19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
@@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() { @@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() {
121 val ret = audioRecord?.read(buffer, 0, buffer.size) 121 val ret = audioRecord?.read(buffer, 0, buffer.size)
122 if (ret != null && ret > 0) { 122 if (ret != null && ret > 0) {
123 val samples = FloatArray(ret) { buffer[it] / 32768.0f } 123 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
124 - model.acceptWaveform(samples, sampleRate=16000) 124 + model.acceptWaveform(samples, sampleRate=sampleRateInHz)
125 while (model.isReady()) { 125 while (model.isReady()) {
126 model.decode() 126 model.decode()
127 } 127 }
@@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() { @@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() {
180 val type = 0 180 val type = 0
181 println("Select model type ${type}") 181 println("Select model type ${type}")
182 val config = OnlineRecognizerConfig( 182 val config = OnlineRecognizerConfig(
183 - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), 183 + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
184 modelConfig = getModelConfig(type = type)!!, 184 modelConfig = getModelConfig(type = type)!!,
185 endpointConfig = getEndpointConfig(), 185 endpointConfig = getEndpointConfig(),
186 enableEndpoint = true, 186 enableEndpoint = true,
@@ -8,7 +8,7 @@ class WaveReader { @@ -8,7 +8,7 @@ class WaveReader {
8 // No resampling is made. 8 // No resampling is made.
9 external fun readWave( 9 external fun readWave(
10 assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f 10 assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
11 - ): FloatArray? 11 + ): Array<Any>
12 12
13 init { 13 init {
14 System.loadLibrary("sherpa-onnx-jni") 14 System.loadLibrary("sherpa-onnx-jni")
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.13.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.14.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.14.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=6a66638a111d3ce21fe6f29cbf9ab3dbcae2331c77391bf825927df5cbf2babe")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz  
16 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz  
17 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz  
18 - /tmp/kaldi-native-fbank-1.13.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.14.tar.gz
  16 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.14.tar.gz
  17 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.14.tar.gz
  18 + /tmp/kaldi-native-fbank-1.14.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.14.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
@@ -91,7 +91,6 @@ def create_recognizer(): @@ -91,7 +91,6 @@ def create_recognizer():
91 rule2_min_trailing_silence=1.2, 91 rule2_min_trailing_silence=1.2,
92 rule3_min_utterance_length=300, # it essentially disables this rule 92 rule3_min_utterance_length=300, # it essentially disables this rule
93 decoding_method=args.decoding_method, 93 decoding_method=args.decoding_method,
94 - max_feature_vectors=100, # 1 second  
95 ) 94 )
96 return recognizer 95 return recognizer
97 96
@@ -86,7 +86,6 @@ def create_recognizer(): @@ -86,7 +86,6 @@ def create_recognizer():
86 sample_rate=16000, 86 sample_rate=16000,
87 feature_dim=80, 87 feature_dim=80,
88 decoding_method=args.decoding_method, 88 decoding_method=args.decoding_method,
89 - max_feature_vectors=100, # 1 second  
90 ) 89 )
91 return recognizer 90 return recognizer
92 91
@@ -6,6 +6,11 @@ set(sources @@ -6,6 +6,11 @@ set(sources
6 features.cc 6 features.cc
7 file-utils.cc 7 file-utils.cc
8 hypothesis.cc 8 hypothesis.cc
  9 + offline-stream.cc
  10 + offline-transducer-greedy-search-decoder.cc
  11 + offline-transducer-model-config.cc
  12 + offline-transducer-model.cc
  13 + offline-recognizer.cc
9 online-lstm-transducer-model.cc 14 online-lstm-transducer-model.cc
10 online-recognizer.cc 15 online-recognizer.cc
11 online-stream.cc 16 online-stream.cc
@@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK) @@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK)
56 endif() 61 endif()
57 62
58 add_executable(sherpa-onnx sherpa-onnx.cc) 63 add_executable(sherpa-onnx sherpa-onnx.cc)
  64 +add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
59 65
60 target_link_libraries(sherpa-onnx sherpa-onnx-core) 66 target_link_libraries(sherpa-onnx sherpa-onnx-core)
  67 +target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
61 if(NOT WIN32) 68 if(NOT WIN32)
62 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") 69 target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
  70 + target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
63 endif() 71 endif()
64 72
65 if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) 73 if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
@@ -68,7 +76,13 @@ else() @@ -68,7 +76,13 @@ else()
68 install(TARGETS sherpa-onnx-core DESTINATION lib) 76 install(TARGETS sherpa-onnx-core DESTINATION lib)
69 endif() 77 endif()
70 78
71 -install(TARGETS sherpa-onnx DESTINATION bin) 79 +install(
  80 + TARGETS
  81 + sherpa-onnx
  82 + sherpa-onnx-offline
  83 + DESTINATION
  84 + bin
  85 +)
72 86
73 if(SHERPA_ONNX_HAS_ALSA) 87 if(SHERPA_ONNX_HAS_ALSA)
74 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) 88 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
@@ -19,7 +19,9 @@ namespace sherpa_onnx { @@ -19,7 +19,9 @@ namespace sherpa_onnx {
19 void FeatureExtractorConfig::Register(ParseOptions *po) { 19 void FeatureExtractorConfig::Register(ParseOptions *po) {
20 po->Register("sample-rate", &sampling_rate, 20 po->Register("sample-rate", &sampling_rate,
21 "Sampling rate of the input waveform. Must match the one " 21 "Sampling rate of the input waveform. Must match the one "
22 - "expected by the model."); 22 + "expected by the model. Note: You can have a different "
  23 + "sample rate for the input waveform. We will do resampling "
  24 + "inside the feature extractor");
23 25
24 po->Register("feat-dim", &feature_dim, 26 po->Register("feat-dim", &feature_dim,
25 "Feature dimension. Must match the one expected by the model."); 27 "Feature dimension. Must match the one expected by the model.");
@@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const { @@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const {
30 32
31 os << "FeatureExtractorConfig("; 33 os << "FeatureExtractorConfig(";
32 os << "sampling_rate=" << sampling_rate << ", "; 34 os << "sampling_rate=" << sampling_rate << ", ";
33 - os << "feature_dim=" << feature_dim << ", ";  
34 - os << "max_feature_vectors=" << max_feature_vectors << ")"; 35 + os << "feature_dim=" << feature_dim << ")";
35 36
36 return os.str(); 37 return os.str();
37 } 38 }
@@ -43,8 +44,6 @@ class FeatureExtractor::Impl { @@ -43,8 +44,6 @@ class FeatureExtractor::Impl {
43 opts_.frame_opts.snip_edges = false; 44 opts_.frame_opts.snip_edges = false;
44 opts_.frame_opts.samp_freq = config.sampling_rate; 45 opts_.frame_opts.samp_freq = config.sampling_rate;
45 46
46 - opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;  
47 -  
48 opts_.mel_opts.num_bins = config.feature_dim; 47 opts_.mel_opts.num_bins = config.feature_dim;
49 48
50 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 49 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
@@ -95,7 +94,7 @@ class FeatureExtractor::Impl { @@ -95,7 +94,7 @@ class FeatureExtractor::Impl {
95 fbank_->AcceptWaveform(sampling_rate, waveform, n); 94 fbank_->AcceptWaveform(sampling_rate, waveform, n);
96 } 95 }
97 96
98 - void InputFinished() { 97 + void InputFinished() const {
99 std::lock_guard<std::mutex> lock(mutex_); 98 std::lock_guard<std::mutex> lock(mutex_);
100 fbank_->InputFinished(); 99 fbank_->InputFinished();
101 } 100 }
@@ -110,12 +109,21 @@ class FeatureExtractor::Impl { @@ -110,12 +109,21 @@ class FeatureExtractor::Impl {
110 return fbank_->IsLastFrame(frame); 109 return fbank_->IsLastFrame(frame);
111 } 110 }
112 111
113 - std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {  
114 - if (frame_index + n > NumFramesReady()) {  
115 - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); 112 + std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
  113 + std::lock_guard<std::mutex> lock(mutex_);
  114 + if (frame_index + n > fbank_->NumFramesReady()) {
  115 + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
  116 + fbank_->NumFramesReady());
116 exit(-1); 117 exit(-1);
117 } 118 }
118 - std::lock_guard<std::mutex> lock(mutex_); 119 +
  120 + int32_t discard_num = frame_index - last_frame_index_;
  121 + if (discard_num < 0) {
  122 + SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
  123 + last_frame_index_, frame_index);
  124 + exit(-1);
  125 + }
  126 + fbank_->Pop(discard_num);
119 127
120 int32_t feature_dim = fbank_->Dim(); 128 int32_t feature_dim = fbank_->Dim();
121 std::vector<float> features(feature_dim * n); 129 std::vector<float> features(feature_dim * n);
@@ -128,12 +136,9 @@ class FeatureExtractor::Impl { @@ -128,12 +136,9 @@ class FeatureExtractor::Impl {
128 p += feature_dim; 136 p += feature_dim;
129 } 137 }
130 138
131 - return features;  
132 - } 139 + last_frame_index_ = frame_index;
133 140
134 - void Reset() {  
135 - std::lock_guard<std::mutex> lock(mutex_);  
136 - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 141 + return features;
137 } 142 }
138 143
139 int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } 144 int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
@@ -143,6 +148,7 @@ class FeatureExtractor::Impl { @@ -143,6 +148,7 @@ class FeatureExtractor::Impl {
143 knf::FbankOptions opts_; 148 knf::FbankOptions opts_;
144 mutable std::mutex mutex_; 149 mutable std::mutex mutex_;
145 std::unique_ptr<LinearResample> resampler_; 150 std::unique_ptr<LinearResample> resampler_;
  151 + int32_t last_frame_index_ = 0;
146 }; 152 };
147 153
148 FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) 154 FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
@@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) @@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
151 FeatureExtractor::~FeatureExtractor() = default; 157 FeatureExtractor::~FeatureExtractor() = default;
152 158
153 void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, 159 void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
154 - const float *waveform, int32_t n) { 160 + const float *waveform, int32_t n) const {
155 impl_->AcceptWaveform(sampling_rate, waveform, n); 161 impl_->AcceptWaveform(sampling_rate, waveform, n);
156 } 162 }
157 163
158 -void FeatureExtractor::InputFinished() { impl_->InputFinished(); } 164 +void FeatureExtractor::InputFinished() const { impl_->InputFinished(); }
159 165
160 int32_t FeatureExtractor::NumFramesReady() const { 166 int32_t FeatureExtractor::NumFramesReady() const {
161 return impl_->NumFramesReady(); 167 return impl_->NumFramesReady();
@@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, @@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
170 return impl_->GetFrames(frame_index, n); 176 return impl_->GetFrames(frame_index, n);
171 } 177 }
172 178
173 -void FeatureExtractor::Reset() { impl_->Reset(); }  
174 -  
175 int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } 179 int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
176 180
177 } // namespace sherpa_onnx 181 } // namespace sherpa_onnx
@@ -14,9 +14,12 @@ @@ -14,9 +14,12 @@
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 struct FeatureExtractorConfig { 16 struct FeatureExtractorConfig {
  17 + // Sampling rate used by the feature extractor. If it is different from
  18 + // the sampling rate of the input waveform, we will do resampling inside.
17 int32_t sampling_rate = 16000; 19 int32_t sampling_rate = 16000;
  20 +
  21 + // Feature dimension
18 int32_t feature_dim = 80; 22 int32_t feature_dim = 80;
19 - int32_t max_feature_vectors = -1;  
20 23
21 std::string ToString() const; 24 std::string ToString() const;
22 25
@@ -36,7 +39,8 @@ class FeatureExtractor { @@ -36,7 +39,8 @@ class FeatureExtractor {
36 the range [-1, 1]. 39 the range [-1, 1].
37 @param n Number of entries in waveform 40 @param n Number of entries in waveform
38 */ 41 */
39 - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); 42 + void AcceptWaveform(int32_t sampling_rate, const float *waveform,
  43 + int32_t n) const;
40 44
41 /** 45 /**
42 * InputFinished() tells the class you won't be providing any 46 * InputFinished() tells the class you won't be providing any
@@ -44,7 +48,7 @@ class FeatureExtractor { @@ -44,7 +48,7 @@ class FeatureExtractor {
44 * of features, in the case where snip-edges == false; it also 48 * of features, in the case where snip-edges == false; it also
45 * affects the return value of IsLastFrame(). 49 * affects the return value of IsLastFrame().
46 */ 50 */
47 - void InputFinished(); 51 + void InputFinished() const;
48 52
49 int32_t NumFramesReady() const; 53 int32_t NumFramesReady() const;
50 54
@@ -62,8 +66,6 @@ class FeatureExtractor { @@ -62,8 +66,6 @@ class FeatureExtractor {
62 */ 66 */
63 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; 67 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
64 68
65 - void Reset();  
66 -  
67 /// Return feature dim of this extractor 69 /// Return feature dim of this extractor
68 int32_t FeatureDim() const; 70 int32_t FeatureDim() const;
69 71
  1 +// sherpa-onnx/csrc/offline-recognizer.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  12 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
  13 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  14 +#include "sherpa-onnx/csrc/pad-sequence.h"
  15 +#include "sherpa-onnx/csrc/symbol-table.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +static OfflineRecognitionResult Convert(
  20 + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table,
  21 + int32_t frame_shift_ms, int32_t subsampling_factor) {
  22 + OfflineRecognitionResult r;
  23 + r.tokens.reserve(src.tokens.size());
  24 + r.timestamps.reserve(src.timestamps.size());
  25 +
  26 + std::string text;
  27 + for (auto i : src.tokens) {
  28 + auto sym = sym_table[i];
  29 + text.append(sym);
  30 +
  31 + r.tokens.push_back(std::move(sym));
  32 + }
  33 + r.text = std::move(text);
  34 +
  35 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  36 + for (auto t : src.timestamps) {
  37 + float time = frame_shift_s * t;
  38 + r.timestamps.push_back(time);
  39 + }
  40 +
  41 + return r;
  42 +}
  43 +
  44 +void OfflineRecognizerConfig::Register(ParseOptions *po) {
  45 + feat_config.Register(po);
  46 + model_config.Register(po);
  47 +
  48 + po->Register("decoding-method", &decoding_method,
  49 + "decoding method,"
  50 + "Valid values: greedy_search.");
  51 +}
  52 +
  53 +bool OfflineRecognizerConfig::Validate() const {
  54 + return model_config.Validate();
  55 +}
  56 +
  57 +std::string OfflineRecognizerConfig::ToString() const {
  58 + std::ostringstream os;
  59 +
  60 + os << "OfflineRecognizerConfig(";
  61 + os << "feat_config=" << feat_config.ToString() << ", ";
  62 + os << "model_config=" << model_config.ToString() << ", ";
  63 + os << "decoding_method=\"" << decoding_method << "\")";
  64 +
  65 + return os.str();
  66 +}
  67 +
  68 +class OfflineRecognizer::Impl {
  69 + public:
  70 + explicit Impl(const OfflineRecognizerConfig &config)
  71 + : config_(config),
  72 + symbol_table_(config_.model_config.tokens),
  73 + model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
  74 + if (config_.decoding_method == "greedy_search") {
  75 + decoder_ =
  76 + std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
  77 + } else if (config_.decoding_method == "modified_beam_search") {
  78 + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
  79 + exit(-1);
  80 + } else {
  81 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  82 + config_.decoding_method.c_str());
  83 + exit(-1);
  84 + }
  85 + }
  86 +
  87 + std::unique_ptr<OfflineStream> CreateStream() const {
  88 + return std::make_unique<OfflineStream>(config_.feat_config);
  89 + }
  90 +
  91 + void DecodeStreams(OfflineStream **ss, int32_t n) const {
  92 + auto memory_info =
  93 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  94 +
  95 + int32_t feat_dim = ss[0]->FeatureDim();
  96 +
  97 + std::vector<Ort::Value> features;
  98 +
  99 + features.reserve(n);
  100 +
  101 + std::vector<std::vector<float>> features_vec(n);
  102 + std::vector<int64_t> features_length_vec(n);
  103 + for (int32_t i = 0; i != n; ++i) {
  104 + auto f = ss[i]->GetFrames();
  105 + int32_t num_frames = f.size() / feat_dim;
  106 +
  107 + features_length_vec[i] = num_frames;
  108 + features_vec[i] = std::move(f);
  109 +
  110 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  111 +
  112 + Ort::Value x = Ort::Value::CreateTensor(
  113 + memory_info, features_vec[i].data(), features_vec[i].size(),
  114 + shape.data(), shape.size());
  115 + features.push_back(std::move(x));
  116 + }
  117 +
  118 + std::vector<const Ort::Value *> features_pointer(n);
  119 + for (int32_t i = 0; i != n; ++i) {
  120 + features_pointer[i] = &features[i];
  121 + }
  122 +
  123 + std::array<int64_t, 1> features_length_shape = {n};
  124 + Ort::Value x_length = Ort::Value::CreateTensor(
  125 + memory_info, features_length_vec.data(), n,
  126 + features_length_shape.data(), features_length_shape.size());
  127 +
  128 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
  129 + -23.025850929940457f);
  130 +
  131 + auto t = model_->RunEncoder(std::move(x), std::move(x_length));
  132 + auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
  133 +
  134 + int32_t frame_shift_ms = 10;
  135 + for (int32_t i = 0; i != n; ++i) {
  136 + auto r = Convert(results[i], symbol_table_, frame_shift_ms,
  137 + model_->SubsamplingFactor());
  138 +
  139 + ss[i]->SetResult(r);
  140 + }
  141 + }
  142 +
  143 + private:
  144 + OfflineRecognizerConfig config_;
  145 + SymbolTable symbol_table_;
  146 + std::unique_ptr<OfflineTransducerModel> model_;
  147 + std::unique_ptr<OfflineTransducerDecoder> decoder_;
  148 +};
  149 +
  150 +OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
  151 + : impl_(std::make_unique<Impl>(config)) {}
  152 +
  153 +OfflineRecognizer::~OfflineRecognizer() = default;
  154 +
  155 +std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
  156 + return impl_->CreateStream();
  157 +}
  158 +
  159 +void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
  160 + impl_->DecodeStreams(ss, n);
  161 +}
  162 +
  163 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-recognizer.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/offline-stream.h"
  13 +#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  14 +#include "sherpa-onnx/csrc/parse-options.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +struct OfflineRecognitionResult {
  19 + // Recognition results.
  20 + // For English, it consists of space separated words.
  21 + // For Chinese, it consists of Chinese words without spaces.
  22 + std::string text;
  23 +
  24 + // Decoded results at the token level.
  25 + // For instance, for BPE-based models it consists of a list of BPE tokens.
  26 + std::vector<std::string> tokens;
  27 +
  28 + /// timestamps.size() == tokens.size()
  29 + /// timestamps[i] records the time in seconds when tokens[i] is decoded.
  30 + std::vector<float> timestamps;
  31 +};
  32 +
  33 +struct OfflineRecognizerConfig {
  34 + OfflineFeatureExtractorConfig feat_config;
  35 + OfflineTransducerModelConfig model_config;
  36 +
  37 + std::string decoding_method = "greedy_search";
  38 + // only greedy_search is implemented
  39 + // TODO(fangjun): Implement modified_beam_search
  40 +
  41 + OfflineRecognizerConfig() = default;
  42 + OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
  43 + const OfflineTransducerModelConfig &model_config,
  44 + const std::string &decoding_method)
  45 + : feat_config(feat_config),
  46 + model_config(model_config),
  47 + decoding_method(decoding_method) {}
  48 +
  49 + void Register(ParseOptions *po);
  50 + bool Validate() const;
  51 +
  52 + std::string ToString() const;
  53 +};
  54 +
  55 +class OfflineRecognizer {
  56 + public:
  57 + ~OfflineRecognizer();
  58 +
  59 + explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
  60 +
  61 + /// Create a stream for decoding.
  62 + std::unique_ptr<OfflineStream> CreateStream() const;
  63 +
  64 + /** Decode a single stream
  65 + *
  66 + * @param s The stream to decode.
  67 + */
  68 + void DecodeStream(OfflineStream *s) const {
  69 + OfflineStream *ss[1] = {s};
  70 + DecodeStreams(ss, 1);
  71 + }
  72 +
  73 + /** Decode a list of streams.
  74 + *
  75 + * @param ss Pointer to an array of streams.
  76 + * @param n Size of the input array.
  77 + */
  78 + void DecodeStreams(OfflineStream **ss, int32_t n) const;
  79 +
  80 + private:
  81 + class Impl;
  82 + std::unique_ptr<Impl> impl_;
  83 +};
  84 +
  85 +} // namespace sherpa_onnx
  86 +
  87 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
  1 +// sherpa-onnx/csrc/offline-stream.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-stream.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <algorithm>
  10 +
  11 +#include "kaldi-native-fbank/csrc/online-feature.h"
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  14 +#include "sherpa-onnx/csrc/resample.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
  19 + po->Register("sample-rate", &sampling_rate,
  20 + "Sampling rate of the input waveform. Must match the one "
  21 + "expected by the model. Note: You can have a different "
  22 + "sample rate for the input waveform. We will do resampling "
  23 + "inside the feature extractor");
  24 +
  25 + po->Register("feat-dim", &feature_dim,
  26 + "Feature dimension. Must match the one expected by the model.");
  27 +}
  28 +
  29 +std::string OfflineFeatureExtractorConfig::ToString() const {
  30 + std::ostringstream os;
  31 +
  32 + os << "OfflineFeatureExtractorConfig(";
  33 + os << "sampling_rate=" << sampling_rate << ", ";
  34 + os << "feature_dim=" << feature_dim << ")";
  35 +
  36 + return os.str();
  37 +}
  38 +
  39 +class OfflineStream::Impl {
  40 + public:
  41 + explicit Impl(const OfflineFeatureExtractorConfig &config) {
  42 + opts_.frame_opts.dither = 0;
  43 + opts_.frame_opts.snip_edges = false;
  44 + opts_.frame_opts.samp_freq = config.sampling_rate;
  45 + opts_.mel_opts.num_bins = config.feature_dim;
  46 +
  47 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  48 + }
  49 +
  50 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
  51 + if (sampling_rate != opts_.frame_opts.samp_freq) {
  52 + SHERPA_ONNX_LOGE(
  53 + "Creating a resampler:\n"
  54 + " in_sample_rate: %d\n"
  55 + " output_sample_rate: %d\n",
  56 + sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
  57 +
  58 + float min_freq =
  59 + std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
  60 + float lowpass_cutoff = 0.99 * 0.5 * min_freq;
  61 +
  62 + int32_t lowpass_filter_width = 6;
  63 + auto resampler = std::make_unique<LinearResample>(
  64 + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
  65 + lowpass_filter_width);
  66 + std::vector<float> samples;
  67 + resampler->Resample(waveform, n, true, &samples);
  68 + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
  69 + samples.size());
  70 + fbank_->InputFinished();
  71 + return;
  72 + }
  73 +
  74 + fbank_->AcceptWaveform(sampling_rate, waveform, n);
  75 + fbank_->InputFinished();
  76 + }
  77 +
  78 + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
  79 +
  80 + std::vector<float> GetFrames() const {
  81 + int32_t n = fbank_->NumFramesReady();
  82 + assert(n > 0 && "Please first call AcceptWaveform()");
  83 +
  84 + int32_t feature_dim = FeatureDim();
  85 +
  86 + std::vector<float> features(n * feature_dim);
  87 +
  88 + float *p = features.data();
  89 +
  90 + for (int32_t i = 0; i != n; ++i) {
  91 + const float *f = fbank_->GetFrame(i);
  92 + std::copy(f, f + feature_dim, p);
  93 + p += feature_dim;
  94 + }
  95 +
  96 + return features;
  97 + }
  98 +
  99 + void SetResult(const OfflineRecognitionResult &r) { r_ = r; }
  100 +
  101 + const OfflineRecognitionResult &GetResult() const { return r_; }
  102 +
  103 + private:
  104 + std::unique_ptr<knf::OnlineFbank> fbank_;
  105 + knf::FbankOptions opts_;
  106 + OfflineRecognitionResult r_;
  107 +};
  108 +
  109 +OfflineStream::OfflineStream(
  110 + const OfflineFeatureExtractorConfig &config /*= {}*/)
  111 + : impl_(std::make_unique<Impl>(config)) {}
  112 +
  113 +OfflineStream::~OfflineStream() = default;
  114 +
  115 +void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
  116 + int32_t n) const {
  117 + impl_->AcceptWaveform(sampling_rate, waveform, n);
  118 +}
  119 +
  120 +int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); }
  121 +
  122 +std::vector<float> OfflineStream::GetFrames() const {
  123 + return impl_->GetFrames();
  124 +}
  125 +
  126 +void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
  127 + impl_->SetResult(r);
  128 +}
  129 +
  130 +const OfflineRecognitionResult &OfflineStream::GetResult() const {
  131 + return impl_->GetResult();
  132 +}
  133 +
  134 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-stream.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
  7 +#include <stdint.h>
  8 +
  9 +#include <memory>
  10 +#include <string>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/parse-options.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +struct OfflineRecognitionResult;
  17 +
  18 +struct OfflineFeatureExtractorConfig {
  19 + // Sampling rate used by the feature extractor. If it is different from
  20 + // the sampling rate of the input waveform, we will do resampling inside.
  21 + int32_t sampling_rate = 16000;
  22 +
  23 + // Feature dimension
  24 + int32_t feature_dim = 80;
  25 +
  26 + std::string ToString() const;
  27 +
  28 + void Register(ParseOptions *po);
  29 +};
  30 +
  31 +class OfflineStream {
  32 + public:
  33 + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
  34 + ~OfflineStream();
  35 +
  36 + /**
  37 + @param sampling_rate The sampling_rate of the input waveform. If it does
  38 + not equal to config.sampling_rate, we will do
  39 + resampling inside.
  40 + @param waveform Pointer to a 1-D array of size n. It must be normalized to
  41 + the range [-1, 1].
  42 + @param n Number of entries in waveform
  43 +
  44 + Caution: You can only invoke this function once so you have to input
  45 + all the samples at once
  46 + */
  47 + void AcceptWaveform(int32_t sampling_rate, const float *waveform,
  48 + int32_t n) const;
  49 +
  50 + /// Return feature dim of this extractor
  51 + int32_t FeatureDim() const;
  52 +
  53 + // Get all the feature frames of this stream in a 1-D array, which is
  54 + // flattened from a 2-D array of shape (num_frames, feat_dim).
  55 + std::vector<float> GetFrames() const;
  56 +
  57 + /** Set the recognition result for this stream. */
  58 + void SetResult(const OfflineRecognitionResult &r);
  59 +
  60 + /** Get the recognition result of this stream */
  61 + const OfflineRecognitionResult &GetResult() const;
  62 +
  63 + private:
  64 + class Impl;
  65 + std::unique_ptr<Impl> impl_;
  66 +};
  67 +
  68 +} // namespace sherpa_onnx
  69 +
  70 +#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
  1 +// sherpa-onnx/csrc/offline-transducer-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineTransducerDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int64_t> tokens;
  17 +
  18 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  19 + /// Note: The index is after subsampling
  20 + std::vector<int32_t> timestamps;
  21 +};
  22 +
  23 +class OfflineTransducerDecoder {
  24 + public:
  25 + virtual ~OfflineTransducerDecoder() = default;
  26 +
  27 + /** Run transducer beam search given the output from the encoder model.
  28 + *
  29 + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
  30 + * @param encoder_out_length A 1-D tensor of shape (N,) containing number
  31 + * of valid frames in encoder_out before padding.
  32 + *
  33 + * @return Return a vector of size `N` containing the decoded results.
  34 + */
  35 + virtual std::vector<OfflineTransducerDecoderResult> Decode(
  36 + Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
  37 +};
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <iterator>
  9 +#include <utility>
  10 +
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +#include "sherpa-onnx/csrc/packed-sequence.h"
  13 +#include "sherpa-onnx/csrc/slice.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +std::vector<OfflineTransducerDecoderResult>
  18 +OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
  19 + Ort::Value encoder_out_length) {
  20 + PackedSequence packed_encoder_out = PackPaddedSequence(
  21 + model_->Allocator(), &encoder_out, &encoder_out_length);
  22 +
  23 + int32_t batch_size =
  24 + static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
  25 +
  26 + int32_t vocab_size = model_->VocabSize();
  27 + int32_t context_size = model_->ContextSize();
  28 +
  29 + std::vector<OfflineTransducerDecoderResult> ans(batch_size);
  30 + for (auto &r : ans) {
  31 + // 0 is the ID of the blank token
  32 + r.tokens.resize(context_size, 0);
  33 + }
  34 +
  35 + auto decoder_input = model_->BuildDecoderInput(ans, ans.size());
  36 + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
  37 +
  38 + int32_t start = 0;
  39 + int32_t t = 0;
  40 + for (auto n : packed_encoder_out.batch_sizes) {
  41 + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
  42 + Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
  43 + start += n;
  44 + Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
  45 + std::move(cur_decoder_out));
  46 + const float *p_logit = logit.GetTensorData<float>();
  47 + bool emitted = false;
  48 + for (int32_t i = 0; i != n; ++i) {
  49 + auto y = static_cast<int32_t>(std::distance(
  50 + static_cast<const float *>(p_logit),
  51 + std::max_element(static_cast<const float *>(p_logit),
  52 + static_cast<const float *>(p_logit) + vocab_size)));
  53 + p_logit += vocab_size;
  54 + if (y != 0) {
  55 + ans[i].tokens.push_back(y);
  56 + ans[i].timestamps.push_back(t);
  57 + emitted = true;
  58 + }
  59 + }
  60 + if (emitted) {
  61 + Ort::Value decoder_input = model_->BuildDecoderInput(ans, n);
  62 + decoder_out = model_->RunDecoder(std::move(decoder_input));
  63 + }
  64 + ++t;
  65 + }
  66 +
  67 + for (auto &r : ans) {
  68 + r.tokens = {r.tokens.begin() + context_size, r.tokens.end()};
  69 + }
  70 +
  71 + std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
  72 + for (int32_t i = 0; i != batch_size; ++i) {
  73 + unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]);
  74 + }
  75 +
  76 + return unsorted_ans;
  77 +}
  78 +
  79 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  11 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
  16 + public:
  17 + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
  18 + : model_(model) {}
  19 +
  20 + std::vector<OfflineTransducerDecoderResult> Decode(
  21 + Ort::Value encoder_out, Ort::Value encoder_out_length) override;
  22 +
  23 + private:
  24 + OfflineTransducerModel *model_; // Not owned
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-transducer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  5 +
  6 +#include <sstream>
  7 +
  8 +#include "sherpa-onnx/csrc/file-utils.h"
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineTransducerModelConfig::Register(ParseOptions *po) {
  14 + po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
  15 + po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
  16 + po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
  17 + po->Register("tokens", &tokens, "Path to tokens.txt");
  18 + po->Register("num_threads", &num_threads,
  19 + "Number of threads to run the neural network");
  20 +
  21 + po->Register("debug", &debug,
  22 + "true to print model information while loading it.");
  23 +}
  24 +
  25 +bool OfflineTransducerModelConfig::Validate() const {
  26 + if (!FileExists(tokens)) {
  27 + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
  28 + return false;
  29 + }
  30 +
  31 + if (!FileExists(encoder_filename)) {
  32 + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
  33 + return false;
  34 + }
  35 +
  36 + if (!FileExists(decoder_filename)) {
  37 + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
  38 + return false;
  39 + }
  40 +
  41 + if (!FileExists(joiner_filename)) {
  42 + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
  43 + return false;
  44 + }
  45 +
  46 + if (num_threads < 1) {
  47 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  48 + return false;
  49 + }
  50 +
  51 + return true;
  52 +}
  53 +
  54 +std::string OfflineTransducerModelConfig::ToString() const {
  55 + std::ostringstream os;
  56 +
  57 + os << "OfflineTransducerModelConfig(";
  58 + os << "encoder_filename=\"" << encoder_filename << "\", ";
  59 + os << "decoder_filename=\"" << decoder_filename << "\", ";
  60 + os << "joiner_filename=\"" << joiner_filename << "\", ";
  61 + os << "tokens=\"" << tokens << "\", ";
  62 + os << "num_threads=" << num_threads << ", ";
  63 + os << "debug=" << (debug ? "True" : "False") << ")";
  64 +
  65 + return os.str();
  66 +}
  67 +
  68 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineTransducerModelConfig {
  14 + std::string encoder_filename;
  15 + std::string decoder_filename;
  16 + std::string joiner_filename;
  17 + std::string tokens;
  18 + int32_t num_threads = 2;
  19 + bool debug = false;
  20 +
  21 + OfflineTransducerModelConfig() = default;
  22 + OfflineTransducerModelConfig(const std::string &encoder_filename,
  23 + const std::string &decoder_filename,
  24 + const std::string &joiner_filename,
  25 + const std::string &tokens, int32_t num_threads,
  26 + bool debug)
  27 + : encoder_filename(encoder_filename),
  28 + decoder_filename(decoder_filename),
  29 + joiner_filename(joiner_filename),
  30 + tokens(tokens),
  31 + num_threads(num_threads),
  32 + debug(debug) {}
  33 +
  34 + void Register(ParseOptions *po);
  35 + bool Validate() const;
  36 +
  37 + std::string ToString() const;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-transducer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineTransducerModel::Impl {
  18 + public:
  19 + explicit Impl(const OfflineTransducerModelConfig &config)
  20 + : config_(config),
  21 + env_(ORT_LOGGING_LEVEL_WARNING),
  22 + sess_opts_{},
  23 + allocator_{} {
  24 + sess_opts_.SetIntraOpNumThreads(config.num_threads);
  25 + sess_opts_.SetInterOpNumThreads(config.num_threads);
  26 + {
  27 + auto buf = ReadFile(config.encoder_filename);
  28 + InitEncoder(buf.data(), buf.size());
  29 + }
  30 +
  31 + {
  32 + auto buf = ReadFile(config.decoder_filename);
  33 + InitDecoder(buf.data(), buf.size());
  34 + }
  35 +
  36 + {
  37 + auto buf = ReadFile(config.joiner_filename);
  38 + InitJoiner(buf.data(), buf.size());
  39 + }
  40 + }
  41 +
  42 + std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
  43 + Ort::Value features_length) {
  44 + std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
  45 + std::move(features_length)};
  46 +
  47 + auto encoder_out = encoder_sess_->Run(
  48 + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
  49 + encoder_inputs.size(), encoder_output_names_ptr_.data(),
  50 + encoder_output_names_ptr_.size());
  51 +
  52 + return {std::move(encoder_out[0]), std::move(encoder_out[1])};
  53 + }
  54 +
  55 + Ort::Value RunDecoder(Ort::Value decoder_input) {
  56 + auto decoder_out = decoder_sess_->Run(
  57 + {}, decoder_input_names_ptr_.data(), &decoder_input, 1,
  58 + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
  59 + return std::move(decoder_out[0]);
  60 + }
  61 +
  62 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
  63 + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
  64 + std::move(decoder_out)};
  65 + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
  66 + joiner_input.data(), joiner_input.size(),
  67 + joiner_output_names_ptr_.data(),
  68 + joiner_output_names_ptr_.size());
  69 +
  70 + return std::move(logit[0]);
  71 + }
  72 +
  73 + int32_t VocabSize() const { return vocab_size_; }
  74 + int32_t ContextSize() const { return context_size_; }
  75 + int32_t SubsamplingFactor() const { return 4; }
  76 + OrtAllocator *Allocator() const { return allocator_; }
  77 +
  78 + Ort::Value BuildDecoderInput(
  79 + const std::vector<OfflineTransducerDecoderResult> &results,
  80 + int32_t end_index) const {
  81 + assert(end_index <= results.size());
  82 +
  83 + int32_t batch_size = end_index;
  84 + int32_t context_size = ContextSize();
  85 + std::array<int64_t, 2> shape{batch_size, context_size};
  86 +
  87 + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
  88 + Allocator(), shape.data(), shape.size());
  89 + int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
  90 +
  91 + for (int32_t i = 0; i != batch_size; ++i) {
  92 + const auto &r = results[i];
  93 + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size;
  94 + const int64_t *end = r.tokens.data() + r.tokens.size();
  95 + std::copy(begin, end, p);
  96 + p += context_size;
  97 + }
  98 + return decoder_input;
  99 + }
  100 +
  101 + private:
  102 + void InitEncoder(void *model_data, size_t model_data_length) {
  103 + encoder_sess_ = std::make_unique<Ort::Session>(
  104 + env_, model_data, model_data_length, sess_opts_);
  105 +
  106 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  107 + &encoder_input_names_ptr_);
  108 +
  109 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  110 + &encoder_output_names_ptr_);
  111 +
  112 + // get meta data
  113 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  114 + if (config_.debug) {
  115 + std::ostringstream os;
  116 + os << "---encoder---\n";
  117 + PrintModelMetadata(os, meta_data);
  118 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  119 + }
  120 + }
  121 +
  122 + void InitDecoder(void *model_data, size_t model_data_length) {
  123 + decoder_sess_ = std::make_unique<Ort::Session>(
  124 + env_, model_data, model_data_length, sess_opts_);
  125 +
  126 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  127 + &decoder_input_names_ptr_);
  128 +
  129 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  130 + &decoder_output_names_ptr_);
  131 +
  132 + // get meta data
  133 + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
  134 + if (config_.debug) {
  135 + std::ostringstream os;
  136 + os << "---decoder---\n";
  137 + PrintModelMetadata(os, meta_data);
  138 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  139 + }
  140 +
  141 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  142 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  143 + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
  144 + }
  145 +
  146 + void InitJoiner(void *model_data, size_t model_data_length) {
  147 + joiner_sess_ = std::make_unique<Ort::Session>(
  148 + env_, model_data, model_data_length, sess_opts_);
  149 +
  150 + GetInputNames(joiner_sess_.get(), &joiner_input_names_,
  151 + &joiner_input_names_ptr_);
  152 +
  153 + GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
  154 + &joiner_output_names_ptr_);
  155 +
  156 + // get meta data
  157 + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
  158 + if (config_.debug) {
  159 + std::ostringstream os;
  160 + os << "---joiner---\n";
  161 + PrintModelMetadata(os, meta_data);
  162 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  163 + }
  164 + }
  165 +
  166 + private:
  167 + OfflineTransducerModelConfig config_;
  168 + Ort::Env env_;
  169 + Ort::SessionOptions sess_opts_;
  170 + Ort::AllocatorWithDefaultOptions allocator_;
  171 +
  172 + std::unique_ptr<Ort::Session> encoder_sess_;
  173 + std::unique_ptr<Ort::Session> decoder_sess_;
  174 + std::unique_ptr<Ort::Session> joiner_sess_;
  175 +
  176 + std::vector<std::string> encoder_input_names_;
  177 + std::vector<const char *> encoder_input_names_ptr_;
  178 +
  179 + std::vector<std::string> encoder_output_names_;
  180 + std::vector<const char *> encoder_output_names_ptr_;
  181 +
  182 + std::vector<std::string> decoder_input_names_;
  183 + std::vector<const char *> decoder_input_names_ptr_;
  184 +
  185 + std::vector<std::string> decoder_output_names_;
  186 + std::vector<const char *> decoder_output_names_ptr_;
  187 +
  188 + std::vector<std::string> joiner_input_names_;
  189 + std::vector<const char *> joiner_input_names_ptr_;
  190 +
  191 + std::vector<std::string> joiner_output_names_;
  192 + std::vector<const char *> joiner_output_names_ptr_;
  193 +
  194 + int32_t vocab_size_ = 0; // initialized in InitDecoder
  195 + int32_t context_size_ = 0; // initialized in InitDecoder
  196 +};
  197 +
  198 +OfflineTransducerModel::OfflineTransducerModel(
  199 + const OfflineTransducerModelConfig &config)
  200 + : impl_(std::make_unique<Impl>(config)) {}
  201 +
  202 +OfflineTransducerModel::~OfflineTransducerModel() = default;
  203 +
  204 +std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(
  205 + Ort::Value features, Ort::Value features_length) {
  206 + return impl_->RunEncoder(std::move(features), std::move(features_length));
  207 +}
  208 +
  209 +Ort::Value OfflineTransducerModel::RunDecoder(Ort::Value decoder_input) {
  210 + return impl_->RunDecoder(std::move(decoder_input));
  211 +}
  212 +
  213 +Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out,
  214 + Ort::Value decoder_out) {
  215 + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
  216 +}
  217 +
  218 +int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); }
  219 +
  220 +int32_t OfflineTransducerModel::ContextSize() const {
  221 + return impl_->ContextSize();
  222 +}
  223 +
  224 +int32_t OfflineTransducerModel::SubsamplingFactor() const {
  225 + return impl_->SubsamplingFactor();
  226 +}
  227 +
  228 +OrtAllocator *OfflineTransducerModel::Allocator() const {
  229 + return impl_->Allocator();
  230 +}
  231 +
  232 +Ort::Value OfflineTransducerModel::BuildDecoderInput(
  233 + const std::vector<OfflineTransducerDecoderResult> &results,
  234 + int32_t end_index) const {
  235 + return impl_->BuildDecoderInput(results, end_index);
  236 +}
  237 +
  238 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-transducer-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +struct OfflineTransducerDecoderResult;
  17 +
  18 +class OfflineTransducerModel {
  19 + public:
  20 + explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config);
  21 + ~OfflineTransducerModel();
  22 +
  23 + /** Run the encoder.
  24 + *
  25 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  26 + * @param features_length A 1-D tensor of shape (N,) containing number of
  27 + * valid frames in `features` before padding.
  28 + *
  29 + * @return Return a pair containing:
  30 + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
  31 + * - encoder_out_length: A 1-D tensor of shape (N,) containing number
  32 + * of frames in `encoder_out` before padding.
  33 + */
  34 + std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
  35 + Ort::Value features_length);
  36 +
  37 + /** Run the decoder network.
  38 + *
  39 + * Caution: We assume there are no recurrent connections in the decoder and
  40 + * the decoder is stateless. See
  41 + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
  42 + * for an example
  43 + *
  44 + * @param decoder_input It is usually of shape (N, context_size)
  45 + * @return Return a tensor of shape (N, decoder_dim).
  46 + */
  47 + Ort::Value RunDecoder(Ort::Value decoder_input);
  48 +
  49 + /** Run the joint network.
  50 + *
  51 + * @param encoder_out Output of the encoder network. A tensor of shape
  52 + * (N, joiner_dim).
  53 + * @param decoder_out Output of the decoder network. A tensor of shape
  54 + * (N, joiner_dim).
  55 + * @return Return a tensor of shape (N, vocab_size). In icefall, the last
  56 + * last layer of the joint network is `nn.Linear`,
  57 + * not `nn.LogSoftmax`.
  58 + */
  59 + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out);
  60 +
  61 + /** Return the vocabulary size of the model
  62 + */
  63 + int32_t VocabSize() const;
  64 +
  65 + /** Return the context_size of the decoder model.
  66 + */
  67 + int32_t ContextSize() const;
  68 +
  69 + /** Return the subsampling factor of the model.
  70 + */
  71 + int32_t SubsamplingFactor() const;
  72 +
  73 + /** Return an allocator for allocating memory
  74 + */
  75 + OrtAllocator *Allocator() const;
  76 +
  77 + /** Build decoder_input from the current results.
  78 + *
  79 + * @param results Current decoded results.
  80 + * @param end_index We only use results[0:end_index] to build
  81 + * the decoder_input.
  82 + * @return Return a tensor of shape (results.size(), ContextSize())
  83 + */
  84 + Ort::Value BuildDecoderInput(
  85 + const std::vector<OfflineTransducerDecoderResult> &results,
  86 + int32_t end_index) const;
  87 +
  88 + private:
  89 + class Impl;
  90 + std::unique_ptr<Impl> impl_;
  91 +};
  92 +
  93 +} // namespace sherpa_onnx
  94 +
  95 +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
@@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data, @@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data,
95 std::ostringstream os; 95 std::ostringstream os;
96 os << "---encoder---\n"; 96 os << "---encoder---\n";
97 PrintModelMetadata(os, meta_data); 97 PrintModelMetadata(os, meta_data);
98 - fprintf(stderr, "%s\n", os.str().c_str()); 98 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
99 } 99 }
100 100
101 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 101 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
@@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data, @@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data,
123 std::ostringstream os; 123 std::ostringstream os;
124 os << "---decoder---\n"; 124 os << "---decoder---\n";
125 PrintModelMetadata(os, meta_data); 125 PrintModelMetadata(os, meta_data);
126 - fprintf(stderr, "%s\n", os.str().c_str()); 126 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
127 } 127 }
128 128
129 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 129 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
@@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data, @@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data,
148 std::ostringstream os; 148 std::ostringstream os;
149 os << "---joiner---\n"; 149 os << "---joiner---\n";
150 PrintModelMetadata(os, meta_data); 150 PrintModelMetadata(os, meta_data);
151 - fprintf(stderr, "%s\n", os.str().c_str()); 151 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
152 } 152 }
153 } 153 }
154 154
@@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { @@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
228 std::pair<Ort::Value, std::vector<Ort::Value>> 228 std::pair<Ort::Value, std::vector<Ort::Value>>
229 OnlineLstmTransducerModel::RunEncoder(Ort::Value features, 229 OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
230 std::vector<Ort::Value> states) { 230 std::vector<Ort::Value> states) {
231 - auto memory_info =  
232 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
233 -  
234 std::array<Ort::Value, 3> encoder_inputs = { 231 std::array<Ort::Value, 3> encoder_inputs = {
235 std::move(features), std::move(states[0]), std::move(states[1])}; 232 std::move(features), std::move(states[0]), std::move(states[1])};
236 233
@@ -20,7 +20,7 @@ class OnlineStream::Impl { @@ -20,7 +20,7 @@ class OnlineStream::Impl {
20 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); 20 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
21 } 21 }
22 22
23 - void InputFinished() { feat_extractor_.InputFinished(); } 23 + void InputFinished() const { feat_extractor_.InputFinished(); }
24 24
25 int32_t NumFramesReady() const { 25 int32_t NumFramesReady() const {
26 return feat_extractor_.NumFramesReady() - start_frame_index_; 26 return feat_extractor_.NumFramesReady() - start_frame_index_;
@@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) @@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
68 OnlineStream::~OnlineStream() = default; 68 OnlineStream::~OnlineStream() = default;
69 69
70 void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, 70 void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
71 - int32_t n) { 71 + int32_t n) const {
72 impl_->AcceptWaveform(sampling_rate, waveform, n); 72 impl_->AcceptWaveform(sampling_rate, waveform, n);
73 } 73 }
74 74
75 -void OnlineStream::InputFinished() { impl_->InputFinished(); } 75 +void OnlineStream::InputFinished() const { impl_->InputFinished(); }
76 76
77 int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } 77 int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); }
78 78
@@ -27,7 +27,8 @@ class OnlineStream { @@ -27,7 +27,8 @@ class OnlineStream {
27 the range [-1, 1]. 27 the range [-1, 1].
28 @param n Number of entries in waveform 28 @param n Number of entries in waveform
29 */ 29 */
30 - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); 30 + void AcceptWaveform(int32_t sampling_rate, const float *waveform,
  31 + int32_t n) const;
31 32
32 /** 33 /**
33 * InputFinished() tells the class you won't be providing any 34 * InputFinished() tells the class you won't be providing any
@@ -35,7 +36,7 @@ class OnlineStream { @@ -35,7 +36,7 @@ class OnlineStream {
35 * of features, in the case where snip-edges == false; it also 36 * of features, in the case where snip-edges == false; it also
36 * affects the return value of IsLastFrame(). 37 * affects the return value of IsLastFrame().
37 */ 38 */
38 - void InputFinished(); 39 + void InputFinished() const;
39 40
40 int32_t NumFramesReady() const; 41 int32_t NumFramesReady() const;
41 42
@@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) { @@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) {
248 std::string wave_filename = po.GetArg(1); 248 std::string wave_filename = po.GetArg(1);
249 249
250 bool is_ok = false; 250 bool is_ok = false;
  251 + int32_t actual_sample_rate = -1;
251 std::vector<float> samples = 252 std::vector<float> samples =
252 - sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok); 253 + sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok);
253 254
254 if (!is_ok) { 255 if (!is_ok) {
255 SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); 256 SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
256 return -1; 257 return -1;
257 } 258 }
258 259
  260 + if (actual_sample_rate != sample_rate) {
  261 + SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate,
  262 + actual_sample_rate);
  263 + return -1;
  264 + }
  265 +
259 asio::io_context io_conn; // for network connections 266 asio::io_context io_conn; // for network connections
260 Client c(io_conn, server_ip, server_port, samples, samples_per_message, 267 Client c(io_conn, server_ip, server_port, samples, samples_per_message,
261 seconds_per_message); 268 seconds_per_message);
@@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, @@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
97 std::ostringstream os; 97 std::ostringstream os;
98 os << "---encoder---\n"; 98 os << "---encoder---\n";
99 PrintModelMetadata(os, meta_data); 99 PrintModelMetadata(os, meta_data);
100 - fprintf(stderr, "%s\n", os.str().c_str()); 100 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
101 } 101 }
102 102
103 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 103 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
@@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, @@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
123 print(num_encoder_layers_, "num_encoder_layers"); 123 print(num_encoder_layers_, "num_encoder_layers");
124 print(cnn_module_kernels_, "cnn_module_kernels"); 124 print(cnn_module_kernels_, "cnn_module_kernels");
125 print(left_context_len_, "left_context_len"); 125 print(left_context_len_, "left_context_len");
126 - fprintf(stderr, "T: %d\n", T_);  
127 - fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_); 126 + SHERPA_ONNX_LOGE("T: %d", T_);
  127 + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
128 } 128 }
129 } 129 }
130 130
@@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data, @@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
145 std::ostringstream os; 145 std::ostringstream os;
146 os << "---decoder---\n"; 146 os << "---decoder---\n";
147 PrintModelMetadata(os, meta_data); 147 PrintModelMetadata(os, meta_data);
148 - fprintf(stderr, "%s\n", os.str().c_str()); 148 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
149 } 149 }
150 150
151 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below 151 Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
@@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data, @@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
170 std::ostringstream os; 170 std::ostringstream os;
171 os << "---joiner---\n"; 171 os << "---joiner---\n";
172 PrintModelMetadata(os, meta_data); 172 PrintModelMetadata(os, meta_data);
173 - fprintf(stderr, "%s\n", os.str().c_str()); 173 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
174 } 174 }
175 } 175 }
176 176
@@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { @@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
435 std::pair<Ort::Value, std::vector<Ort::Value>> 435 std::pair<Ort::Value, std::vector<Ort::Value>>
436 OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, 436 OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
437 std::vector<Ort::Value> states) { 437 std::vector<Ort::Value> states) {
438 - auto memory_info =  
439 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
440 -  
441 std::vector<Ort::Value> encoder_inputs; 438 std::vector<Ort::Value> encoder_inputs;
442 encoder_inputs.reserve(1 + states.size()); 439 encoder_inputs.reserve(1 + states.size());
443 440
@@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator, @@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator,
41 std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); 41 std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape();
42 42
43 assert(v_shape.size() == 3); 43 assert(v_shape.size() == 3);
44 - assert(l_shape.size() == 3); 44 + assert(l_shape.size() == 1);
45 assert(v_shape[0] == l_shape[0]); 45 assert(v_shape[0] == l_shape[0]);
46 46
47 std::vector<int32_t> indexes(v_shape[0]); 47 std::vector<int32_t> indexes(v_shape[0]);
@@ -13,7 +13,26 @@ namespace sherpa_onnx { @@ -13,7 +13,26 @@ namespace sherpa_onnx {
13 struct PackedSequence { 13 struct PackedSequence {
14 std::vector<int32_t> sorted_indexes; 14 std::vector<int32_t> sorted_indexes;
15 std::vector<int32_t> batch_sizes; 15 std::vector<int32_t> batch_sizes;
  16 +
  17 + // data is a 2-D tensor of shape (sum(batch_sizes), channels)
16 Ort::Value data{nullptr}; 18 Ort::Value data{nullptr};
  19 +
  20 + // Return a shallow copy of data[start:start+size, :]
  21 + Ort::Value Get(int32_t start, int32_t size) {
  22 + auto shape = data.GetTensorTypeAndShapeInfo().GetShape();
  23 +
  24 + std::array<int64_t, 2> ans_shape{size, shape[1]};
  25 +
  26 + float *p = data.GetTensorMutableData<float>();
  27 +
  28 + auto memory_info =
  29 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  30 +
  31 + // a shallow copy
  32 + return Ort::Value::CreateTensor(memory_info, p + start * shape[1],
  33 + size * shape[1], ans_shape.data(),
  34 + ans_shape.size());
  35 + }
17 }; 36 };
18 37
19 /** Similar to torch.nn.utils.rnn.pad_sequence but it supports only 38 /** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
@@ -46,7 +46,7 @@ I Gcd(I m, I n) { @@ -46,7 +46,7 @@ I Gcd(I m, I n) {
46 // this function is copied from kaldi/src/base/kaldi-math.h 46 // this function is copied from kaldi/src/base/kaldi-math.h
47 if (m == 0 || n == 0) { 47 if (m == 0 || n == 0) {
48 if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. 48 if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
49 - fprintf(stderr, "Undefined GCD since m = 0, n = 0."); 49 + fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
50 exit(-1); 50 exit(-1);
51 } 51 }
52 return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); 52 return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
@@ -95,6 +95,10 @@ as the device_name. @@ -95,6 +95,10 @@ as the device_name.
95 95
96 fprintf(stderr, "%s\n", config.ToString().c_str()); 96 fprintf(stderr, "%s\n", config.ToString().c_str());
97 97
  98 + if (!config.Validate()) {
  99 + fprintf(stderr, "Errors in config!\n");
  100 + return -1;
  101 + }
98 sherpa_onnx::OnlineRecognizer recognizer(config); 102 sherpa_onnx::OnlineRecognizer recognizer(config);
99 103
100 int32_t expected_sample_rate = config.feat_config.sampling_rate; 104 int32_t expected_sample_rate = config.feat_config.sampling_rate;
@@ -86,6 +86,11 @@ for a list of pre-trained models to download. @@ -86,6 +86,11 @@ for a list of pre-trained models to download.
86 86
87 fprintf(stderr, "%s\n", config.ToString().c_str()); 87 fprintf(stderr, "%s\n", config.ToString().c_str());
88 88
  89 + if (!config.Validate()) {
  90 + fprintf(stderr, "Errors in config!\n");
  91 + return -1;
  92 + }
  93 +
89 sherpa_onnx::OnlineRecognizer recognizer(config); 94 sherpa_onnx::OnlineRecognizer recognizer(config);
90 auto s = recognizer.CreateStream(); 95 auto s = recognizer.CreateStream();
91 96
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include <stdio.h>
  6 +
  7 +#include <chrono> // NOLINT
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  12 +#include "sherpa-onnx/csrc/offline-stream.h"
  13 +#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
  14 +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
  15 +#include "sherpa-onnx/csrc/offline-transducer-model.h"
  16 +#include "sherpa-onnx/csrc/pad-sequence.h"
  17 +#include "sherpa-onnx/csrc/symbol-table.h"
  18 +#include "sherpa-onnx/csrc/wave-reader.h"
  19 +
  20 +int main(int32_t argc, char *argv[]) {
  21 + if (argc < 6 || argc > 8) {
  22 + const char *usage = R"usage(
  23 +Usage:
  24 + ./bin/sherpa-onnx-offline \
  25 + /path/to/tokens.txt \
  26 + /path/to/encoder.onnx \
  27 + /path/to/decoder.onnx \
  28 + /path/to/joiner.onnx \
  29 + /path/to/foo.wav [num_threads [decoding_method]]
  30 +
  31 +Default value for num_threads is 2.
  32 +Valid values for decoding_method: greedy_search.
  33 +foo.wav should be of single channel, 16-bit PCM encoded wave file; its
  34 +sampling rate can be arbitrary and does not need to be 16kHz.
  35 +
  36 +Please refer to
  37 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  38 +for a list of pre-trained models to download.
  39 +)usage";
  40 + fprintf(stderr, "%s\n", usage);
  41 +
  42 + return 0;
  43 + }
  44 +
  45 + sherpa_onnx::OfflineRecognizerConfig config;
  46 +
  47 + config.model_config.tokens = argv[1];
  48 +
  49 + config.model_config.debug = false;
  50 + config.model_config.encoder_filename = argv[2];
  51 + config.model_config.decoder_filename = argv[3];
  52 + config.model_config.joiner_filename = argv[4];
  53 +
  54 + std::string wav_filename = argv[5];
  55 +
  56 + config.model_config.num_threads = 2;
  57 + if (argc == 7 && atoi(argv[6]) > 0) {
  58 + config.model_config.num_threads = atoi(argv[6]);
  59 + }
  60 +
  61 + if (argc == 8) {
  62 + config.decoding_method = argv[7];
  63 + }
  64 +
  65 + fprintf(stderr, "%s\n", config.ToString().c_str());
  66 +
  67 + if (!config.Validate()) {
  68 + fprintf(stderr, "Errors in config!\n");
  69 + return -1;
  70 + }
  71 +
  72 + int32_t sampling_rate = -1;
  73 +
  74 + bool is_ok = false;
  75 + std::vector<float> samples =
  76 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
  77 + if (!is_ok) {
  78 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
  79 + return -1;
  80 + }
  81 + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
  82 +
  83 + float duration = samples.size() / static_cast<float>(sampling_rate);
  84 +
  85 + sherpa_onnx::OfflineRecognizer recognizer(config);
  86 + auto s = recognizer.CreateStream();
  87 +
  88 + auto begin = std::chrono::steady_clock::now();
  89 + fprintf(stderr, "Started\n");
  90 +
  91 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  92 +
  93 + recognizer.DecodeStream(s.get());
  94 +
  95 + fprintf(stderr, "Done!\n");
  96 +
  97 + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
  98 + s->GetResult().text.c_str());
  99 +
  100 + auto end = std::chrono::steady_clock::now();
  101 + float elapsed_seconds =
  102 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  103 + .count() /
  104 + 1000.;
  105 +
  106 + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
  107 + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
  108 +
  109 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  110 + float rtf = elapsed_seconds / duration;
  111 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  112 + elapsed_seconds, duration, rtf);
  113 +
  114 + return 0;
  115 +}
@@ -26,6 +26,8 @@ Usage: @@ -26,6 +26,8 @@ Usage:
26 26
27 Default value for num_threads is 2. 27 Default value for num_threads is 2.
28 Valid values for decoding_method: greedy_search (default), modified_beam_search. 28 Valid values for decoding_method: greedy_search (default), modified_beam_search.
  29 +foo.wav should be of single channel, 16-bit PCM encoded wave file; its
  30 +sampling rate can be arbitrary and does not need to be 16kHz.
29 31
30 Please refer to 32 Please refer to
31 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 33 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -59,20 +61,26 @@ for a list of pre-trained models to download. @@ -59,20 +61,26 @@ for a list of pre-trained models to download.
59 61
60 fprintf(stderr, "%s\n", config.ToString().c_str()); 62 fprintf(stderr, "%s\n", config.ToString().c_str());
61 63
  64 + if (!config.Validate()) {
  65 + fprintf(stderr, "Errors in config!\n");
  66 + return -1;
  67 + }
  68 +
62 sherpa_onnx::OnlineRecognizer recognizer(config); 69 sherpa_onnx::OnlineRecognizer recognizer(config);
63 70
64 - int32_t expected_sampling_rate = config.feat_config.sampling_rate; 71 + int32_t sampling_rate = -1;
65 72
66 bool is_ok = false; 73 bool is_ok = false;
67 std::vector<float> samples = 74 std::vector<float> samples =
68 - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); 75 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
69 76
70 if (!is_ok) { 77 if (!is_ok) {
71 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); 78 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
72 return -1; 79 return -1;
73 } 80 }
  81 + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
74 82
75 - float duration = samples.size() / static_cast<float>(expected_sampling_rate); 83 + float duration = samples.size() / static_cast<float>(sampling_rate);
76 84
77 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); 85 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
78 fprintf(stderr, "wav duration (s): %.3f\n", duration); 86 fprintf(stderr, "wav duration (s): %.3f\n", duration);
@@ -81,12 +89,13 @@ for a list of pre-trained models to download. @@ -81,12 +89,13 @@ for a list of pre-trained models to download.
81 fprintf(stderr, "Started\n"); 89 fprintf(stderr, "Started\n");
82 90
83 auto s = recognizer.CreateStream(); 91 auto s = recognizer.CreateStream();
84 - s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); 92 + s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
  93 +
  94 + std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate));
  95 + // Note: We can call AcceptWaveform() multiple times.
  96 + s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
85 97
86 - std::vector<float> tail_paddings(  
87 - static_cast<int>(0.2 * expected_sampling_rate));  
88 - s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),  
89 - tail_paddings.size()); 98 + // Call InputFinished() to indicate that no audio samples are available
90 s->InputFinished(); 99 s->InputFinished();
91 100
92 while (recognizer.IsReady(s.get())) { 101 while (recognizer.IsReady(s.get())) {
@@ -30,4 +30,23 @@ TEST(Slice, Slice3D) { @@ -30,4 +30,23 @@ TEST(Slice, Slice3D) {
30 // TODO(fangjun): Check that the results are correct 30 // TODO(fangjun): Check that the results are correct
31 } 31 }
32 32
  33 +TEST(Slice, Slice2D) {
  34 + Ort::AllocatorWithDefaultOptions allocator;
  35 + std::array<int64_t, 2> shape{5, 8};
  36 + Ort::Value v =
  37 + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
  38 + float *p = v.GetTensorMutableData<float>();
  39 +
  40 + std::iota(p, p + shape[0] * shape[1], 0);
  41 +
  42 + auto v1 = Slice(allocator, &v, 1, 3);
  43 + auto v2 = Slice(allocator, &v, 0, 2);
  44 +
  45 + Print2D(&v);
  46 + Print2D(&v1);
  47 + Print2D(&v2);
  48 +
  49 + // TODO(fangjun): Check that the results are correct
  50 +}
  51 +
33 } // namespace sherpa_onnx 52 } // namespace sherpa_onnx
@@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, @@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
24 24
25 assert(0 <= dim1_start); 25 assert(0 <= dim1_start);
26 assert(dim1_start < dim1_end); 26 assert(dim1_start < dim1_end);
27 - assert(dim1_end < shape[1]); 27 + assert(dim1_end <= shape[1]);
28 28
29 const T *src = v->GetTensorData<T>(); 29 const T *src = v->GetTensorData<T>();
30 30
@@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, @@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
46 return ans; 46 return ans;
47 } 47 }
48 48
  49 +template <typename T /*= float*/>
  50 +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
  51 + int32_t dim0_start, int32_t dim0_end) {
  52 + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
  53 + assert(shape.size() == 2);
  54 +
  55 + assert(0 <= dim0_start);
  56 + assert(dim0_start < dim0_end);
  57 + assert(dim0_end <= shape[0]);
  58 +
  59 + const T *src = v->GetTensorData<T>();
  60 +
  61 + std::array<int64_t, 2> ans_shape{dim0_end - dim0_start, shape[1]};
  62 +
  63 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  64 + ans_shape.size());
  65 + const T *start = v->GetTensorData<T>() + dim0_start * shape[1];
  66 + const T *end = v->GetTensorData<T>() + dim0_end * shape[1];
  67 + T *dst = ans.GetTensorMutableData<T>();
  68 + std::copy(start, end, dst);
  69 +
  70 + return ans;
  71 +}
  72 +
49 template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v, 73 template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
50 int32_t dim0_start, int32_t dim0_end, 74 int32_t dim0_start, int32_t dim0_end,
51 int32_t dim1_start, int32_t dim1_end); 75 int32_t dim1_start, int32_t dim1_end);
52 76
  77 +template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
  78 + int32_t dim0_start, int32_t dim0_end);
  79 +
53 } // namespace sherpa_onnx 80 } // namespace sherpa_onnx
@@ -8,12 +8,12 @@ @@ -8,12 +8,12 @@
8 8
9 namespace sherpa_onnx { 9 namespace sherpa_onnx {
10 10
11 -/** Get a deep copy by slicing v. 11 +/** Get a deep copy by slicing a 3-D tensor v.
12 * 12 *
13 - * It returns v[dim0_start:dim0_end, dim1_start:dim1_end] 13 + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
14 * 14 *
15 * @param allocator 15 * @param allocator
16 - * @param v A 3-D tensor. Its data type is T. 16 + * @param v A 2-D tensor. Its data type is T.
17 * @param dim0_start Start index of the first dimension.. 17 * @param dim0_start Start index of the first dimension..
18 * @param dim0_end End index of the first dimension.. 18 * @param dim0_end End index of the first dimension..
19 * @param dim1_start Start index of the second dimension. 19 * @param dim1_start Start index of the second dimension.
@@ -26,6 +26,23 @@ template <typename T = float> @@ -26,6 +26,23 @@ template <typename T = float>
26 Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, 26 Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
27 int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, 27 int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
28 int32_t dim1_end); 28 int32_t dim1_end);
  29 +
  30 +/** Get a deep copy by slicing a 2-D tensor v.
  31 + *
  32 + * It returns v[dim0_start:dim0_end, :]
  33 + *
  34 + * @param allocator
  35 + * @param v A 2-D tensor. Its data type is T.
  36 + * @param dim0_start Start index of the first dimension..
  37 + * @param dim0_end End index of the first dimension..
  38 + *
  39 + * @return Return a 2-D tensor of shape
  40 + * (dim0_end-dim0_start, v.shape[1])
  41 + */
  42 +template <typename T = float>
  43 +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
  44 + int32_t dim0_start, int32_t dim0_end);
  45 +
29 } // namespace sherpa_onnx 46 } // namespace sherpa_onnx
30 47
31 #endif // SHERPA_ONNX_CSRC_SLICE_H_ 48 #endif // SHERPA_ONNX_CSRC_SLICE_H_
@@ -6,10 +6,11 @@ @@ -6,10 +6,11 @@
6 6
7 #include <cassert> 7 #include <cassert>
8 #include <fstream> 8 #include <fstream>
9 -#include <iostream>  
10 #include <utility> 9 #include <utility>
11 #include <vector> 10 #include <vector>
12 11
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
14 namespace { 15 namespace {
15 // see http://soundfile.sapp.org/doc/WaveFormat/ 16 // see http://soundfile.sapp.org/doc/WaveFormat/
@@ -20,26 +21,34 @@ struct WaveHeader { @@ -20,26 +21,34 @@ struct WaveHeader {
20 bool Validate() const { 21 bool Validate() const {
21 // F F I R 22 // F F I R
22 if (chunk_id != 0x46464952) { 23 if (chunk_id != 0x46464952) {
  24 + SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
23 return false; 25 return false;
24 } 26 }
25 // E V A W 27 // E V A W
26 if (format != 0x45564157) { 28 if (format != 0x45564157) {
  29 + SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format);
27 return false; 30 return false;
28 } 31 }
29 32
30 if (subchunk1_id != 0x20746d66) { 33 if (subchunk1_id != 0x20746d66) {
  34 + SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
  35 + subchunk1_id);
31 return false; 36 return false;
32 } 37 }
33 38
34 if (subchunk1_size != 16) { // 16 for PCM 39 if (subchunk1_size != 16) { // 16 for PCM
  40 + SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n",
  41 + subchunk1_size);
35 return false; 42 return false;
36 } 43 }
37 44
38 if (audio_format != 1) { // 1 for PCM 45 if (audio_format != 1) { // 1 for PCM
  46 + SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format);
39 return false; 47 return false;
40 } 48 }
41 49
42 if (num_channels != 1) { // we support only single channel for now 50 if (num_channels != 1) { // we support only single channel for now
  51 + SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels);
43 return false; 52 return false;
44 } 53 }
45 if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { 54 if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
@@ -51,6 +60,8 @@ struct WaveHeader { @@ -51,6 +60,8 @@ struct WaveHeader {
51 } 60 }
52 61
53 if (bits_per_sample != 16) { // we support only 16 bits per sample 62 if (bits_per_sample != 16) { // we support only 16 bits per sample
  63 + SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
  64 + bits_per_sample);
54 return false; 65 return false;
55 } 66 }
56 67
@@ -62,7 +73,7 @@ struct WaveHeader { @@ -62,7 +73,7 @@ struct WaveHeader {
62 // and 73 // and
63 // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf 74 // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
64 void SeekToDataChunk(std::istream &is) { 75 void SeekToDataChunk(std::istream &is) {
65 - // a t a d 76 + // a t a d
66 while (is && subchunk2_id != 0x61746164) { 77 while (is && subchunk2_id != 0x61746164) {
67 // const char *p = reinterpret_cast<const char *>(&subchunk2_id); 78 // const char *p = reinterpret_cast<const char *>(&subchunk2_id);
68 // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], 79 // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
@@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, ""); @@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, "");
91 102
92 // Read a wave file of mono-channel. 103 // Read a wave file of mono-channel.
93 // Return its samples normalized to the range [-1, 1). 104 // Return its samples normalized to the range [-1, 1).
94 -std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, 105 +std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
95 bool *is_ok) { 106 bool *is_ok) {
96 WaveHeader header; 107 WaveHeader header;
97 is.read(reinterpret_cast<char *>(&header), sizeof(header)); 108 is.read(reinterpret_cast<char *>(&header), sizeof(header));
@@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, @@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
111 return {}; 122 return {};
112 } 123 }
113 124
114 - if (expected_sample_rate != header.sample_rate) {  
115 - *is_ok = false;  
116 - return {};  
117 - } 125 + *sampling_rate = header.sample_rate;
118 126
119 // header.subchunk2_size contains the number of bytes in the data. 127 // header.subchunk2_size contains the number of bytes in the data.
120 // As we assume each sample contains two bytes, so it is divided by 2 here 128 // As we assume each sample contains two bytes, so it is divided by 2 here
@@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, @@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
137 145
138 } // namespace 146 } // namespace
139 147
140 -std::vector<float> ReadWave(const std::string &filename,  
141 - float expected_sample_rate, bool *is_ok) { 148 +std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
  149 + bool *is_ok) {
142 std::ifstream is(filename, std::ifstream::binary); 150 std::ifstream is(filename, std::ifstream::binary);
143 - return ReadWave(is, expected_sample_rate, is_ok); 151 + return ReadWave(is, sampling_rate, is_ok);
144 } 152 }
145 153
146 -std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, 154 +std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
147 bool *is_ok) { 155 bool *is_ok) {
148 - auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok); 156 + auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
149 return samples; 157 return samples;
150 } 158 }
151 159
@@ -13,17 +13,17 @@ namespace sherpa_onnx { @@ -13,17 +13,17 @@ namespace sherpa_onnx {
13 13
14 /** Read a wave file with expected sample rate. 14 /** Read a wave file with expected sample rate.
15 15
16 - @param filename Path to a wave file. It MUST be single channel, PCM encoded.  
17 - @param expected_sample_rate Expected sample rate of the wave file. If the  
18 - sample rate don't match, it throws an exception. 16 + @param filename Path to a wave file. It MUST be single channel, 16-bit
  17 + PCM encoded.
  18 + @param sampling_rate On return, it contains the sampling rate of the file.
19 @param is_ok On return it is true if the reading succeeded; false otherwise. 19 @param is_ok On return it is true if the reading succeeded; false otherwise.
20 20
21 @return Return wave samples normalized to the range [-1, 1). 21 @return Return wave samples normalized to the range [-1, 1).
22 */ 22 */
23 -std::vector<float> ReadWave(const std::string &filename,  
24 - float expected_sample_rate, bool *is_ok); 23 +std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
  24 + bool *is_ok);
25 25
26 -std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, 26 +std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
27 bool *is_ok); 27 bool *is_ok);
28 28
29 } // namespace sherpa_onnx 29 } // namespace sherpa_onnx
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "jni.h" // NOLINT 11 #include "jni.h" // NOLINT
12 12
13 #include <strstream> 13 #include <strstream>
  14 +#include <utility>
14 15
15 #if __ANDROID_API__ >= 9 16 #if __ANDROID_API__ >= 9
16 #include "android/asset_manager.h" 17 #include "android/asset_manager.h"
@@ -43,14 +44,18 @@ class SherpaOnnx { @@ -43,14 +44,18 @@ class SherpaOnnx {
43 stream_(recognizer_.CreateStream()) { 44 stream_(recognizer_.CreateStream()) {
44 } 45 }
45 46
46 - void AcceptWaveform(int32_t sample_rate, const float *samples,  
47 - int32_t n) const { 47 + void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
  48 + if (input_sample_rate_ == -1) {
  49 + input_sample_rate_ = sample_rate;
  50 + }
  51 +
48 stream_->AcceptWaveform(sample_rate, samples, n); 52 stream_->AcceptWaveform(sample_rate, samples, n);
49 } 53 }
50 54
51 void InputFinished() const { 55 void InputFinished() const {
52 - std::vector<float> tail_padding(16000 * 0.32, 0);  
53 - stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size()); 56 + std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
  57 + stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
  58 + tail_padding.size());
54 stream_->InputFinished(); 59 stream_->InputFinished();
55 } 60 }
56 61
@@ -70,6 +75,7 @@ class SherpaOnnx { @@ -70,6 +75,7 @@ class SherpaOnnx {
70 private: 75 private:
71 sherpa_onnx::OnlineRecognizer recognizer_; 76 sherpa_onnx::OnlineRecognizer recognizer_;
72 std::unique_ptr<sherpa_onnx::OnlineStream> stream_; 77 std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
  78 + int32_t input_sample_rate_ = -1;
73 }; 79 };
74 80
75 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { 81 static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
@@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( @@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
276 return env->NewStringUTF(text.c_str()); 282 return env->NewStringUTF(text.c_str());
277 } 283 }
278 284
  285 +// see
  286 +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
  287 +static jobject NewInteger(JNIEnv *env, int32_t value) {
  288 + jclass cls = env->FindClass("java/lang/Integer");
  289 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
  290 + return env->NewObject(cls, constructor, value);
  291 +}
  292 +
279 SHERPA_ONNX_EXTERN_C 293 SHERPA_ONNX_EXTERN_C
280 -JNIEXPORT jfloatArray JNICALL 294 +JNIEXPORT jobjectArray JNICALL
281 Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( 295 Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
282 - JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,  
283 - jfloat expected_sample_rate) { 296 + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
284 const char *p_filename = env->GetStringUTFChars(filename, nullptr); 297 const char *p_filename = env->GetStringUTFChars(filename, nullptr);
285 #if __ANDROID_API__ >= 9 298 #if __ANDROID_API__ >= 9
286 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 299 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
287 if (!mgr) { 300 if (!mgr) {
288 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 301 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
289 - return nullptr; 302 + exit(-1);
290 } 303 }
291 304
292 std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename); 305 std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
@@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( @@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
297 #endif 310 #endif
298 311
299 bool is_ok = false; 312 bool is_ok = false;
  313 + int32_t sampling_rate = -1;
300 std::vector<float> samples = 314 std::vector<float> samples =
301 - sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok); 315 + sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
302 316
303 env->ReleaseStringUTFChars(filename, p_filename); 317 env->ReleaseStringUTFChars(filename, p_filename);
304 318
305 if (!is_ok) { 319 if (!is_ok) {
306 - return nullptr; 320 + SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
  321 + exit(-1);
307 } 322 }
308 323
309 jfloatArray ans = env->NewFloatArray(samples.size()); 324 jfloatArray ans = env->NewFloatArray(samples.size());
310 env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data()); 325 env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
311 - return ans; 326 +
  327 + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
  328 + 2, env->FindClass("java/lang/Object"), nullptr);
  329 +
  330 + env->SetObjectArrayElement(obj_arr, 0, ans);
  331 + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
  332 +
  333 + return obj_arr;
312 } 334 }
@@ -11,12 +11,10 @@ namespace sherpa_onnx { @@ -11,12 +11,10 @@ namespace sherpa_onnx {
11 static void PybindFeatureExtractorConfig(py::module *m) { 11 static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 - .def(py::init<int32_t, int32_t, int32_t>(),  
15 - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,  
16 - py::arg("max_feature_vectors") = -1) 14 + .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
  15 + py::arg("feature_dim") = 80)
17 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 16 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
18 .def_readwrite("feature_dim", &PyClass::feature_dim) 17 .def_readwrite("feature_dim", &PyClass::feature_dim)
19 - .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)  
20 .def("__str__", &PyClass::ToString); 18 .def("__str__", &PyClass::ToString);
21 } 19 }
22 20
@@ -34,7 +34,6 @@ class OnlineRecognizer(object): @@ -34,7 +34,6 @@ class OnlineRecognizer(object):
34 rule3_min_utterance_length: int = 20, 34 rule3_min_utterance_length: int = 20,
35 decoding_method: str = "greedy_search", 35 decoding_method: str = "greedy_search",
36 max_active_paths: int = 4, 36 max_active_paths: int = 4,
37 - max_feature_vectors: int = -1,  
38 ): 37 ):
39 """ 38 """
40 Please refer to 39 Please refer to
@@ -82,9 +81,6 @@ class OnlineRecognizer(object): @@ -82,9 +81,6 @@ class OnlineRecognizer(object):
82 max_active_paths: 81 max_active_paths:
83 Use only when decoding_method is modified_beam_search. It specifies 82 Use only when decoding_method is modified_beam_search. It specifies
84 the maximum number of active paths during beam search. 83 the maximum number of active paths during beam search.
85 - max_feature_vectors:  
86 - Number of feature vectors to cache. -1 means to cache all feature  
87 - frames that have been processed.  
88 """ 84 """
89 _assert_file_exists(tokens) 85 _assert_file_exists(tokens)
90 _assert_file_exists(encoder) 86 _assert_file_exists(encoder)
@@ -104,7 +100,6 @@ class OnlineRecognizer(object): @@ -104,7 +100,6 @@ class OnlineRecognizer(object):
104 feat_config = FeatureExtractorConfig( 100 feat_config = FeatureExtractorConfig(
105 sampling_rate=sample_rate, 101 sampling_rate=sample_rate,
106 feature_dim=feature_dim, 102 feature_dim=feature_dim,
107 - max_feature_vectors=max_feature_vectors,  
108 ) 103 )
109 104
110 endpoint_config = EndpointConfig( 105 endpoint_config = EndpointConfig(
@@ -8,18 +8,18 @@ @@ -8,18 +8,18 @@
8 8
9 import unittest 9 import unittest
10 10
11 -import sherpa_onnx 11 +import _sherpa_onnx
12 12
13 13
14 class TestFeatureExtractorConfig(unittest.TestCase): 14 class TestFeatureExtractorConfig(unittest.TestCase):
15 def test_default_constructor(self): 15 def test_default_constructor(self):
16 - config = sherpa_onnx.FeatureExtractorConfig() 16 + config = _sherpa_onnx.FeatureExtractorConfig()
17 assert config.sampling_rate == 16000, config.sampling_rate 17 assert config.sampling_rate == 16000, config.sampling_rate
18 assert config.feature_dim == 80, config.feature_dim 18 assert config.feature_dim == 80, config.feature_dim
19 print(config) 19 print(config)
20 20
21 def test_constructor(self): 21 def test_constructor(self):
22 - config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) 22 + config = _sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
23 assert config.sampling_rate == 8000, config.sampling_rate 23 assert config.sampling_rate == 8000, config.sampling_rate
24 assert config.feature_dim == 40, config.feature_dim 24 assert config.feature_dim == 40, config.feature_dim
25 print(config) 25 print(config)
@@ -8,21 +8,23 @@ @@ -8,21 +8,23 @@
8 8
9 import unittest 9 import unittest
10 10
11 -import sherpa_onnx 11 +import _sherpa_onnx
12 12
13 13
14 class TestOnlineTransducerModelConfig(unittest.TestCase): 14 class TestOnlineTransducerModelConfig(unittest.TestCase):
15 def test_constructor(self): 15 def test_constructor(self):
16 - config = sherpa_onnx.OnlineTransducerModelConfig( 16 + config = _sherpa_onnx.OnlineTransducerModelConfig(
17 encoder_filename="encoder.onnx", 17 encoder_filename="encoder.onnx",
18 decoder_filename="decoder.onnx", 18 decoder_filename="decoder.onnx",
19 joiner_filename="joiner.onnx", 19 joiner_filename="joiner.onnx",
  20 + tokens="tokens.txt",
20 num_threads=8, 21 num_threads=8,
21 debug=True, 22 debug=True,
22 ) 23 )
23 assert config.encoder_filename == "encoder.onnx", config.encoder_filename 24 assert config.encoder_filename == "encoder.onnx", config.encoder_filename
24 assert config.decoder_filename == "decoder.onnx", config.decoder_filename 25 assert config.decoder_filename == "decoder.onnx", config.decoder_filename
25 assert config.joiner_filename == "joiner.onnx", config.joiner_filename 26 assert config.joiner_filename == "joiner.onnx", config.joiner_filename
  27 + assert config.tokens == "tokens.txt", config.tokens
26 assert config.num_threads == 8, config.num_threads 28 assert config.num_threads == 8, config.num_threads
27 assert config.debug is True, config.debug 29 assert config.debug is True, config.debug
28 print(config) 30 print(config)