正在显示
48 个修改的文件
包含
1526 行增加
和
150 行删除
| @@ -33,18 +33,20 @@ fun main() { | @@ -33,18 +33,20 @@ fun main() { | ||
| 33 | config = config, | 33 | config = config, |
| 34 | ) | 34 | ) |
| 35 | 35 | ||
| 36 | - var samples = WaveReader.readWave( | 36 | + var objArray = WaveReader.readWave( |
| 37 | assetManager = AssetManager(), | 37 | assetManager = AssetManager(), |
| 38 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", | 38 | filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", |
| 39 | ) | 39 | ) |
| 40 | + var samples : FloatArray = objArray[0] as FloatArray | ||
| 41 | + var sampleRate : Int = objArray[1] as Int | ||
| 40 | 42 | ||
| 41 | - model.acceptWaveform(samples!!, sampleRate=16000) | 43 | + model.acceptWaveform(samples, sampleRate=sampleRate) |
| 42 | while (model.isReady()) { | 44 | while (model.isReady()) { |
| 43 | model.decode() | 45 | model.decode() |
| 44 | } | 46 | } |
| 45 | 47 | ||
| 46 | - var tail_paddings = FloatArray(8000) // 0.5 seconds | ||
| 47 | - model.acceptWaveform(tail_paddings, sampleRate=16000) | 48 | + var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds |
| 49 | + model.acceptWaveform(tail_paddings, sampleRate=sampleRate) | ||
| 48 | model.inputFinished() | 50 | model.inputFinished() |
| 49 | while (model.isReady()) { | 51 | while (model.isReady()) { |
| 50 | model.decode() | 52 | model.decode() |
.github/scripts/test-offline-transducer.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run Conformer transducer (English)" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-conformer-en-2023-03-18 | ||
| 21 | +log "Start testing ${repo_url}" | ||
| 22 | +repo=$(basename $repo_url) | ||
| 23 | +log "Download pretrained model and test-data from $repo_url" | ||
| 24 | + | ||
| 25 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 26 | +pushd $repo | ||
| 27 | +git lfs pull --include "*.onnx" | ||
| 28 | +cd test_wavs | ||
| 29 | +popd | ||
| 30 | + | ||
| 31 | +waves=( | ||
| 32 | +$repo/test_wavs/0.wav | ||
| 33 | +$repo/test_wavs/1.wav | ||
| 34 | +$repo/test_wavs/2.wav | ||
| 35 | +) | ||
| 36 | + | ||
| 37 | +for wave in ${waves[@]}; do | ||
| 38 | + time $EXE \ | ||
| 39 | + $repo/tokens.txt \ | ||
| 40 | + $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 41 | + $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 42 | + $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 43 | + $wave \ | ||
| 44 | + 2 | ||
| 45 | +done | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +if command -v sox &> /dev/null; then | ||
| 49 | + echo "test 8kHz" | ||
| 50 | + sox $repo/test_wavs/0.wav -r 8000 8k.wav | ||
| 51 | + time $EXE \ | ||
| 52 | + $repo/tokens.txt \ | ||
| 53 | + $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 54 | + $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 55 | + $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 56 | + 8k.wav \ | ||
| 57 | + 2 | ||
| 58 | +fi | ||
| 59 | + | ||
| 60 | +rm -rf $repo |
| @@ -40,7 +40,7 @@ for wave in ${waves[@]}; do | @@ -40,7 +40,7 @@ for wave in ${waves[@]}; do | ||
| 40 | $repo/decoder-epoch-99-avg-1.onnx \ | 40 | $repo/decoder-epoch-99-avg-1.onnx \ |
| 41 | $repo/joiner-epoch-99-avg-1.onnx \ | 41 | $repo/joiner-epoch-99-avg-1.onnx \ |
| 42 | $wave \ | 42 | $wave \ |
| 43 | - 4 | 43 | + 2 |
| 44 | done | 44 | done |
| 45 | 45 | ||
| 46 | rm -rf $repo | 46 | rm -rf $repo |
| @@ -72,7 +72,7 @@ for wave in ${waves[@]}; do | @@ -72,7 +72,7 @@ for wave in ${waves[@]}; do | ||
| 72 | $repo/decoder-epoch-11-avg-1.onnx \ | 72 | $repo/decoder-epoch-11-avg-1.onnx \ |
| 73 | $repo/joiner-epoch-11-avg-1.onnx \ | 73 | $repo/joiner-epoch-11-avg-1.onnx \ |
| 74 | $wave \ | 74 | $wave \ |
| 75 | - 4 | 75 | + 2 |
| 76 | done | 76 | done |
| 77 | 77 | ||
| 78 | rm -rf $repo | 78 | rm -rf $repo |
| @@ -104,7 +104,7 @@ for wave in ${waves[@]}; do | @@ -104,7 +104,7 @@ for wave in ${waves[@]}; do | ||
| 104 | $repo/decoder-epoch-99-avg-1.onnx \ | 104 | $repo/decoder-epoch-99-avg-1.onnx \ |
| 105 | $repo/joiner-epoch-99-avg-1.onnx \ | 105 | $repo/joiner-epoch-99-avg-1.onnx \ |
| 106 | $wave \ | 106 | $wave \ |
| 107 | - 4 | 107 | + 2 |
| 108 | done | 108 | done |
| 109 | 109 | ||
| 110 | rm -rf $repo | 110 | rm -rf $repo |
| @@ -138,7 +138,7 @@ for wave in ${waves[@]}; do | @@ -138,7 +138,7 @@ for wave in ${waves[@]}; do | ||
| 138 | $repo/decoder-epoch-99-avg-1.onnx \ | 138 | $repo/decoder-epoch-99-avg-1.onnx \ |
| 139 | $repo/joiner-epoch-99-avg-1.onnx \ | 139 | $repo/joiner-epoch-99-avg-1.onnx \ |
| 140 | $wave \ | 140 | $wave \ |
| 141 | - 4 | 141 | + 2 |
| 142 | done | 142 | done |
| 143 | 143 | ||
| 144 | # Decode a URL | 144 | # Decode a URL |
| @@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | @@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then | ||
| 149 | $repo/decoder-epoch-99-avg-1.onnx \ | 149 | $repo/decoder-epoch-99-avg-1.onnx \ |
| 150 | $repo/joiner-epoch-99-avg-1.onnx \ | 150 | $repo/joiner-epoch-99-avg-1.onnx \ |
| 151 | https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ | 151 | https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \ |
| 152 | - 4 | 152 | + 2 |
| 153 | fi | 153 | fi |
| 154 | 154 | ||
| 155 | rm -rf $repo | 155 | rm -rf $repo |
| @@ -7,11 +7,11 @@ on: | @@ -7,11 +7,11 @@ on: | ||
| 7 | paths: | 7 | paths: |
| 8 | - '.github/workflows/linux.yaml' | 8 | - '.github/workflows/linux.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | + - '.github/scripts/test-offline-transducer.sh' | ||
| 10 | - 'CMakeLists.txt' | 11 | - 'CMakeLists.txt' |
| 11 | - 'cmake/**' | 12 | - 'cmake/**' |
| 12 | - 'sherpa-onnx/csrc/*' | 13 | - 'sherpa-onnx/csrc/*' |
| 13 | - 'sherpa-onnx/c-api/*' | 14 | - 'sherpa-onnx/c-api/*' |
| 14 | - - 'ffmpeg-examples/**' | ||
| 15 | - 'c-api-examples/**' | 15 | - 'c-api-examples/**' |
| 16 | pull_request: | 16 | pull_request: |
| 17 | branches: | 17 | branches: |
| @@ -19,11 +19,11 @@ on: | @@ -19,11 +19,11 @@ on: | ||
| 19 | paths: | 19 | paths: |
| 20 | - '.github/workflows/linux.yaml' | 20 | - '.github/workflows/linux.yaml' |
| 21 | - '.github/scripts/test-online-transducer.sh' | 21 | - '.github/scripts/test-online-transducer.sh' |
| 22 | + - '.github/scripts/test-offline-transducer.sh' | ||
| 22 | - 'CMakeLists.txt' | 23 | - 'CMakeLists.txt' |
| 23 | - 'cmake/**' | 24 | - 'cmake/**' |
| 24 | - 'sherpa-onnx/csrc/*' | 25 | - 'sherpa-onnx/csrc/*' |
| 25 | - 'sherpa-onnx/c-api/*' | 26 | - 'sherpa-onnx/c-api/*' |
| 26 | - - 'ffmpeg-examples/**' | ||
| 27 | 27 | ||
| 28 | concurrency: | 28 | concurrency: |
| 29 | group: linux-${{ github.ref }} | 29 | group: linux-${{ github.ref }} |
| @@ -39,35 +39,26 @@ jobs: | @@ -39,35 +39,26 @@ jobs: | ||
| 39 | fail-fast: false | 39 | fail-fast: false |
| 40 | matrix: | 40 | matrix: |
| 41 | os: [ubuntu-latest] | 41 | os: [ubuntu-latest] |
| 42 | + build_type: [Release, Debug] | ||
| 42 | 43 | ||
| 43 | steps: | 44 | steps: |
| 44 | - uses: actions/checkout@v2 | 45 | - uses: actions/checkout@v2 |
| 45 | with: | 46 | with: |
| 46 | fetch-depth: 0 | 47 | fetch-depth: 0 |
| 47 | 48 | ||
| 48 | - - name: Install ffmpeg | 49 | + - name: Install sox |
| 49 | shell: bash | 50 | shell: bash |
| 50 | run: | | 51 | run: | |
| 51 | - sudo apt-get install -y software-properties-common | ||
| 52 | - sudo add-apt-repository ppa:savoury1/ffmpeg4 | ||
| 53 | - sudo add-apt-repository ppa:savoury1/ffmpeg5 | ||
| 54 | - | ||
| 55 | - sudo apt-get install -y libavdevice-dev libavutil-dev ffmpeg | ||
| 56 | - pkg-config --modversion libavutil | ||
| 57 | - ffmpeg -version | ||
| 58 | - | ||
| 59 | - - name: Show ffmpeg version | ||
| 60 | - shell: bash | ||
| 61 | - run: | | ||
| 62 | - pkg-config --modversion libavutil | ||
| 63 | - ffmpeg -version | 52 | + sudo apt-get update |
| 53 | + sudo apt-get install -y sox | ||
| 54 | + sox -h | ||
| 64 | 55 | ||
| 65 | - name: Configure CMake | 56 | - name: Configure CMake |
| 66 | shell: bash | 57 | shell: bash |
| 67 | run: | | 58 | run: | |
| 68 | mkdir build | 59 | mkdir build |
| 69 | cd build | 60 | cd build |
| 70 | - cmake -D CMAKE_BUILD_TYPE=Release .. | 61 | + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} .. |
| 71 | 62 | ||
| 72 | - name: Build sherpa-onnx for ubuntu | 63 | - name: Build sherpa-onnx for ubuntu |
| 73 | shell: bash | 64 | shell: bash |
| @@ -78,21 +69,19 @@ jobs: | @@ -78,21 +69,19 @@ jobs: | ||
| 78 | ls -lh lib | 69 | ls -lh lib |
| 79 | ls -lh bin | 70 | ls -lh bin |
| 80 | 71 | ||
| 81 | - cd ../ffmpeg-examples | ||
| 82 | - make | ||
| 83 | - | ||
| 84 | - name: Display dependencies of sherpa-onnx for linux | 72 | - name: Display dependencies of sherpa-onnx for linux |
| 85 | shell: bash | 73 | shell: bash |
| 86 | run: | | 74 | run: | |
| 87 | file build/bin/sherpa-onnx | 75 | file build/bin/sherpa-onnx |
| 88 | readelf -d build/bin/sherpa-onnx | 76 | readelf -d build/bin/sherpa-onnx |
| 89 | 77 | ||
| 90 | - - name: Test sherpa-onnx-ffmpeg | 78 | + - name: Test offline transducer |
| 79 | + shell: bash | ||
| 91 | run: | | 80 | run: | |
| 92 | - export PATH=$PWD/ffmpeg-examples:$PATH | ||
| 93 | - export EXE=sherpa-onnx-ffmpeg | 81 | + export PATH=$PWD/build/bin:$PATH |
| 82 | + export EXE=sherpa-onnx-offline | ||
| 94 | 83 | ||
| 95 | - .github/scripts/test-online-transducer.sh | 84 | + .github/scripts/test-offline-transducer.sh |
| 96 | 85 | ||
| 97 | - name: Test online transducer | 86 | - name: Test online transducer |
| 98 | shell: bash | 87 | shell: bash |
| @@ -7,6 +7,7 @@ on: | @@ -7,6 +7,7 @@ on: | ||
| 7 | paths: | 7 | paths: |
| 8 | - '.github/workflows/macos.yaml' | 8 | - '.github/workflows/macos.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | + - '.github/scripts/test-offline-transducer.sh' | ||
| 10 | - 'CMakeLists.txt' | 11 | - 'CMakeLists.txt' |
| 11 | - 'cmake/**' | 12 | - 'cmake/**' |
| 12 | - 'sherpa-onnx/csrc/*' | 13 | - 'sherpa-onnx/csrc/*' |
| @@ -16,6 +17,7 @@ on: | @@ -16,6 +17,7 @@ on: | ||
| 16 | paths: | 17 | paths: |
| 17 | - '.github/workflows/macos.yaml' | 18 | - '.github/workflows/macos.yaml' |
| 18 | - '.github/scripts/test-online-transducer.sh' | 19 | - '.github/scripts/test-online-transducer.sh' |
| 20 | + - '.github/scripts/test-offline-transducer.sh' | ||
| 19 | - 'CMakeLists.txt' | 21 | - 'CMakeLists.txt' |
| 20 | - 'cmake/**' | 22 | - 'cmake/**' |
| 21 | - 'sherpa-onnx/csrc/*' | 23 | - 'sherpa-onnx/csrc/*' |
| @@ -34,18 +36,25 @@ jobs: | @@ -34,18 +36,25 @@ jobs: | ||
| 34 | fail-fast: false | 36 | fail-fast: false |
| 35 | matrix: | 37 | matrix: |
| 36 | os: [macos-latest] | 38 | os: [macos-latest] |
| 39 | + build_type: [Release, Debug] | ||
| 37 | 40 | ||
| 38 | steps: | 41 | steps: |
| 39 | - uses: actions/checkout@v2 | 42 | - uses: actions/checkout@v2 |
| 40 | with: | 43 | with: |
| 41 | fetch-depth: 0 | 44 | fetch-depth: 0 |
| 42 | 45 | ||
| 46 | + - name: Install sox | ||
| 47 | + shell: bash | ||
| 48 | + run: | | ||
| 49 | + brew install sox | ||
| 50 | + sox -h | ||
| 51 | + | ||
| 43 | - name: Configure CMake | 52 | - name: Configure CMake |
| 44 | shell: bash | 53 | shell: bash |
| 45 | run: | | 54 | run: | |
| 46 | mkdir build | 55 | mkdir build |
| 47 | cd build | 56 | cd build |
| 48 | - cmake -D CMAKE_BUILD_TYPE=Release .. | 57 | + cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} .. |
| 49 | 58 | ||
| 50 | - name: Build sherpa-onnx for macos | 59 | - name: Build sherpa-onnx for macos |
| 51 | shell: bash | 60 | shell: bash |
| @@ -64,6 +73,14 @@ jobs: | @@ -64,6 +73,14 @@ jobs: | ||
| 64 | otool -L build/bin/sherpa-onnx | 73 | otool -L build/bin/sherpa-onnx |
| 65 | otool -l build/bin/sherpa-onnx | 74 | otool -l build/bin/sherpa-onnx |
| 66 | 75 | ||
| 76 | + - name: Test offline transducer | ||
| 77 | + shell: bash | ||
| 78 | + run: | | ||
| 79 | + export PATH=$PWD/build/bin:$PATH | ||
| 80 | + export EXE=sherpa-onnx-offline | ||
| 81 | + | ||
| 82 | + .github/scripts/test-offline-transducer.sh | ||
| 83 | + | ||
| 67 | - name: Test online transducer | 84 | - name: Test online transducer |
| 68 | shell: bash | 85 | shell: bash |
| 69 | run: | | 86 | run: | |
| @@ -13,7 +13,7 @@ endif() | @@ -13,7 +13,7 @@ endif() | ||
| 13 | 13 | ||
| 14 | option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) | 14 | option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF) |
| 15 | option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) | 15 | option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF) |
| 16 | -option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON) | 16 | +option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF) |
| 17 | option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) | 17 | option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) |
| 18 | option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) | 18 | option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) |
| 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) |
| @@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() { | @@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() { | ||
| 121 | val ret = audioRecord?.read(buffer, 0, buffer.size) | 121 | val ret = audioRecord?.read(buffer, 0, buffer.size) |
| 122 | if (ret != null && ret > 0) { | 122 | if (ret != null && ret > 0) { |
| 123 | val samples = FloatArray(ret) { buffer[it] / 32768.0f } | 123 | val samples = FloatArray(ret) { buffer[it] / 32768.0f } |
| 124 | - model.acceptWaveform(samples, sampleRate=16000) | 124 | + model.acceptWaveform(samples, sampleRate=sampleRateInHz) |
| 125 | while (model.isReady()) { | 125 | while (model.isReady()) { |
| 126 | model.decode() | 126 | model.decode() |
| 127 | } | 127 | } |
| @@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() { | @@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() { | ||
| 180 | val type = 0 | 180 | val type = 0 |
| 181 | println("Select model type ${type}") | 181 | println("Select model type ${type}") |
| 182 | val config = OnlineRecognizerConfig( | 182 | val config = OnlineRecognizerConfig( |
| 183 | - featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), | 183 | + featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), |
| 184 | modelConfig = getModelConfig(type = type)!!, | 184 | modelConfig = getModelConfig(type = type)!!, |
| 185 | endpointConfig = getEndpointConfig(), | 185 | endpointConfig = getEndpointConfig(), |
| 186 | enableEndpoint = true, | 186 | enableEndpoint = true, |
| @@ -8,7 +8,7 @@ class WaveReader { | @@ -8,7 +8,7 @@ class WaveReader { | ||
| 8 | // No resampling is made. | 8 | // No resampling is made. |
| 9 | external fun readWave( | 9 | external fun readWave( |
| 10 | assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f | 10 | assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f |
| 11 | - ): FloatArray? | 11 | + ): Array<Any> |
| 12 | 12 | ||
| 13 | init { | 13 | init { |
| 14 | System.loadLibrary("sherpa-onnx-jni") | 14 | System.loadLibrary("sherpa-onnx-jni") |
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.13.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.14.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.14.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=6a66638a111d3ce21fe6f29cbf9ab3dbcae2331c77391bf825927df5cbf2babe") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz | ||
| 16 | - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz | ||
| 17 | - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.13.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.14.tar.gz |
| 16 | + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.14.tar.gz | ||
| 17 | + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.14.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.14.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.14.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| @@ -91,7 +91,6 @@ def create_recognizer(): | @@ -91,7 +91,6 @@ def create_recognizer(): | ||
| 91 | rule2_min_trailing_silence=1.2, | 91 | rule2_min_trailing_silence=1.2, |
| 92 | rule3_min_utterance_length=300, # it essentially disables this rule | 92 | rule3_min_utterance_length=300, # it essentially disables this rule |
| 93 | decoding_method=args.decoding_method, | 93 | decoding_method=args.decoding_method, |
| 94 | - max_feature_vectors=100, # 1 second | ||
| 95 | ) | 94 | ) |
| 96 | return recognizer | 95 | return recognizer |
| 97 | 96 |
| @@ -86,7 +86,6 @@ def create_recognizer(): | @@ -86,7 +86,6 @@ def create_recognizer(): | ||
| 86 | sample_rate=16000, | 86 | sample_rate=16000, |
| 87 | feature_dim=80, | 87 | feature_dim=80, |
| 88 | decoding_method=args.decoding_method, | 88 | decoding_method=args.decoding_method, |
| 89 | - max_feature_vectors=100, # 1 second | ||
| 90 | ) | 89 | ) |
| 91 | return recognizer | 90 | return recognizer |
| 92 | 91 |
| @@ -6,6 +6,11 @@ set(sources | @@ -6,6 +6,11 @@ set(sources | ||
| 6 | features.cc | 6 | features.cc |
| 7 | file-utils.cc | 7 | file-utils.cc |
| 8 | hypothesis.cc | 8 | hypothesis.cc |
| 9 | + offline-stream.cc | ||
| 10 | + offline-transducer-greedy-search-decoder.cc | ||
| 11 | + offline-transducer-model-config.cc | ||
| 12 | + offline-transducer-model.cc | ||
| 13 | + offline-recognizer.cc | ||
| 9 | online-lstm-transducer-model.cc | 14 | online-lstm-transducer-model.cc |
| 10 | online-recognizer.cc | 15 | online-recognizer.cc |
| 11 | online-stream.cc | 16 | online-stream.cc |
| @@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK) | @@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK) | ||
| 56 | endif() | 61 | endif() |
| 57 | 62 | ||
| 58 | add_executable(sherpa-onnx sherpa-onnx.cc) | 63 | add_executable(sherpa-onnx sherpa-onnx.cc) |
| 64 | +add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) | ||
| 59 | 65 | ||
| 60 | target_link_libraries(sherpa-onnx sherpa-onnx-core) | 66 | target_link_libraries(sherpa-onnx sherpa-onnx-core) |
| 67 | +target_link_libraries(sherpa-onnx-offline sherpa-onnx-core) | ||
| 61 | if(NOT WIN32) | 68 | if(NOT WIN32) |
| 62 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | 69 | target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") |
| 70 | + target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") | ||
| 63 | endif() | 71 | endif() |
| 64 | 72 | ||
| 65 | if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) | 73 | if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) |
| @@ -68,7 +76,13 @@ else() | @@ -68,7 +76,13 @@ else() | ||
| 68 | install(TARGETS sherpa-onnx-core DESTINATION lib) | 76 | install(TARGETS sherpa-onnx-core DESTINATION lib) |
| 69 | endif() | 77 | endif() |
| 70 | 78 | ||
| 71 | -install(TARGETS sherpa-onnx DESTINATION bin) | 79 | +install( |
| 80 | + TARGETS | ||
| 81 | + sherpa-onnx | ||
| 82 | + sherpa-onnx-offline | ||
| 83 | + DESTINATION | ||
| 84 | + bin | ||
| 85 | +) | ||
| 72 | 86 | ||
| 73 | if(SHERPA_ONNX_HAS_ALSA) | 87 | if(SHERPA_ONNX_HAS_ALSA) |
| 74 | add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) | 88 | add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) |
| @@ -19,7 +19,9 @@ namespace sherpa_onnx { | @@ -19,7 +19,9 @@ namespace sherpa_onnx { | ||
| 19 | void FeatureExtractorConfig::Register(ParseOptions *po) { | 19 | void FeatureExtractorConfig::Register(ParseOptions *po) { |
| 20 | po->Register("sample-rate", &sampling_rate, | 20 | po->Register("sample-rate", &sampling_rate, |
| 21 | "Sampling rate of the input waveform. Must match the one " | 21 | "Sampling rate of the input waveform. Must match the one " |
| 22 | - "expected by the model."); | 22 | + "expected by the model. Note: You can have a different " |
| 23 | + "sample rate for the input waveform. We will do resampling " | ||
| 24 | + "inside the feature extractor"); | ||
| 23 | 25 | ||
| 24 | po->Register("feat-dim", &feature_dim, | 26 | po->Register("feat-dim", &feature_dim, |
| 25 | "Feature dimension. Must match the one expected by the model."); | 27 | "Feature dimension. Must match the one expected by the model."); |
| @@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 30 | 32 | ||
| 31 | os << "FeatureExtractorConfig("; | 33 | os << "FeatureExtractorConfig("; |
| 32 | os << "sampling_rate=" << sampling_rate << ", "; | 34 | os << "sampling_rate=" << sampling_rate << ", "; |
| 33 | - os << "feature_dim=" << feature_dim << ", "; | ||
| 34 | - os << "max_feature_vectors=" << max_feature_vectors << ")"; | 35 | + os << "feature_dim=" << feature_dim << ")"; |
| 35 | 36 | ||
| 36 | return os.str(); | 37 | return os.str(); |
| 37 | } | 38 | } |
| @@ -43,8 +44,6 @@ class FeatureExtractor::Impl { | @@ -43,8 +44,6 @@ class FeatureExtractor::Impl { | ||
| 43 | opts_.frame_opts.snip_edges = false; | 44 | opts_.frame_opts.snip_edges = false; |
| 44 | opts_.frame_opts.samp_freq = config.sampling_rate; | 45 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| 45 | 46 | ||
| 46 | - opts_.frame_opts.max_feature_vectors = config.max_feature_vectors; | ||
| 47 | - | ||
| 48 | opts_.mel_opts.num_bins = config.feature_dim; | 47 | opts_.mel_opts.num_bins = config.feature_dim; |
| 49 | 48 | ||
| 50 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 49 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| @@ -95,7 +94,7 @@ class FeatureExtractor::Impl { | @@ -95,7 +94,7 @@ class FeatureExtractor::Impl { | ||
| 95 | fbank_->AcceptWaveform(sampling_rate, waveform, n); | 94 | fbank_->AcceptWaveform(sampling_rate, waveform, n); |
| 96 | } | 95 | } |
| 97 | 96 | ||
| 98 | - void InputFinished() { | 97 | + void InputFinished() const { |
| 99 | std::lock_guard<std::mutex> lock(mutex_); | 98 | std::lock_guard<std::mutex> lock(mutex_); |
| 100 | fbank_->InputFinished(); | 99 | fbank_->InputFinished(); |
| 101 | } | 100 | } |
| @@ -110,12 +109,21 @@ class FeatureExtractor::Impl { | @@ -110,12 +109,21 @@ class FeatureExtractor::Impl { | ||
| 110 | return fbank_->IsLastFrame(frame); | 109 | return fbank_->IsLastFrame(frame); |
| 111 | } | 110 | } |
| 112 | 111 | ||
| 113 | - std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { | ||
| 114 | - if (frame_index + n > NumFramesReady()) { | ||
| 115 | - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); | 112 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) { |
| 113 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 114 | + if (frame_index + n > fbank_->NumFramesReady()) { | ||
| 115 | + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, | ||
| 116 | + fbank_->NumFramesReady()); | ||
| 116 | exit(-1); | 117 | exit(-1); |
| 117 | } | 118 | } |
| 118 | - std::lock_guard<std::mutex> lock(mutex_); | 119 | + |
| 120 | + int32_t discard_num = frame_index - last_frame_index_; | ||
| 121 | + if (discard_num < 0) { | ||
| 122 | + SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", | ||
| 123 | + last_frame_index_, frame_index); | ||
| 124 | + exit(-1); | ||
| 125 | + } | ||
| 126 | + fbank_->Pop(discard_num); | ||
| 119 | 127 | ||
| 120 | int32_t feature_dim = fbank_->Dim(); | 128 | int32_t feature_dim = fbank_->Dim(); |
| 121 | std::vector<float> features(feature_dim * n); | 129 | std::vector<float> features(feature_dim * n); |
| @@ -128,12 +136,9 @@ class FeatureExtractor::Impl { | @@ -128,12 +136,9 @@ class FeatureExtractor::Impl { | ||
| 128 | p += feature_dim; | 136 | p += feature_dim; |
| 129 | } | 137 | } |
| 130 | 138 | ||
| 131 | - return features; | ||
| 132 | - } | 139 | + last_frame_index_ = frame_index; |
| 133 | 140 | ||
| 134 | - void Reset() { | ||
| 135 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 136 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 141 | + return features; |
| 137 | } | 142 | } |
| 138 | 143 | ||
| 139 | int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | 144 | int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } |
| @@ -143,6 +148,7 @@ class FeatureExtractor::Impl { | @@ -143,6 +148,7 @@ class FeatureExtractor::Impl { | ||
| 143 | knf::FbankOptions opts_; | 148 | knf::FbankOptions opts_; |
| 144 | mutable std::mutex mutex_; | 149 | mutable std::mutex mutex_; |
| 145 | std::unique_ptr<LinearResample> resampler_; | 150 | std::unique_ptr<LinearResample> resampler_; |
| 151 | + int32_t last_frame_index_ = 0; | ||
| 146 | }; | 152 | }; |
| 147 | 153 | ||
| 148 | FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | 154 | FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) |
| @@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | @@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | ||
| 151 | FeatureExtractor::~FeatureExtractor() = default; | 157 | FeatureExtractor::~FeatureExtractor() = default; |
| 152 | 158 | ||
| 153 | void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, | 159 | void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, |
| 154 | - const float *waveform, int32_t n) { | 160 | + const float *waveform, int32_t n) const { |
| 155 | impl_->AcceptWaveform(sampling_rate, waveform, n); | 161 | impl_->AcceptWaveform(sampling_rate, waveform, n); |
| 156 | } | 162 | } |
| 157 | 163 | ||
| 158 | -void FeatureExtractor::InputFinished() { impl_->InputFinished(); } | 164 | +void FeatureExtractor::InputFinished() const { impl_->InputFinished(); } |
| 159 | 165 | ||
| 160 | int32_t FeatureExtractor::NumFramesReady() const { | 166 | int32_t FeatureExtractor::NumFramesReady() const { |
| 161 | return impl_->NumFramesReady(); | 167 | return impl_->NumFramesReady(); |
| @@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | @@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | ||
| 170 | return impl_->GetFrames(frame_index, n); | 176 | return impl_->GetFrames(frame_index, n); |
| 171 | } | 177 | } |
| 172 | 178 | ||
| 173 | -void FeatureExtractor::Reset() { impl_->Reset(); } | ||
| 174 | - | ||
| 175 | int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } | 179 | int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } |
| 176 | 180 | ||
| 177 | } // namespace sherpa_onnx | 181 | } // namespace sherpa_onnx |
| @@ -14,9 +14,12 @@ | @@ -14,9 +14,12 @@ | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | struct FeatureExtractorConfig { | 16 | struct FeatureExtractorConfig { |
| 17 | + // Sampling rate used by the feature extractor. If it is different from | ||
| 18 | + // the sampling rate of the input waveform, we will do resampling inside. | ||
| 17 | int32_t sampling_rate = 16000; | 19 | int32_t sampling_rate = 16000; |
| 20 | + | ||
| 21 | + // Feature dimension | ||
| 18 | int32_t feature_dim = 80; | 22 | int32_t feature_dim = 80; |
| 19 | - int32_t max_feature_vectors = -1; | ||
| 20 | 23 | ||
| 21 | std::string ToString() const; | 24 | std::string ToString() const; |
| 22 | 25 | ||
| @@ -36,7 +39,8 @@ class FeatureExtractor { | @@ -36,7 +39,8 @@ class FeatureExtractor { | ||
| 36 | the range [-1, 1]. | 39 | the range [-1, 1]. |
| 37 | @param n Number of entries in waveform | 40 | @param n Number of entries in waveform |
| 38 | */ | 41 | */ |
| 39 | - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); | 42 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| 43 | + int32_t n) const; | ||
| 40 | 44 | ||
| 41 | /** | 45 | /** |
| 42 | * InputFinished() tells the class you won't be providing any | 46 | * InputFinished() tells the class you won't be providing any |
| @@ -44,7 +48,7 @@ class FeatureExtractor { | @@ -44,7 +48,7 @@ class FeatureExtractor { | ||
| 44 | * of features, in the case where snip-edges == false; it also | 48 | * of features, in the case where snip-edges == false; it also |
| 45 | * affects the return value of IsLastFrame(). | 49 | * affects the return value of IsLastFrame(). |
| 46 | */ | 50 | */ |
| 47 | - void InputFinished(); | 51 | + void InputFinished() const; |
| 48 | 52 | ||
| 49 | int32_t NumFramesReady() const; | 53 | int32_t NumFramesReady() const; |
| 50 | 54 | ||
| @@ -62,8 +66,6 @@ class FeatureExtractor { | @@ -62,8 +66,6 @@ class FeatureExtractor { | ||
| 62 | */ | 66 | */ |
| 63 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; | 67 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; |
| 64 | 68 | ||
| 65 | - void Reset(); | ||
| 66 | - | ||
| 67 | /// Return feature dim of this extractor | 69 | /// Return feature dim of this extractor |
| 68 | int32_t FeatureDim() const; | 70 | int32_t FeatureDim() const; |
| 69 | 71 |
sherpa-onnx/csrc/offline-recognizer.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-recognizer.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 14 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 15 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +static OfflineRecognitionResult Convert( | ||
| 20 | + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, | ||
| 21 | + int32_t frame_shift_ms, int32_t subsampling_factor) { | ||
| 22 | + OfflineRecognitionResult r; | ||
| 23 | + r.tokens.reserve(src.tokens.size()); | ||
| 24 | + r.timestamps.reserve(src.timestamps.size()); | ||
| 25 | + | ||
| 26 | + std::string text; | ||
| 27 | + for (auto i : src.tokens) { | ||
| 28 | + auto sym = sym_table[i]; | ||
| 29 | + text.append(sym); | ||
| 30 | + | ||
| 31 | + r.tokens.push_back(std::move(sym)); | ||
| 32 | + } | ||
| 33 | + r.text = std::move(text); | ||
| 34 | + | ||
| 35 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 36 | + for (auto t : src.timestamps) { | ||
| 37 | + float time = frame_shift_s * t; | ||
| 38 | + r.timestamps.push_back(time); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + return r; | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 45 | + feat_config.Register(po); | ||
| 46 | + model_config.Register(po); | ||
| 47 | + | ||
| 48 | + po->Register("decoding-method", &decoding_method, | ||
| 49 | + "decoding method," | ||
| 50 | + "Valid values: greedy_search."); | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +bool OfflineRecognizerConfig::Validate() const { | ||
| 54 | + return model_config.Validate(); | ||
| 55 | +} | ||
| 56 | + | ||
| 57 | +std::string OfflineRecognizerConfig::ToString() const { | ||
| 58 | + std::ostringstream os; | ||
| 59 | + | ||
| 60 | + os << "OfflineRecognizerConfig("; | ||
| 61 | + os << "feat_config=" << feat_config.ToString() << ", "; | ||
| 62 | + os << "model_config=" << model_config.ToString() << ", "; | ||
| 63 | + os << "decoding_method=\"" << decoding_method << "\")"; | ||
| 64 | + | ||
| 65 | + return os.str(); | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +class OfflineRecognizer::Impl { | ||
| 69 | + public: | ||
| 70 | + explicit Impl(const OfflineRecognizerConfig &config) | ||
| 71 | + : config_(config), | ||
| 72 | + symbol_table_(config_.model_config.tokens), | ||
| 73 | + model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | ||
| 74 | + if (config_.decoding_method == "greedy_search") { | ||
| 75 | + decoder_ = | ||
| 76 | + std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | ||
| 77 | + } else if (config_.decoding_method == "modified_beam_search") { | ||
| 78 | + SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); | ||
| 79 | + exit(-1); | ||
| 80 | + } else { | ||
| 81 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 82 | + config_.decoding_method.c_str()); | ||
| 83 | + exit(-1); | ||
| 84 | + } | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + std::unique_ptr<OfflineStream> CreateStream() const { | ||
| 88 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + void DecodeStreams(OfflineStream **ss, int32_t n) const { | ||
| 92 | + auto memory_info = | ||
| 93 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 94 | + | ||
| 95 | + int32_t feat_dim = ss[0]->FeatureDim(); | ||
| 96 | + | ||
| 97 | + std::vector<Ort::Value> features; | ||
| 98 | + | ||
| 99 | + features.reserve(n); | ||
| 100 | + | ||
| 101 | + std::vector<std::vector<float>> features_vec(n); | ||
| 102 | + std::vector<int64_t> features_length_vec(n); | ||
| 103 | + for (int32_t i = 0; i != n; ++i) { | ||
| 104 | + auto f = ss[i]->GetFrames(); | ||
| 105 | + int32_t num_frames = f.size() / feat_dim; | ||
| 106 | + | ||
| 107 | + features_length_vec[i] = num_frames; | ||
| 108 | + features_vec[i] = std::move(f); | ||
| 109 | + | ||
| 110 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 111 | + | ||
| 112 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 113 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 114 | + shape.data(), shape.size()); | ||
| 115 | + features.push_back(std::move(x)); | ||
| 116 | + } | ||
| 117 | + | ||
| 118 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 119 | + for (int32_t i = 0; i != n; ++i) { | ||
| 120 | + features_pointer[i] = &features[i]; | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 124 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 125 | + memory_info, features_length_vec.data(), n, | ||
| 126 | + features_length_shape.data(), features_length_shape.size()); | ||
| 127 | + | ||
| 128 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, | ||
| 129 | + -23.025850929940457f); | ||
| 130 | + | ||
| 131 | + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); | ||
| 132 | + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | ||
| 133 | + | ||
| 134 | + int32_t frame_shift_ms = 10; | ||
| 135 | + for (int32_t i = 0; i != n; ++i) { | ||
| 136 | + auto r = Convert(results[i], symbol_table_, frame_shift_ms, | ||
| 137 | + model_->SubsamplingFactor()); | ||
| 138 | + | ||
| 139 | + ss[i]->SetResult(r); | ||
| 140 | + } | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + private: | ||
| 144 | + OfflineRecognizerConfig config_; | ||
| 145 | + SymbolTable symbol_table_; | ||
| 146 | + std::unique_ptr<OfflineTransducerModel> model_; | ||
| 147 | + std::unique_ptr<OfflineTransducerDecoder> decoder_; | ||
| 148 | +}; | ||
| 149 | + | ||
| 150 | +OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | ||
| 151 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 152 | + | ||
| 153 | +OfflineRecognizer::~OfflineRecognizer() = default; | ||
| 154 | + | ||
| 155 | +std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { | ||
| 156 | + return impl_->CreateStream(); | ||
| 157 | +} | ||
| 158 | + | ||
| 159 | +void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { | ||
| 160 | + impl_->DecodeStreams(ss, n); | ||
| 161 | +} | ||
| 162 | + | ||
| 163 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-recognizer.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-recognizer.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +struct OfflineRecognitionResult { | ||
| 19 | + // Recognition results. | ||
| 20 | + // For English, it consists of space separated words. | ||
| 21 | + // For Chinese, it consists of Chinese words without spaces. | ||
| 22 | + std::string text; | ||
| 23 | + | ||
| 24 | + // Decoded results at the token level. | ||
| 25 | + // For instance, for BPE-based models it consists of a list of BPE tokens. | ||
| 26 | + std::vector<std::string> tokens; | ||
| 27 | + | ||
| 28 | + /// timestamps.size() == tokens.size() | ||
| 29 | + /// timestamps[i] records the time in seconds when tokens[i] is decoded. | ||
| 30 | + std::vector<float> timestamps; | ||
| 31 | +}; | ||
| 32 | + | ||
| 33 | +struct OfflineRecognizerConfig { | ||
| 34 | + OfflineFeatureExtractorConfig feat_config; | ||
| 35 | + OfflineTransducerModelConfig model_config; | ||
| 36 | + | ||
| 37 | + std::string decoding_method = "greedy_search"; | ||
| 38 | + // only greedy_search is implemented | ||
| 39 | + // TODO(fangjun): Implement modified_beam_search | ||
| 40 | + | ||
| 41 | + OfflineRecognizerConfig() = default; | ||
| 42 | + OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, | ||
| 43 | + const OfflineTransducerModelConfig &model_config, | ||
| 44 | + const std::string &decoding_method) | ||
| 45 | + : feat_config(feat_config), | ||
| 46 | + model_config(model_config), | ||
| 47 | + decoding_method(decoding_method) {} | ||
| 48 | + | ||
| 49 | + void Register(ParseOptions *po); | ||
| 50 | + bool Validate() const; | ||
| 51 | + | ||
| 52 | + std::string ToString() const; | ||
| 53 | +}; | ||
| 54 | + | ||
| 55 | +class OfflineRecognizer { | ||
| 56 | + public: | ||
| 57 | + ~OfflineRecognizer(); | ||
| 58 | + | ||
| 59 | + explicit OfflineRecognizer(const OfflineRecognizerConfig &config); | ||
| 60 | + | ||
| 61 | + /// Create a stream for decoding. | ||
| 62 | + std::unique_ptr<OfflineStream> CreateStream() const; | ||
| 63 | + | ||
| 64 | + /** Decode a single stream | ||
| 65 | + * | ||
| 66 | + * @param s The stream to decode. | ||
| 67 | + */ | ||
| 68 | + void DecodeStream(OfflineStream *s) const { | ||
| 69 | + OfflineStream *ss[1] = {s}; | ||
| 70 | + DecodeStreams(ss, 1); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + /** Decode a list of streams. | ||
| 74 | + * | ||
| 75 | + * @param ss Pointer to an array of streams. | ||
| 76 | + * @param n Size of the input array. | ||
| 77 | + */ | ||
| 78 | + void DecodeStreams(OfflineStream **ss, int32_t n) const; | ||
| 79 | + | ||
| 80 | + private: | ||
| 81 | + class Impl; | ||
| 82 | + std::unique_ptr<Impl> impl_; | ||
| 83 | +}; | ||
| 84 | + | ||
| 85 | +} // namespace sherpa_onnx | ||
| 86 | + | ||
| 87 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ |
sherpa-onnx/csrc/offline-stream.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-stream.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | + | ||
| 11 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 14 | +#include "sherpa-onnx/csrc/resample.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 19 | + po->Register("sample-rate", &sampling_rate, | ||
| 20 | + "Sampling rate of the input waveform. Must match the one " | ||
| 21 | + "expected by the model. Note: You can have a different " | ||
| 22 | + "sample rate for the input waveform. We will do resampling " | ||
| 23 | + "inside the feature extractor"); | ||
| 24 | + | ||
| 25 | + po->Register("feat-dim", &feature_dim, | ||
| 26 | + "Feature dimension. Must match the one expected by the model."); | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +std::string OfflineFeatureExtractorConfig::ToString() const { | ||
| 30 | + std::ostringstream os; | ||
| 31 | + | ||
| 32 | + os << "OfflineFeatureExtractorConfig("; | ||
| 33 | + os << "sampling_rate=" << sampling_rate << ", "; | ||
| 34 | + os << "feature_dim=" << feature_dim << ")"; | ||
| 35 | + | ||
| 36 | + return os.str(); | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +class OfflineStream::Impl { | ||
| 40 | + public: | ||
| 41 | + explicit Impl(const OfflineFeatureExtractorConfig &config) { | ||
| 42 | + opts_.frame_opts.dither = 0; | ||
| 43 | + opts_.frame_opts.snip_edges = false; | ||
| 44 | + opts_.frame_opts.samp_freq = config.sampling_rate; | ||
| 45 | + opts_.mel_opts.num_bins = config.feature_dim; | ||
| 46 | + | ||
| 47 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | ||
| 51 | + if (sampling_rate != opts_.frame_opts.samp_freq) { | ||
| 52 | + SHERPA_ONNX_LOGE( | ||
| 53 | + "Creating a resampler:\n" | ||
| 54 | + " in_sample_rate: %d\n" | ||
| 55 | + " output_sample_rate: %d\n", | ||
| 56 | + sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq)); | ||
| 57 | + | ||
| 58 | + float min_freq = | ||
| 59 | + std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq); | ||
| 60 | + float lowpass_cutoff = 0.99 * 0.5 * min_freq; | ||
| 61 | + | ||
| 62 | + int32_t lowpass_filter_width = 6; | ||
| 63 | + auto resampler = std::make_unique<LinearResample>( | ||
| 64 | + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff, | ||
| 65 | + lowpass_filter_width); | ||
| 66 | + std::vector<float> samples; | ||
| 67 | + resampler->Resample(waveform, n, true, &samples); | ||
| 68 | + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), | ||
| 69 | + samples.size()); | ||
| 70 | + fbank_->InputFinished(); | ||
| 71 | + return; | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 75 | + fbank_->InputFinished(); | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | ||
| 79 | + | ||
| 80 | + std::vector<float> GetFrames() const { | ||
| 81 | + int32_t n = fbank_->NumFramesReady(); | ||
| 82 | + assert(n > 0 && "Please first call AcceptWaveform()"); | ||
| 83 | + | ||
| 84 | + int32_t feature_dim = FeatureDim(); | ||
| 85 | + | ||
| 86 | + std::vector<float> features(n * feature_dim); | ||
| 87 | + | ||
| 88 | + float *p = features.data(); | ||
| 89 | + | ||
| 90 | + for (int32_t i = 0; i != n; ++i) { | ||
| 91 | + const float *f = fbank_->GetFrame(i); | ||
| 92 | + std::copy(f, f + feature_dim, p); | ||
| 93 | + p += feature_dim; | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + return features; | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + void SetResult(const OfflineRecognitionResult &r) { r_ = r; } | ||
| 100 | + | ||
| 101 | + const OfflineRecognitionResult &GetResult() const { return r_; } | ||
| 102 | + | ||
| 103 | + private: | ||
| 104 | + std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 105 | + knf::FbankOptions opts_; | ||
| 106 | + OfflineRecognitionResult r_; | ||
| 107 | +}; | ||
| 108 | + | ||
| 109 | +OfflineStream::OfflineStream( | ||
| 110 | + const OfflineFeatureExtractorConfig &config /*= {}*/) | ||
| 111 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 112 | + | ||
| 113 | +OfflineStream::~OfflineStream() = default; | ||
| 114 | + | ||
| 115 | +void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, | ||
| 116 | + int32_t n) const { | ||
| 117 | + impl_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 118 | +} | ||
| 119 | + | ||
| 120 | +int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); } | ||
| 121 | + | ||
| 122 | +std::vector<float> OfflineStream::GetFrames() const { | ||
| 123 | + return impl_->GetFrames(); | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +void OfflineStream::SetResult(const OfflineRecognitionResult &r) { | ||
| 127 | + impl_->SetResult(r); | ||
| 128 | +} | ||
| 129 | + | ||
| 130 | +const OfflineRecognitionResult &OfflineStream::GetResult() const { | ||
| 131 | + return impl_->GetResult(); | ||
| 132 | +} | ||
| 133 | + | ||
| 134 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-stream.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-stream.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ | ||
| 7 | +#include <stdint.h> | ||
| 8 | + | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | +struct OfflineRecognitionResult; | ||
| 17 | + | ||
| 18 | +struct OfflineFeatureExtractorConfig { | ||
| 19 | + // Sampling rate used by the feature extractor. If it is different from | ||
| 20 | + // the sampling rate of the input waveform, we will do resampling inside. | ||
| 21 | + int32_t sampling_rate = 16000; | ||
| 22 | + | ||
| 23 | + // Feature dimension | ||
| 24 | + int32_t feature_dim = 80; | ||
| 25 | + | ||
| 26 | + std::string ToString() const; | ||
| 27 | + | ||
| 28 | + void Register(ParseOptions *po); | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +class OfflineStream { | ||
| 32 | + public: | ||
| 33 | + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); | ||
| 34 | + ~OfflineStream(); | ||
| 35 | + | ||
| 36 | + /** | ||
| 37 | + @param sampling_rate The sampling_rate of the input waveform. If it does | ||
| 38 | + not equal to config.sampling_rate, we will do | ||
| 39 | + resampling inside. | ||
| 40 | + @param waveform Pointer to a 1-D array of size n. It must be normalized to | ||
| 41 | + the range [-1, 1]. | ||
| 42 | + @param n Number of entries in waveform | ||
| 43 | + | ||
| 44 | + Caution: You can only invoke this function once so you have to input | ||
| 45 | + all the samples at once | ||
| 46 | + */ | ||
| 47 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, | ||
| 48 | + int32_t n) const; | ||
| 49 | + | ||
| 50 | + /// Return feature dim of this extractor | ||
| 51 | + int32_t FeatureDim() const; | ||
| 52 | + | ||
| 53 | + // Get all the feature frames of this stream in a 1-D array, which is | ||
| 54 | + // flattened from a 2-D array of shape (num_frames, feat_dim). | ||
| 55 | + std::vector<float> GetFrames() const; | ||
| 56 | + | ||
| 57 | + /** Set the recognition result for this stream. */ | ||
| 58 | + void SetResult(const OfflineRecognitionResult &r); | ||
| 59 | + | ||
| 60 | + /** Get the recognition result of this stream */ | ||
| 61 | + const OfflineRecognitionResult &GetResult() const; | ||
| 62 | + | ||
| 63 | + private: | ||
| 64 | + class Impl; | ||
| 65 | + std::unique_ptr<Impl> impl_; | ||
| 66 | +}; | ||
| 67 | + | ||
| 68 | +} // namespace sherpa_onnx | ||
| 69 | + | ||
| 70 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ |
| 1 | +// sherpa-onnx/csrc/offline-transducer-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineTransducerDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int64_t> tokens; | ||
| 17 | + | ||
| 18 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 19 | + /// Note: The index is after subsampling | ||
| 20 | + std::vector<int32_t> timestamps; | ||
| 21 | +}; | ||
| 22 | + | ||
| 23 | +class OfflineTransducerDecoder { | ||
| 24 | + public: | ||
| 25 | + virtual ~OfflineTransducerDecoder() = default; | ||
| 26 | + | ||
| 27 | + /** Run transducer beam search given the output from the encoder model. | ||
| 28 | + * | ||
| 29 | + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) | ||
| 30 | + * @param encoder_out_length A 1-D tensor of shape (N,) containing number | ||
| 31 | + * of valid frames in encoder_out before padding. | ||
| 32 | + * | ||
| 33 | + * @return Return a vector of size `N` containing the decoded results. | ||
| 34 | + */ | ||
| 35 | + virtual std::vector<OfflineTransducerDecoderResult> Decode( | ||
| 36 | + Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <iterator> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | +#include "sherpa-onnx/csrc/packed-sequence.h" | ||
| 13 | +#include "sherpa-onnx/csrc/slice.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +std::vector<OfflineTransducerDecoderResult> | ||
| 18 | +OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | ||
| 19 | + Ort::Value encoder_out_length) { | ||
| 20 | + PackedSequence packed_encoder_out = PackPaddedSequence( | ||
| 21 | + model_->Allocator(), &encoder_out, &encoder_out_length); | ||
| 22 | + | ||
| 23 | + int32_t batch_size = | ||
| 24 | + static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); | ||
| 25 | + | ||
| 26 | + int32_t vocab_size = model_->VocabSize(); | ||
| 27 | + int32_t context_size = model_->ContextSize(); | ||
| 28 | + | ||
| 29 | + std::vector<OfflineTransducerDecoderResult> ans(batch_size); | ||
| 30 | + for (auto &r : ans) { | ||
| 31 | + // 0 is the ID of the blank token | ||
| 32 | + r.tokens.resize(context_size, 0); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); | ||
| 36 | + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 37 | + | ||
| 38 | + int32_t start = 0; | ||
| 39 | + int32_t t = 0; | ||
| 40 | + for (auto n : packed_encoder_out.batch_sizes) { | ||
| 41 | + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); | ||
| 42 | + Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n); | ||
| 43 | + start += n; | ||
| 44 | + Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), | ||
| 45 | + std::move(cur_decoder_out)); | ||
| 46 | + const float *p_logit = logit.GetTensorData<float>(); | ||
| 47 | + bool emitted = false; | ||
| 48 | + for (int32_t i = 0; i != n; ++i) { | ||
| 49 | + auto y = static_cast<int32_t>(std::distance( | ||
| 50 | + static_cast<const float *>(p_logit), | ||
| 51 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 52 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 53 | + p_logit += vocab_size; | ||
| 54 | + if (y != 0) { | ||
| 55 | + ans[i].tokens.push_back(y); | ||
| 56 | + ans[i].timestamps.push_back(t); | ||
| 57 | + emitted = true; | ||
| 58 | + } | ||
| 59 | + } | ||
| 60 | + if (emitted) { | ||
| 61 | + Ort::Value decoder_input = model_->BuildDecoderInput(ans, n); | ||
| 62 | + decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 63 | + } | ||
| 64 | + ++t; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + for (auto &r : ans) { | ||
| 68 | + r.tokens = {r.tokens.begin() + context_size, r.tokens.end()}; | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size); | ||
| 72 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 73 | + unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + return unsorted_ans; | ||
| 77 | +} | ||
| 78 | + | ||
| 79 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 16 | + public: | ||
| 17 | + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model) | ||
| 18 | + : model_(model) {} | ||
| 19 | + | ||
| 20 | + std::vector<OfflineTransducerDecoderResult> Decode( | ||
| 21 | + Ort::Value encoder_out, Ort::Value encoder_out_length) override; | ||
| 22 | + | ||
| 23 | + private: | ||
| 24 | + OfflineTransducerModel *model_; // Not owned | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-transducer-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" | ||
| 5 | + | ||
| 6 | +#include <sstream> | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 14 | + po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); | ||
| 15 | + po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); | ||
| 16 | + po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); | ||
| 17 | + po->Register("tokens", &tokens, "Path to tokens.txt"); | ||
| 18 | + po->Register("num_threads", &num_threads, | ||
| 19 | + "Number of threads to run the neural network"); | ||
| 20 | + | ||
| 21 | + po->Register("debug", &debug, | ||
| 22 | + "true to print model information while loading it."); | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +bool OfflineTransducerModelConfig::Validate() const { | ||
| 26 | + if (!FileExists(tokens)) { | ||
| 27 | + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | ||
| 28 | + return false; | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + if (!FileExists(encoder_filename)) { | ||
| 32 | + SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); | ||
| 33 | + return false; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + if (!FileExists(decoder_filename)) { | ||
| 37 | + SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str()); | ||
| 38 | + return false; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + if (!FileExists(joiner_filename)) { | ||
| 42 | + SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str()); | ||
| 43 | + return false; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + if (num_threads < 1) { | ||
| 47 | + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 48 | + return false; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + return true; | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +std::string OfflineTransducerModelConfig::ToString() const { | ||
| 55 | + std::ostringstream os; | ||
| 56 | + | ||
| 57 | + os << "OfflineTransducerModelConfig("; | ||
| 58 | + os << "encoder_filename=\"" << encoder_filename << "\", "; | ||
| 59 | + os << "decoder_filename=\"" << decoder_filename << "\", "; | ||
| 60 | + os << "joiner_filename=\"" << joiner_filename << "\", "; | ||
| 61 | + os << "tokens=\"" << tokens << "\", "; | ||
| 62 | + os << "num_threads=" << num_threads << ", "; | ||
| 63 | + os << "debug=" << (debug ? "True" : "False") << ")"; | ||
| 64 | + | ||
| 65 | + return os.str(); | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-transducer-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineTransducerModelConfig { | ||
| 14 | + std::string encoder_filename; | ||
| 15 | + std::string decoder_filename; | ||
| 16 | + std::string joiner_filename; | ||
| 17 | + std::string tokens; | ||
| 18 | + int32_t num_threads = 2; | ||
| 19 | + bool debug = false; | ||
| 20 | + | ||
| 21 | + OfflineTransducerModelConfig() = default; | ||
| 22 | + OfflineTransducerModelConfig(const std::string &encoder_filename, | ||
| 23 | + const std::string &decoder_filename, | ||
| 24 | + const std::string &joiner_filename, | ||
| 25 | + const std::string &tokens, int32_t num_threads, | ||
| 26 | + bool debug) | ||
| 27 | + : encoder_filename(encoder_filename), | ||
| 28 | + decoder_filename(decoder_filename), | ||
| 29 | + joiner_filename(joiner_filename), | ||
| 30 | + tokens(tokens), | ||
| 31 | + num_threads(num_threads), | ||
| 32 | + debug(debug) {} | ||
| 33 | + | ||
| 34 | + void Register(ParseOptions *po); | ||
| 35 | + bool Validate() const; | ||
| 36 | + | ||
| 37 | + std::string ToString() const; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx | ||
| 41 | + | ||
| 42 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-transducer-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-transducer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineTransducerModel::Impl { | ||
| 18 | + public: | ||
| 19 | + explicit Impl(const OfflineTransducerModelConfig &config) | ||
| 20 | + : config_(config), | ||
| 21 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 22 | + sess_opts_{}, | ||
| 23 | + allocator_{} { | ||
| 24 | + sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 25 | + sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 26 | + { | ||
| 27 | + auto buf = ReadFile(config.encoder_filename); | ||
| 28 | + InitEncoder(buf.data(), buf.size()); | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + { | ||
| 32 | + auto buf = ReadFile(config.decoder_filename); | ||
| 33 | + InitDecoder(buf.data(), buf.size()); | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + { | ||
| 37 | + auto buf = ReadFile(config.joiner_filename); | ||
| 38 | + InitJoiner(buf.data(), buf.size()); | ||
| 39 | + } | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features, | ||
| 43 | + Ort::Value features_length) { | ||
| 44 | + std::array<Ort::Value, 2> encoder_inputs = {std::move(features), | ||
| 45 | + std::move(features_length)}; | ||
| 46 | + | ||
| 47 | + auto encoder_out = encoder_sess_->Run( | ||
| 48 | + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), | ||
| 49 | + encoder_inputs.size(), encoder_output_names_ptr_.data(), | ||
| 50 | + encoder_output_names_ptr_.size()); | ||
| 51 | + | ||
| 52 | + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + Ort::Value RunDecoder(Ort::Value decoder_input) { | ||
| 56 | + auto decoder_out = decoder_sess_->Run( | ||
| 57 | + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, | ||
| 58 | + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); | ||
| 59 | + return std::move(decoder_out[0]); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) { | ||
| 63 | + std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out), | ||
| 64 | + std::move(decoder_out)}; | ||
| 65 | + auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(), | ||
| 66 | + joiner_input.data(), joiner_input.size(), | ||
| 67 | + joiner_output_names_ptr_.data(), | ||
| 68 | + joiner_output_names_ptr_.size()); | ||
| 69 | + | ||
| 70 | + return std::move(logit[0]); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 74 | + int32_t ContextSize() const { return context_size_; } | ||
| 75 | + int32_t SubsamplingFactor() const { return 4; } | ||
| 76 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 77 | + | ||
| 78 | + Ort::Value BuildDecoderInput( | ||
| 79 | + const std::vector<OfflineTransducerDecoderResult> &results, | ||
| 80 | + int32_t end_index) const { | ||
| 81 | + assert(end_index <= results.size()); | ||
| 82 | + | ||
| 83 | + int32_t batch_size = end_index; | ||
| 84 | + int32_t context_size = ContextSize(); | ||
| 85 | + std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 86 | + | ||
| 87 | + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>( | ||
| 88 | + Allocator(), shape.data(), shape.size()); | ||
| 89 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 90 | + | ||
| 91 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 92 | + const auto &r = results[i]; | ||
| 93 | + const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size; | ||
| 94 | + const int64_t *end = r.tokens.data() + r.tokens.size(); | ||
| 95 | + std::copy(begin, end, p); | ||
| 96 | + p += context_size; | ||
| 97 | + } | ||
| 98 | + return decoder_input; | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + private: | ||
| 102 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 103 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 104 | + env_, model_data, model_data_length, sess_opts_); | ||
| 105 | + | ||
| 106 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 107 | + &encoder_input_names_ptr_); | ||
| 108 | + | ||
| 109 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 110 | + &encoder_output_names_ptr_); | ||
| 111 | + | ||
| 112 | + // get meta data | ||
| 113 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 114 | + if (config_.debug) { | ||
| 115 | + std::ostringstream os; | ||
| 116 | + os << "---encoder---\n"; | ||
| 117 | + PrintModelMetadata(os, meta_data); | ||
| 118 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 119 | + } | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + void InitDecoder(void *model_data, size_t model_data_length) { | ||
| 123 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 124 | + env_, model_data, model_data_length, sess_opts_); | ||
| 125 | + | ||
| 126 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 127 | + &decoder_input_names_ptr_); | ||
| 128 | + | ||
| 129 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 130 | + &decoder_output_names_ptr_); | ||
| 131 | + | ||
| 132 | + // get meta data | ||
| 133 | + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); | ||
| 134 | + if (config_.debug) { | ||
| 135 | + std::ostringstream os; | ||
| 136 | + os << "---decoder---\n"; | ||
| 137 | + PrintModelMetadata(os, meta_data); | ||
| 138 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 142 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 143 | + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + void InitJoiner(void *model_data, size_t model_data_length) { | ||
| 147 | + joiner_sess_ = std::make_unique<Ort::Session>( | ||
| 148 | + env_, model_data, model_data_length, sess_opts_); | ||
| 149 | + | ||
| 150 | + GetInputNames(joiner_sess_.get(), &joiner_input_names_, | ||
| 151 | + &joiner_input_names_ptr_); | ||
| 152 | + | ||
| 153 | + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | ||
| 154 | + &joiner_output_names_ptr_); | ||
| 155 | + | ||
| 156 | + // get meta data | ||
| 157 | + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); | ||
| 158 | + if (config_.debug) { | ||
| 159 | + std::ostringstream os; | ||
| 160 | + os << "---joiner---\n"; | ||
| 161 | + PrintModelMetadata(os, meta_data); | ||
| 162 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 163 | + } | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + private: | ||
| 167 | + OfflineTransducerModelConfig config_; | ||
| 168 | + Ort::Env env_; | ||
| 169 | + Ort::SessionOptions sess_opts_; | ||
| 170 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 171 | + | ||
| 172 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 173 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 174 | + std::unique_ptr<Ort::Session> joiner_sess_; | ||
| 175 | + | ||
| 176 | + std::vector<std::string> encoder_input_names_; | ||
| 177 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 178 | + | ||
| 179 | + std::vector<std::string> encoder_output_names_; | ||
| 180 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 181 | + | ||
| 182 | + std::vector<std::string> decoder_input_names_; | ||
| 183 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 184 | + | ||
| 185 | + std::vector<std::string> decoder_output_names_; | ||
| 186 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 187 | + | ||
| 188 | + std::vector<std::string> joiner_input_names_; | ||
| 189 | + std::vector<const char *> joiner_input_names_ptr_; | ||
| 190 | + | ||
| 191 | + std::vector<std::string> joiner_output_names_; | ||
| 192 | + std::vector<const char *> joiner_output_names_ptr_; | ||
| 193 | + | ||
| 194 | + int32_t vocab_size_ = 0; // initialized in InitDecoder | ||
| 195 | + int32_t context_size_ = 0; // initialized in InitDecoder | ||
| 196 | +}; | ||
| 197 | + | ||
| 198 | +OfflineTransducerModel::OfflineTransducerModel( | ||
| 199 | + const OfflineTransducerModelConfig &config) | ||
| 200 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 201 | + | ||
| 202 | +OfflineTransducerModel::~OfflineTransducerModel() = default; | ||
| 203 | + | ||
| 204 | +std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder( | ||
| 205 | + Ort::Value features, Ort::Value features_length) { | ||
| 206 | + return impl_->RunEncoder(std::move(features), std::move(features_length)); | ||
| 207 | +} | ||
| 208 | + | ||
| 209 | +Ort::Value OfflineTransducerModel::RunDecoder(Ort::Value decoder_input) { | ||
| 210 | + return impl_->RunDecoder(std::move(decoder_input)); | ||
| 211 | +} | ||
| 212 | + | ||
| 213 | +Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out, | ||
| 214 | + Ort::Value decoder_out) { | ||
| 215 | + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); | ||
| 216 | +} | ||
| 217 | + | ||
| 218 | +int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 219 | + | ||
| 220 | +int32_t OfflineTransducerModel::ContextSize() const { | ||
| 221 | + return impl_->ContextSize(); | ||
| 222 | +} | ||
| 223 | + | ||
| 224 | +int32_t OfflineTransducerModel::SubsamplingFactor() const { | ||
| 225 | + return impl_->SubsamplingFactor(); | ||
| 226 | +} | ||
| 227 | + | ||
| 228 | +OrtAllocator *OfflineTransducerModel::Allocator() const { | ||
| 229 | + return impl_->Allocator(); | ||
| 230 | +} | ||
| 231 | + | ||
| 232 | +Ort::Value OfflineTransducerModel::BuildDecoderInput( | ||
| 233 | + const std::vector<OfflineTransducerDecoderResult> &results, | ||
| 234 | + int32_t end_index) const { | ||
| 235 | + return impl_->BuildDecoderInput(results, end_index); | ||
| 236 | +} | ||
| 237 | + | ||
| 238 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-transducer-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-transducer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +struct OfflineTransducerDecoderResult; | ||
| 17 | + | ||
| 18 | +class OfflineTransducerModel { | ||
| 19 | + public: | ||
| 20 | + explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config); | ||
| 21 | + ~OfflineTransducerModel(); | ||
| 22 | + | ||
| 23 | + /** Run the encoder. | ||
| 24 | + * | ||
| 25 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 26 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 27 | + * valid frames in `features` before padding. | ||
| 28 | + * | ||
| 29 | + * @return Return a pair containing: | ||
| 30 | + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) | ||
| 31 | + * - encoder_out_length: A 1-D tensor of shape (N,) containing number | ||
| 32 | + * of frames in `encoder_out` before padding. | ||
| 33 | + */ | ||
| 34 | + std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features, | ||
| 35 | + Ort::Value features_length); | ||
| 36 | + | ||
| 37 | + /** Run the decoder network. | ||
| 38 | + * | ||
| 39 | + * Caution: We assume there are no recurrent connections in the decoder and | ||
| 40 | + * the decoder is stateless. See | ||
| 41 | + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py | ||
| 42 | + * for an example | ||
| 43 | + * | ||
| 44 | + * @param decoder_input It is usually of shape (N, context_size) | ||
| 45 | + * @return Return a tensor of shape (N, decoder_dim). | ||
| 46 | + */ | ||
| 47 | + Ort::Value RunDecoder(Ort::Value decoder_input); | ||
| 48 | + | ||
| 49 | + /** Run the joint network. | ||
| 50 | + * | ||
| 51 | + * @param encoder_out Output of the encoder network. A tensor of shape | ||
| 52 | + * (N, joiner_dim). | ||
| 53 | + * @param decoder_out Output of the decoder network. A tensor of shape | ||
| 54 | + * (N, joiner_dim). | ||
| 55 | + * @return Return a tensor of shape (N, vocab_size). In icefall, the last | ||
| 56 | + * last layer of the joint network is `nn.Linear`, | ||
| 57 | + * not `nn.LogSoftmax`. | ||
| 58 | + */ | ||
| 59 | + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out); | ||
| 60 | + | ||
| 61 | + /** Return the vocabulary size of the model | ||
| 62 | + */ | ||
| 63 | + int32_t VocabSize() const; | ||
| 64 | + | ||
| 65 | + /** Return the context_size of the decoder model. | ||
| 66 | + */ | ||
| 67 | + int32_t ContextSize() const; | ||
| 68 | + | ||
| 69 | + /** Return the subsampling factor of the model. | ||
| 70 | + */ | ||
| 71 | + int32_t SubsamplingFactor() const; | ||
| 72 | + | ||
| 73 | + /** Return an allocator for allocating memory | ||
| 74 | + */ | ||
| 75 | + OrtAllocator *Allocator() const; | ||
| 76 | + | ||
| 77 | + /** Build decoder_input from the current results. | ||
| 78 | + * | ||
| 79 | + * @param results Current decoded results. | ||
| 80 | + * @param end_index We only use results[0:end_index] to build | ||
| 81 | + * the decoder_input. | ||
| 82 | + * @return Return a tensor of shape (results.size(), ContextSize()) | ||
| 83 | + */ | ||
| 84 | + Ort::Value BuildDecoderInput( | ||
| 85 | + const std::vector<OfflineTransducerDecoderResult> &results, | ||
| 86 | + int32_t end_index) const; | ||
| 87 | + | ||
| 88 | + private: | ||
| 89 | + class Impl; | ||
| 90 | + std::unique_ptr<Impl> impl_; | ||
| 91 | +}; | ||
| 92 | + | ||
| 93 | +} // namespace sherpa_onnx | ||
| 94 | + | ||
| 95 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ |
| @@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data, | @@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data, | ||
| 95 | std::ostringstream os; | 95 | std::ostringstream os; |
| 96 | os << "---encoder---\n"; | 96 | os << "---encoder---\n"; |
| 97 | PrintModelMetadata(os, meta_data); | 97 | PrintModelMetadata(os, meta_data); |
| 98 | - fprintf(stderr, "%s\n", os.str().c_str()); | 98 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 99 | } | 99 | } |
| 100 | 100 | ||
| 101 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 101 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| @@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data, | @@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data, | ||
| 123 | std::ostringstream os; | 123 | std::ostringstream os; |
| 124 | os << "---decoder---\n"; | 124 | os << "---decoder---\n"; |
| 125 | PrintModelMetadata(os, meta_data); | 125 | PrintModelMetadata(os, meta_data); |
| 126 | - fprintf(stderr, "%s\n", os.str().c_str()); | 126 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 127 | } | 127 | } |
| 128 | 128 | ||
| 129 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 129 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| @@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data, | @@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data, | ||
| 148 | std::ostringstream os; | 148 | std::ostringstream os; |
| 149 | os << "---joiner---\n"; | 149 | os << "---joiner---\n"; |
| 150 | PrintModelMetadata(os, meta_data); | 150 | PrintModelMetadata(os, meta_data); |
| 151 | - fprintf(stderr, "%s\n", os.str().c_str()); | 151 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 152 | } | 152 | } |
| 153 | } | 153 | } |
| 154 | 154 | ||
| @@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | @@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() { | ||
| 228 | std::pair<Ort::Value, std::vector<Ort::Value>> | 228 | std::pair<Ort::Value, std::vector<Ort::Value>> |
| 229 | OnlineLstmTransducerModel::RunEncoder(Ort::Value features, | 229 | OnlineLstmTransducerModel::RunEncoder(Ort::Value features, |
| 230 | std::vector<Ort::Value> states) { | 230 | std::vector<Ort::Value> states) { |
| 231 | - auto memory_info = | ||
| 232 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 233 | - | ||
| 234 | std::array<Ort::Value, 3> encoder_inputs = { | 231 | std::array<Ort::Value, 3> encoder_inputs = { |
| 235 | std::move(features), std::move(states[0]), std::move(states[1])}; | 232 | std::move(features), std::move(states[0]), std::move(states[1])}; |
| 236 | 233 |
| @@ -20,7 +20,7 @@ class OnlineStream::Impl { | @@ -20,7 +20,7 @@ class OnlineStream::Impl { | ||
| 20 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); | 20 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); |
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | - void InputFinished() { feat_extractor_.InputFinished(); } | 23 | + void InputFinished() const { feat_extractor_.InputFinished(); } |
| 24 | 24 | ||
| 25 | int32_t NumFramesReady() const { | 25 | int32_t NumFramesReady() const { |
| 26 | return feat_extractor_.NumFramesReady() - start_frame_index_; | 26 | return feat_extractor_.NumFramesReady() - start_frame_index_; |
| @@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | @@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | ||
| 68 | OnlineStream::~OnlineStream() = default; | 68 | OnlineStream::~OnlineStream() = default; |
| 69 | 69 | ||
| 70 | void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, | 70 | void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| 71 | - int32_t n) { | 71 | + int32_t n) const { |
| 72 | impl_->AcceptWaveform(sampling_rate, waveform, n); | 72 | impl_->AcceptWaveform(sampling_rate, waveform, n); |
| 73 | } | 73 | } |
| 74 | 74 | ||
| 75 | -void OnlineStream::InputFinished() { impl_->InputFinished(); } | 75 | +void OnlineStream::InputFinished() const { impl_->InputFinished(); } |
| 76 | 76 | ||
| 77 | int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } | 77 | int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } |
| 78 | 78 |
| @@ -27,7 +27,8 @@ class OnlineStream { | @@ -27,7 +27,8 @@ class OnlineStream { | ||
| 27 | the range [-1, 1]. | 27 | the range [-1, 1]. |
| 28 | @param n Number of entries in waveform | 28 | @param n Number of entries in waveform |
| 29 | */ | 29 | */ |
| 30 | - void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); | 30 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| 31 | + int32_t n) const; | ||
| 31 | 32 | ||
| 32 | /** | 33 | /** |
| 33 | * InputFinished() tells the class you won't be providing any | 34 | * InputFinished() tells the class you won't be providing any |
| @@ -35,7 +36,7 @@ class OnlineStream { | @@ -35,7 +36,7 @@ class OnlineStream { | ||
| 35 | * of features, in the case where snip-edges == false; it also | 36 | * of features, in the case where snip-edges == false; it also |
| 36 | * affects the return value of IsLastFrame(). | 37 | * affects the return value of IsLastFrame(). |
| 37 | */ | 38 | */ |
| 38 | - void InputFinished(); | 39 | + void InputFinished() const; |
| 39 | 40 | ||
| 40 | int32_t NumFramesReady() const; | 41 | int32_t NumFramesReady() const; |
| 41 | 42 |
| @@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 248 | std::string wave_filename = po.GetArg(1); | 248 | std::string wave_filename = po.GetArg(1); |
| 249 | 249 | ||
| 250 | bool is_ok = false; | 250 | bool is_ok = false; |
| 251 | + int32_t actual_sample_rate = -1; | ||
| 251 | std::vector<float> samples = | 252 | std::vector<float> samples = |
| 252 | - sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok); | 253 | + sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok); |
| 253 | 254 | ||
| 254 | if (!is_ok) { | 255 | if (!is_ok) { |
| 255 | SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); | 256 | SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str()); |
| 256 | return -1; | 257 | return -1; |
| 257 | } | 258 | } |
| 258 | 259 | ||
| 260 | + if (actual_sample_rate != sample_rate) { | ||
| 261 | + SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate, | ||
| 262 | + actual_sample_rate); | ||
| 263 | + return -1; | ||
| 264 | + } | ||
| 265 | + | ||
| 259 | asio::io_context io_conn; // for network connections | 266 | asio::io_context io_conn; // for network connections |
| 260 | Client c(io_conn, server_ip, server_port, samples, samples_per_message, | 267 | Client c(io_conn, server_ip, server_port, samples, samples_per_message, |
| 261 | seconds_per_message); | 268 | seconds_per_message); |
| @@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, | @@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, | ||
| 97 | std::ostringstream os; | 97 | std::ostringstream os; |
| 98 | os << "---encoder---\n"; | 98 | os << "---encoder---\n"; |
| 99 | PrintModelMetadata(os, meta_data); | 99 | PrintModelMetadata(os, meta_data); |
| 100 | - fprintf(stderr, "%s\n", os.str().c_str()); | 100 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 101 | } | 101 | } |
| 102 | 102 | ||
| 103 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 103 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| @@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, | @@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data, | ||
| 123 | print(num_encoder_layers_, "num_encoder_layers"); | 123 | print(num_encoder_layers_, "num_encoder_layers"); |
| 124 | print(cnn_module_kernels_, "cnn_module_kernels"); | 124 | print(cnn_module_kernels_, "cnn_module_kernels"); |
| 125 | print(left_context_len_, "left_context_len"); | 125 | print(left_context_len_, "left_context_len"); |
| 126 | - fprintf(stderr, "T: %d\n", T_); | ||
| 127 | - fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_); | 126 | + SHERPA_ONNX_LOGE("T: %d", T_); |
| 127 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | ||
| 128 | } | 128 | } |
| 129 | } | 129 | } |
| 130 | 130 | ||
| @@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data, | @@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data, | ||
| 145 | std::ostringstream os; | 145 | std::ostringstream os; |
| 146 | os << "---decoder---\n"; | 146 | os << "---decoder---\n"; |
| 147 | PrintModelMetadata(os, meta_data); | 147 | PrintModelMetadata(os, meta_data); |
| 148 | - fprintf(stderr, "%s\n", os.str().c_str()); | 148 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 149 | } | 149 | } |
| 150 | 150 | ||
| 151 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | 151 | Ort::AllocatorWithDefaultOptions allocator; // used in the macro below |
| @@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data, | @@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data, | ||
| 170 | std::ostringstream os; | 170 | std::ostringstream os; |
| 171 | os << "---joiner---\n"; | 171 | os << "---joiner---\n"; |
| 172 | PrintModelMetadata(os, meta_data); | 172 | PrintModelMetadata(os, meta_data); |
| 173 | - fprintf(stderr, "%s\n", os.str().c_str()); | 173 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); |
| 174 | } | 174 | } |
| 175 | } | 175 | } |
| 176 | 176 | ||
| @@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { | @@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() { | ||
| 435 | std::pair<Ort::Value, std::vector<Ort::Value>> | 435 | std::pair<Ort::Value, std::vector<Ort::Value>> |
| 436 | OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, | 436 | OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, |
| 437 | std::vector<Ort::Value> states) { | 437 | std::vector<Ort::Value> states) { |
| 438 | - auto memory_info = | ||
| 439 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 440 | - | ||
| 441 | std::vector<Ort::Value> encoder_inputs; | 438 | std::vector<Ort::Value> encoder_inputs; |
| 442 | encoder_inputs.reserve(1 + states.size()); | 439 | encoder_inputs.reserve(1 + states.size()); |
| 443 | 440 |
| @@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator, | @@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator, | ||
| 41 | std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); | 41 | std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape(); |
| 42 | 42 | ||
| 43 | assert(v_shape.size() == 3); | 43 | assert(v_shape.size() == 3); |
| 44 | - assert(l_shape.size() == 3); | 44 | + assert(l_shape.size() == 1); |
| 45 | assert(v_shape[0] == l_shape[0]); | 45 | assert(v_shape[0] == l_shape[0]); |
| 46 | 46 | ||
| 47 | std::vector<int32_t> indexes(v_shape[0]); | 47 | std::vector<int32_t> indexes(v_shape[0]); |
| @@ -13,7 +13,26 @@ namespace sherpa_onnx { | @@ -13,7 +13,26 @@ namespace sherpa_onnx { | ||
| 13 | struct PackedSequence { | 13 | struct PackedSequence { |
| 14 | std::vector<int32_t> sorted_indexes; | 14 | std::vector<int32_t> sorted_indexes; |
| 15 | std::vector<int32_t> batch_sizes; | 15 | std::vector<int32_t> batch_sizes; |
| 16 | + | ||
| 17 | + // data is a 2-D tensor of shape (sum(batch_sizes), channels) | ||
| 16 | Ort::Value data{nullptr}; | 18 | Ort::Value data{nullptr}; |
| 19 | + | ||
| 20 | + // Return a shallow copy of data[start:start+size, :] | ||
| 21 | + Ort::Value Get(int32_t start, int32_t size) { | ||
| 22 | + auto shape = data.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 23 | + | ||
| 24 | + std::array<int64_t, 2> ans_shape{size, shape[1]}; | ||
| 25 | + | ||
| 26 | + float *p = data.GetTensorMutableData<float>(); | ||
| 27 | + | ||
| 28 | + auto memory_info = | ||
| 29 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 30 | + | ||
| 31 | + // a shallow copy | ||
| 32 | + return Ort::Value::CreateTensor(memory_info, p + start * shape[1], | ||
| 33 | + size * shape[1], ans_shape.data(), | ||
| 34 | + ans_shape.size()); | ||
| 35 | + } | ||
| 17 | }; | 36 | }; |
| 18 | 37 | ||
| 19 | /** Similar to torch.nn.utils.rnn.pad_sequence but it supports only | 38 | /** Similar to torch.nn.utils.rnn.pad_sequence but it supports only |
| @@ -46,7 +46,7 @@ I Gcd(I m, I n) { | @@ -46,7 +46,7 @@ I Gcd(I m, I n) { | ||
| 46 | // this function is copied from kaldi/src/base/kaldi-math.h | 46 | // this function is copied from kaldi/src/base/kaldi-math.h |
| 47 | if (m == 0 || n == 0) { | 47 | if (m == 0 || n == 0) { |
| 48 | if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. | 48 | if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. |
| 49 | - fprintf(stderr, "Undefined GCD since m = 0, n = 0."); | 49 | + fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n"); |
| 50 | exit(-1); | 50 | exit(-1); |
| 51 | } | 51 | } |
| 52 | return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); | 52 | return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); |
| @@ -95,6 +95,10 @@ as the device_name. | @@ -95,6 +95,10 @@ as the device_name. | ||
| 95 | 95 | ||
| 96 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 96 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 97 | 97 | ||
| 98 | + if (!config.Validate()) { | ||
| 99 | + fprintf(stderr, "Errors in config!\n"); | ||
| 100 | + return -1; | ||
| 101 | + } | ||
| 98 | sherpa_onnx::OnlineRecognizer recognizer(config); | 102 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 99 | 103 | ||
| 100 | int32_t expected_sample_rate = config.feat_config.sampling_rate; | 104 | int32_t expected_sample_rate = config.feat_config.sampling_rate; |
| @@ -86,6 +86,11 @@ for a list of pre-trained models to download. | @@ -86,6 +86,11 @@ for a list of pre-trained models to download. | ||
| 86 | 86 | ||
| 87 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 87 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 88 | 88 | ||
| 89 | + if (!config.Validate()) { | ||
| 90 | + fprintf(stderr, "Errors in config!\n"); | ||
| 91 | + return -1; | ||
| 92 | + } | ||
| 93 | + | ||
| 89 | sherpa_onnx::OnlineRecognizer recognizer(config); | 94 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 90 | auto s = recognizer.CreateStream(); | 95 | auto s = recognizer.CreateStream(); |
| 91 | 96 |
sherpa-onnx/csrc/sherpa-onnx-offline.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <stdio.h> | ||
| 6 | + | ||
| 7 | +#include <chrono> // NOLINT | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 16 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 17 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 18 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 19 | + | ||
| 20 | +int main(int32_t argc, char *argv[]) { | ||
| 21 | + if (argc < 6 || argc > 8) { | ||
| 22 | + const char *usage = R"usage( | ||
| 23 | +Usage: | ||
| 24 | + ./bin/sherpa-onnx-offline \ | ||
| 25 | + /path/to/tokens.txt \ | ||
| 26 | + /path/to/encoder.onnx \ | ||
| 27 | + /path/to/decoder.onnx \ | ||
| 28 | + /path/to/joiner.onnx \ | ||
| 29 | + /path/to/foo.wav [num_threads [decoding_method]] | ||
| 30 | + | ||
| 31 | +Default value for num_threads is 2. | ||
| 32 | +Valid values for decoding_method: greedy_search. | ||
| 33 | +foo.wav should be of single channel, 16-bit PCM encoded wave file; its | ||
| 34 | +sampling rate can be arbitrary and does not need to be 16kHz. | ||
| 35 | + | ||
| 36 | +Please refer to | ||
| 37 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 38 | +for a list of pre-trained models to download. | ||
| 39 | +)usage"; | ||
| 40 | + fprintf(stderr, "%s\n", usage); | ||
| 41 | + | ||
| 42 | + return 0; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + sherpa_onnx::OfflineRecognizerConfig config; | ||
| 46 | + | ||
| 47 | + config.model_config.tokens = argv[1]; | ||
| 48 | + | ||
| 49 | + config.model_config.debug = false; | ||
| 50 | + config.model_config.encoder_filename = argv[2]; | ||
| 51 | + config.model_config.decoder_filename = argv[3]; | ||
| 52 | + config.model_config.joiner_filename = argv[4]; | ||
| 53 | + | ||
| 54 | + std::string wav_filename = argv[5]; | ||
| 55 | + | ||
| 56 | + config.model_config.num_threads = 2; | ||
| 57 | + if (argc == 7 && atoi(argv[6]) > 0) { | ||
| 58 | + config.model_config.num_threads = atoi(argv[6]); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + if (argc == 8) { | ||
| 62 | + config.decoding_method = argv[7]; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 66 | + | ||
| 67 | + if (!config.Validate()) { | ||
| 68 | + fprintf(stderr, "Errors in config!\n"); | ||
| 69 | + return -1; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + int32_t sampling_rate = -1; | ||
| 73 | + | ||
| 74 | + bool is_ok = false; | ||
| 75 | + std::vector<float> samples = | ||
| 76 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 77 | + if (!is_ok) { | ||
| 78 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 79 | + return -1; | ||
| 80 | + } | ||
| 81 | + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 82 | + | ||
| 83 | + float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 84 | + | ||
| 85 | + sherpa_onnx::OfflineRecognizer recognizer(config); | ||
| 86 | + auto s = recognizer.CreateStream(); | ||
| 87 | + | ||
| 88 | + auto begin = std::chrono::steady_clock::now(); | ||
| 89 | + fprintf(stderr, "Started\n"); | ||
| 90 | + | ||
| 91 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 92 | + | ||
| 93 | + recognizer.DecodeStream(s.get()); | ||
| 94 | + | ||
| 95 | + fprintf(stderr, "Done!\n"); | ||
| 96 | + | ||
| 97 | + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), | ||
| 98 | + s->GetResult().text.c_str()); | ||
| 99 | + | ||
| 100 | + auto end = std::chrono::steady_clock::now(); | ||
| 101 | + float elapsed_seconds = | ||
| 102 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 103 | + .count() / | ||
| 104 | + 1000.; | ||
| 105 | + | ||
| 106 | + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | ||
| 107 | + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||
| 108 | + | ||
| 109 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 110 | + float rtf = elapsed_seconds / duration; | ||
| 111 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 112 | + elapsed_seconds, duration, rtf); | ||
| 113 | + | ||
| 114 | + return 0; | ||
| 115 | +} |
| @@ -26,6 +26,8 @@ Usage: | @@ -26,6 +26,8 @@ Usage: | ||
| 26 | 26 | ||
| 27 | Default value for num_threads is 2. | 27 | Default value for num_threads is 2. |
| 28 | Valid values for decoding_method: greedy_search (default), modified_beam_search. | 28 | Valid values for decoding_method: greedy_search (default), modified_beam_search. |
| 29 | +foo.wav should be of single channel, 16-bit PCM encoded wave file; its | ||
| 30 | +sampling rate can be arbitrary and does not need to be 16kHz. | ||
| 29 | 31 | ||
| 30 | Please refer to | 32 | Please refer to |
| 31 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 33 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| @@ -59,20 +61,26 @@ for a list of pre-trained models to download. | @@ -59,20 +61,26 @@ for a list of pre-trained models to download. | ||
| 59 | 61 | ||
| 60 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 62 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 61 | 63 | ||
| 64 | + if (!config.Validate()) { | ||
| 65 | + fprintf(stderr, "Errors in config!\n"); | ||
| 66 | + return -1; | ||
| 67 | + } | ||
| 68 | + | ||
| 62 | sherpa_onnx::OnlineRecognizer recognizer(config); | 69 | sherpa_onnx::OnlineRecognizer recognizer(config); |
| 63 | 70 | ||
| 64 | - int32_t expected_sampling_rate = config.feat_config.sampling_rate; | 71 | + int32_t sampling_rate = -1; |
| 65 | 72 | ||
| 66 | bool is_ok = false; | 73 | bool is_ok = false; |
| 67 | std::vector<float> samples = | 74 | std::vector<float> samples = |
| 68 | - sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); | 75 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); |
| 69 | 76 | ||
| 70 | if (!is_ok) { | 77 | if (!is_ok) { |
| 71 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | 78 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| 72 | return -1; | 79 | return -1; |
| 73 | } | 80 | } |
| 81 | + fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 74 | 82 | ||
| 75 | - float duration = samples.size() / static_cast<float>(expected_sampling_rate); | 83 | + float duration = samples.size() / static_cast<float>(sampling_rate); |
| 76 | 84 | ||
| 77 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | 85 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); |
| 78 | fprintf(stderr, "wav duration (s): %.3f\n", duration); | 86 | fprintf(stderr, "wav duration (s): %.3f\n", duration); |
| @@ -81,12 +89,13 @@ for a list of pre-trained models to download. | @@ -81,12 +89,13 @@ for a list of pre-trained models to download. | ||
| 81 | fprintf(stderr, "Started\n"); | 89 | fprintf(stderr, "Started\n"); |
| 82 | 90 | ||
| 83 | auto s = recognizer.CreateStream(); | 91 | auto s = recognizer.CreateStream(); |
| 84 | - s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | 92 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); |
| 93 | + | ||
| 94 | + std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate)); | ||
| 95 | + // Note: We can call AcceptWaveform() multiple times. | ||
| 96 | + s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); | ||
| 85 | 97 | ||
| 86 | - std::vector<float> tail_paddings( | ||
| 87 | - static_cast<int>(0.2 * expected_sampling_rate)); | ||
| 88 | - s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(), | ||
| 89 | - tail_paddings.size()); | 98 | + // Call InputFinished() to indicate that no audio samples are available |
| 90 | s->InputFinished(); | 99 | s->InputFinished(); |
| 91 | 100 | ||
| 92 | while (recognizer.IsReady(s.get())) { | 101 | while (recognizer.IsReady(s.get())) { |
| @@ -30,4 +30,23 @@ TEST(Slice, Slice3D) { | @@ -30,4 +30,23 @@ TEST(Slice, Slice3D) { | ||
| 30 | // TODO(fangjun): Check that the results are correct | 30 | // TODO(fangjun): Check that the results are correct |
| 31 | } | 31 | } |
| 32 | 32 | ||
| 33 | +TEST(Slice, Slice2D) { | ||
| 34 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 35 | + std::array<int64_t, 2> shape{5, 8}; | ||
| 36 | + Ort::Value v = | ||
| 37 | + Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size()); | ||
| 38 | + float *p = v.GetTensorMutableData<float>(); | ||
| 39 | + | ||
| 40 | + std::iota(p, p + shape[0] * shape[1], 0); | ||
| 41 | + | ||
| 42 | + auto v1 = Slice(allocator, &v, 1, 3); | ||
| 43 | + auto v2 = Slice(allocator, &v, 0, 2); | ||
| 44 | + | ||
| 45 | + Print2D(&v); | ||
| 46 | + Print2D(&v1); | ||
| 47 | + Print2D(&v2); | ||
| 48 | + | ||
| 49 | + // TODO(fangjun): Check that the results are correct | ||
| 50 | +} | ||
| 51 | + | ||
| 33 | } // namespace sherpa_onnx | 52 | } // namespace sherpa_onnx |
| @@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | @@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | ||
| 24 | 24 | ||
| 25 | assert(0 <= dim1_start); | 25 | assert(0 <= dim1_start); |
| 26 | assert(dim1_start < dim1_end); | 26 | assert(dim1_start < dim1_end); |
| 27 | - assert(dim1_end < shape[1]); | 27 | + assert(dim1_end <= shape[1]); |
| 28 | 28 | ||
| 29 | const T *src = v->GetTensorData<T>(); | 29 | const T *src = v->GetTensorData<T>(); |
| 30 | 30 | ||
| @@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | @@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | ||
| 46 | return ans; | 46 | return ans; |
| 47 | } | 47 | } |
| 48 | 48 | ||
| 49 | +template <typename T /*= float*/> | ||
| 50 | +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | ||
| 51 | + int32_t dim0_start, int32_t dim0_end) { | ||
| 52 | + std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 53 | + assert(shape.size() == 2); | ||
| 54 | + | ||
| 55 | + assert(0 <= dim0_start); | ||
| 56 | + assert(dim0_start < dim0_end); | ||
| 57 | + assert(dim0_end <= shape[0]); | ||
| 58 | + | ||
| 59 | + const T *src = v->GetTensorData<T>(); | ||
| 60 | + | ||
| 61 | + std::array<int64_t, 2> ans_shape{dim0_end - dim0_start, shape[1]}; | ||
| 62 | + | ||
| 63 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), | ||
| 64 | + ans_shape.size()); | ||
| 65 | + const T *start = v->GetTensorData<T>() + dim0_start * shape[1]; | ||
| 66 | + const T *end = v->GetTensorData<T>() + dim0_end * shape[1]; | ||
| 67 | + T *dst = ans.GetTensorMutableData<T>(); | ||
| 68 | + std::copy(start, end, dst); | ||
| 69 | + | ||
| 70 | + return ans; | ||
| 71 | +} | ||
| 72 | + | ||
| 49 | template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v, | 73 | template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v, |
| 50 | int32_t dim0_start, int32_t dim0_end, | 74 | int32_t dim0_start, int32_t dim0_end, |
| 51 | int32_t dim1_start, int32_t dim1_end); | 75 | int32_t dim1_start, int32_t dim1_end); |
| 52 | 76 | ||
| 77 | +template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v, | ||
| 78 | + int32_t dim0_start, int32_t dim0_end); | ||
| 79 | + | ||
| 53 | } // namespace sherpa_onnx | 80 | } // namespace sherpa_onnx |
| @@ -8,12 +8,12 @@ | @@ -8,12 +8,12 @@ | ||
| 8 | 8 | ||
| 9 | namespace sherpa_onnx { | 9 | namespace sherpa_onnx { |
| 10 | 10 | ||
| 11 | -/** Get a deep copy by slicing v. | 11 | +/** Get a deep copy by slicing a 3-D tensor v. |
| 12 | * | 12 | * |
| 13 | - * It returns v[dim0_start:dim0_end, dim1_start:dim1_end] | 13 | + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] |
| 14 | * | 14 | * |
| 15 | * @param allocator | 15 | * @param allocator |
| 16 | - * @param v A 3-D tensor. Its data type is T. | 16 | + * @param v A 2-D tensor. Its data type is T. |
| 17 | * @param dim0_start Start index of the first dimension.. | 17 | * @param dim0_start Start index of the first dimension.. |
| 18 | * @param dim0_end End index of the first dimension.. | 18 | * @param dim0_end End index of the first dimension.. |
| 19 | * @param dim1_start Start index of the second dimension. | 19 | * @param dim1_start Start index of the second dimension. |
| @@ -26,6 +26,23 @@ template <typename T = float> | @@ -26,6 +26,23 @@ template <typename T = float> | ||
| 26 | Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | 26 | Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, |
| 27 | int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, | 27 | int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, |
| 28 | int32_t dim1_end); | 28 | int32_t dim1_end); |
| 29 | + | ||
| 30 | +/** Get a deep copy by slicing a 2-D tensor v. | ||
| 31 | + * | ||
| 32 | + * It returns v[dim0_start:dim0_end, :] | ||
| 33 | + * | ||
| 34 | + * @param allocator | ||
| 35 | + * @param v A 2-D tensor. Its data type is T. | ||
| 36 | + * @param dim0_start Start index of the first dimension.. | ||
| 37 | + * @param dim0_end End index of the first dimension.. | ||
| 38 | + * | ||
| 39 | + * @return Return a 2-D tensor of shape | ||
| 40 | + * (dim0_end-dim0_start, v.shape[1]) | ||
| 41 | + */ | ||
| 42 | +template <typename T = float> | ||
| 43 | +Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v, | ||
| 44 | + int32_t dim0_start, int32_t dim0_end); | ||
| 45 | + | ||
| 29 | } // namespace sherpa_onnx | 46 | } // namespace sherpa_onnx |
| 30 | 47 | ||
| 31 | #endif // SHERPA_ONNX_CSRC_SLICE_H_ | 48 | #endif // SHERPA_ONNX_CSRC_SLICE_H_ |
| @@ -6,10 +6,11 @@ | @@ -6,10 +6,11 @@ | ||
| 6 | 6 | ||
| 7 | #include <cassert> | 7 | #include <cassert> |
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | -#include <iostream> | ||
| 10 | #include <utility> | 9 | #include <utility> |
| 11 | #include <vector> | 10 | #include <vector> |
| 12 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | + | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | namespace { | 15 | namespace { |
| 15 | // see http://soundfile.sapp.org/doc/WaveFormat/ | 16 | // see http://soundfile.sapp.org/doc/WaveFormat/ |
| @@ -20,26 +21,34 @@ struct WaveHeader { | @@ -20,26 +21,34 @@ struct WaveHeader { | ||
| 20 | bool Validate() const { | 21 | bool Validate() const { |
| 21 | // F F I R | 22 | // F F I R |
| 22 | if (chunk_id != 0x46464952) { | 23 | if (chunk_id != 0x46464952) { |
| 24 | + SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id); | ||
| 23 | return false; | 25 | return false; |
| 24 | } | 26 | } |
| 25 | // E V A W | 27 | // E V A W |
| 26 | if (format != 0x45564157) { | 28 | if (format != 0x45564157) { |
| 29 | + SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format); | ||
| 27 | return false; | 30 | return false; |
| 28 | } | 31 | } |
| 29 | 32 | ||
| 30 | if (subchunk1_id != 0x20746d66) { | 33 | if (subchunk1_id != 0x20746d66) { |
| 34 | + SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n", | ||
| 35 | + subchunk1_id); | ||
| 31 | return false; | 36 | return false; |
| 32 | } | 37 | } |
| 33 | 38 | ||
| 34 | if (subchunk1_size != 16) { // 16 for PCM | 39 | if (subchunk1_size != 16) { // 16 for PCM |
| 40 | + SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n", | ||
| 41 | + subchunk1_size); | ||
| 35 | return false; | 42 | return false; |
| 36 | } | 43 | } |
| 37 | 44 | ||
| 38 | if (audio_format != 1) { // 1 for PCM | 45 | if (audio_format != 1) { // 1 for PCM |
| 46 | + SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format); | ||
| 39 | return false; | 47 | return false; |
| 40 | } | 48 | } |
| 41 | 49 | ||
| 42 | if (num_channels != 1) { // we support only single channel for now | 50 | if (num_channels != 1) { // we support only single channel for now |
| 51 | + SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels); | ||
| 43 | return false; | 52 | return false; |
| 44 | } | 53 | } |
| 45 | if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { | 54 | if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) { |
| @@ -51,6 +60,8 @@ struct WaveHeader { | @@ -51,6 +60,8 @@ struct WaveHeader { | ||
| 51 | } | 60 | } |
| 52 | 61 | ||
| 53 | if (bits_per_sample != 16) { // we support only 16 bits per sample | 62 | if (bits_per_sample != 16) { // we support only 16 bits per sample |
| 63 | + SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n", | ||
| 64 | + bits_per_sample); | ||
| 54 | return false; | 65 | return false; |
| 55 | } | 66 | } |
| 56 | 67 | ||
| @@ -62,7 +73,7 @@ struct WaveHeader { | @@ -62,7 +73,7 @@ struct WaveHeader { | ||
| 62 | // and | 73 | // and |
| 63 | // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf | 74 | // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf |
| 64 | void SeekToDataChunk(std::istream &is) { | 75 | void SeekToDataChunk(std::istream &is) { |
| 65 | - // a t a d | 76 | + // a t a d |
| 66 | while (is && subchunk2_id != 0x61746164) { | 77 | while (is && subchunk2_id != 0x61746164) { |
| 67 | // const char *p = reinterpret_cast<const char *>(&subchunk2_id); | 78 | // const char *p = reinterpret_cast<const char *>(&subchunk2_id); |
| 68 | // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], | 79 | // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], |
| @@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, ""); | @@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, ""); | ||
| 91 | 102 | ||
| 92 | // Read a wave file of mono-channel. | 103 | // Read a wave file of mono-channel. |
| 93 | // Return its samples normalized to the range [-1, 1). | 104 | // Return its samples normalized to the range [-1, 1). |
| 94 | -std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, | 105 | +std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, |
| 95 | bool *is_ok) { | 106 | bool *is_ok) { |
| 96 | WaveHeader header; | 107 | WaveHeader header; |
| 97 | is.read(reinterpret_cast<char *>(&header), sizeof(header)); | 108 | is.read(reinterpret_cast<char *>(&header), sizeof(header)); |
| @@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, | @@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, | ||
| 111 | return {}; | 122 | return {}; |
| 112 | } | 123 | } |
| 113 | 124 | ||
| 114 | - if (expected_sample_rate != header.sample_rate) { | ||
| 115 | - *is_ok = false; | ||
| 116 | - return {}; | ||
| 117 | - } | 125 | + *sampling_rate = header.sample_rate; |
| 118 | 126 | ||
| 119 | // header.subchunk2_size contains the number of bytes in the data. | 127 | // header.subchunk2_size contains the number of bytes in the data. |
| 120 | // As we assume each sample contains two bytes, so it is divided by 2 here | 128 | // As we assume each sample contains two bytes, so it is divided by 2 here |
| @@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, | @@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate, | ||
| 137 | 145 | ||
| 138 | } // namespace | 146 | } // namespace |
| 139 | 147 | ||
| 140 | -std::vector<float> ReadWave(const std::string &filename, | ||
| 141 | - float expected_sample_rate, bool *is_ok) { | 148 | +std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, |
| 149 | + bool *is_ok) { | ||
| 142 | std::ifstream is(filename, std::ifstream::binary); | 150 | std::ifstream is(filename, std::ifstream::binary); |
| 143 | - return ReadWave(is, expected_sample_rate, is_ok); | 151 | + return ReadWave(is, sampling_rate, is_ok); |
| 144 | } | 152 | } |
| 145 | 153 | ||
| 146 | -std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, | 154 | +std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, |
| 147 | bool *is_ok) { | 155 | bool *is_ok) { |
| 148 | - auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok); | 156 | + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); |
| 149 | return samples; | 157 | return samples; |
| 150 | } | 158 | } |
| 151 | 159 |
| @@ -13,17 +13,17 @@ namespace sherpa_onnx { | @@ -13,17 +13,17 @@ namespace sherpa_onnx { | ||
| 13 | 13 | ||
| 14 | /** Read a wave file with expected sample rate. | 14 | /** Read a wave file with expected sample rate. |
| 15 | 15 | ||
| 16 | - @param filename Path to a wave file. It MUST be single channel, PCM encoded. | ||
| 17 | - @param expected_sample_rate Expected sample rate of the wave file. If the | ||
| 18 | - sample rate don't match, it throws an exception. | 16 | + @param filename Path to a wave file. It MUST be single channel, 16-bit |
| 17 | + PCM encoded. | ||
| 18 | + @param sampling_rate On return, it contains the sampling rate of the file. | ||
| 19 | @param is_ok On return it is true if the reading succeeded; false otherwise. | 19 | @param is_ok On return it is true if the reading succeeded; false otherwise. |
| 20 | 20 | ||
| 21 | @return Return wave samples normalized to the range [-1, 1). | 21 | @return Return wave samples normalized to the range [-1, 1). |
| 22 | */ | 22 | */ |
| 23 | -std::vector<float> ReadWave(const std::string &filename, | ||
| 24 | - float expected_sample_rate, bool *is_ok); | 23 | +std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, |
| 24 | + bool *is_ok); | ||
| 25 | 25 | ||
| 26 | -std::vector<float> ReadWave(std::istream &is, float expected_sample_rate, | 26 | +std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, |
| 27 | bool *is_ok); | 27 | bool *is_ok); |
| 28 | 28 | ||
| 29 | } // namespace sherpa_onnx | 29 | } // namespace sherpa_onnx |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "jni.h" // NOLINT | 11 | #include "jni.h" // NOLINT |
| 12 | 12 | ||
| 13 | #include <strstream> | 13 | #include <strstream> |
| 14 | +#include <utility> | ||
| 14 | 15 | ||
| 15 | #if __ANDROID_API__ >= 9 | 16 | #if __ANDROID_API__ >= 9 |
| 16 | #include "android/asset_manager.h" | 17 | #include "android/asset_manager.h" |
| @@ -43,14 +44,18 @@ class SherpaOnnx { | @@ -43,14 +44,18 @@ class SherpaOnnx { | ||
| 43 | stream_(recognizer_.CreateStream()) { | 44 | stream_(recognizer_.CreateStream()) { |
| 44 | } | 45 | } |
| 45 | 46 | ||
| 46 | - void AcceptWaveform(int32_t sample_rate, const float *samples, | ||
| 47 | - int32_t n) const { | 47 | + void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) { |
| 48 | + if (input_sample_rate_ == -1) { | ||
| 49 | + input_sample_rate_ = sample_rate; | ||
| 50 | + } | ||
| 51 | + | ||
| 48 | stream_->AcceptWaveform(sample_rate, samples, n); | 52 | stream_->AcceptWaveform(sample_rate, samples, n); |
| 49 | } | 53 | } |
| 50 | 54 | ||
| 51 | void InputFinished() const { | 55 | void InputFinished() const { |
| 52 | - std::vector<float> tail_padding(16000 * 0.32, 0); | ||
| 53 | - stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size()); | 56 | + std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0); |
| 57 | + stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), | ||
| 58 | + tail_padding.size()); | ||
| 54 | stream_->InputFinished(); | 59 | stream_->InputFinished(); |
| 55 | } | 60 | } |
| 56 | 61 | ||
| @@ -70,6 +75,7 @@ class SherpaOnnx { | @@ -70,6 +75,7 @@ class SherpaOnnx { | ||
| 70 | private: | 75 | private: |
| 71 | sherpa_onnx::OnlineRecognizer recognizer_; | 76 | sherpa_onnx::OnlineRecognizer recognizer_; |
| 72 | std::unique_ptr<sherpa_onnx::OnlineStream> stream_; | 77 | std::unique_ptr<sherpa_onnx::OnlineStream> stream_; |
| 78 | + int32_t input_sample_rate_ = -1; | ||
| 73 | }; | 79 | }; |
| 74 | 80 | ||
| 75 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | 81 | static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { |
| @@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( | @@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( | ||
| 276 | return env->NewStringUTF(text.c_str()); | 282 | return env->NewStringUTF(text.c_str()); |
| 277 | } | 283 | } |
| 278 | 284 | ||
| 285 | +// see | ||
| 286 | +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables | ||
| 287 | +static jobject NewInteger(JNIEnv *env, int32_t value) { | ||
| 288 | + jclass cls = env->FindClass("java/lang/Integer"); | ||
| 289 | + jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V"); | ||
| 290 | + return env->NewObject(cls, constructor, value); | ||
| 291 | +} | ||
| 292 | + | ||
| 279 | SHERPA_ONNX_EXTERN_C | 293 | SHERPA_ONNX_EXTERN_C |
| 280 | -JNIEXPORT jfloatArray JNICALL | 294 | +JNIEXPORT jobjectArray JNICALL |
| 281 | Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( | 295 | Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( |
| 282 | - JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename, | ||
| 283 | - jfloat expected_sample_rate) { | 296 | + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { |
| 284 | const char *p_filename = env->GetStringUTFChars(filename, nullptr); | 297 | const char *p_filename = env->GetStringUTFChars(filename, nullptr); |
| 285 | #if __ANDROID_API__ >= 9 | 298 | #if __ANDROID_API__ >= 9 |
| 286 | AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); | 299 | AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); |
| 287 | if (!mgr) { | 300 | if (!mgr) { |
| 288 | SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); | 301 | SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); |
| 289 | - return nullptr; | 302 | + exit(-1); |
| 290 | } | 303 | } |
| 291 | 304 | ||
| 292 | std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename); | 305 | std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename); |
| @@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( | @@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave( | ||
| 297 | #endif | 310 | #endif |
| 298 | 311 | ||
| 299 | bool is_ok = false; | 312 | bool is_ok = false; |
| 313 | + int32_t sampling_rate = -1; | ||
| 300 | std::vector<float> samples = | 314 | std::vector<float> samples = |
| 301 | - sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok); | 315 | + sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok); |
| 302 | 316 | ||
| 303 | env->ReleaseStringUTFChars(filename, p_filename); | 317 | env->ReleaseStringUTFChars(filename, p_filename); |
| 304 | 318 | ||
| 305 | if (!is_ok) { | 319 | if (!is_ok) { |
| 306 | - return nullptr; | 320 | + SHERPA_ONNX_LOGE("Failed to read %s", p_filename); |
| 321 | + exit(-1); | ||
| 307 | } | 322 | } |
| 308 | 323 | ||
| 309 | jfloatArray ans = env->NewFloatArray(samples.size()); | 324 | jfloatArray ans = env->NewFloatArray(samples.size()); |
| 310 | env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data()); | 325 | env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data()); |
| 311 | - return ans; | 326 | + |
| 327 | + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( | ||
| 328 | + 2, env->FindClass("java/lang/Object"), nullptr); | ||
| 329 | + | ||
| 330 | + env->SetObjectArrayElement(obj_arr, 0, ans); | ||
| 331 | + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate)); | ||
| 332 | + | ||
| 333 | + return obj_arr; | ||
| 312 | } | 334 | } |
| @@ -11,12 +11,10 @@ namespace sherpa_onnx { | @@ -11,12 +11,10 @@ namespace sherpa_onnx { | ||
| 11 | static void PybindFeatureExtractorConfig(py::module *m) { | 11 | static void PybindFeatureExtractorConfig(py::module *m) { |
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | - .def(py::init<int32_t, int32_t, int32_t>(), | ||
| 15 | - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, | ||
| 16 | - py::arg("max_feature_vectors") = -1) | 14 | + .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000, |
| 15 | + py::arg("feature_dim") = 80) | ||
| 17 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 16 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 18 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 17 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| 19 | - .def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors) | ||
| 20 | .def("__str__", &PyClass::ToString); | 18 | .def("__str__", &PyClass::ToString); |
| 21 | } | 19 | } |
| 22 | 20 |
| @@ -34,7 +34,6 @@ class OnlineRecognizer(object): | @@ -34,7 +34,6 @@ class OnlineRecognizer(object): | ||
| 34 | rule3_min_utterance_length: int = 20, | 34 | rule3_min_utterance_length: int = 20, |
| 35 | decoding_method: str = "greedy_search", | 35 | decoding_method: str = "greedy_search", |
| 36 | max_active_paths: int = 4, | 36 | max_active_paths: int = 4, |
| 37 | - max_feature_vectors: int = -1, | ||
| 38 | ): | 37 | ): |
| 39 | """ | 38 | """ |
| 40 | Please refer to | 39 | Please refer to |
| @@ -82,9 +81,6 @@ class OnlineRecognizer(object): | @@ -82,9 +81,6 @@ class OnlineRecognizer(object): | ||
| 82 | max_active_paths: | 81 | max_active_paths: |
| 83 | Use only when decoding_method is modified_beam_search. It specifies | 82 | Use only when decoding_method is modified_beam_search. It specifies |
| 84 | the maximum number of active paths during beam search. | 83 | the maximum number of active paths during beam search. |
| 85 | - max_feature_vectors: | ||
| 86 | - Number of feature vectors to cache. -1 means to cache all feature | ||
| 87 | - frames that have been processed. | ||
| 88 | """ | 84 | """ |
| 89 | _assert_file_exists(tokens) | 85 | _assert_file_exists(tokens) |
| 90 | _assert_file_exists(encoder) | 86 | _assert_file_exists(encoder) |
| @@ -104,7 +100,6 @@ class OnlineRecognizer(object): | @@ -104,7 +100,6 @@ class OnlineRecognizer(object): | ||
| 104 | feat_config = FeatureExtractorConfig( | 100 | feat_config = FeatureExtractorConfig( |
| 105 | sampling_rate=sample_rate, | 101 | sampling_rate=sample_rate, |
| 106 | feature_dim=feature_dim, | 102 | feature_dim=feature_dim, |
| 107 | - max_feature_vectors=max_feature_vectors, | ||
| 108 | ) | 103 | ) |
| 109 | 104 | ||
| 110 | endpoint_config = EndpointConfig( | 105 | endpoint_config = EndpointConfig( |
| @@ -8,18 +8,18 @@ | @@ -8,18 +8,18 @@ | ||
| 8 | 8 | ||
| 9 | import unittest | 9 | import unittest |
| 10 | 10 | ||
| 11 | -import sherpa_onnx | 11 | +import _sherpa_onnx |
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | class TestFeatureExtractorConfig(unittest.TestCase): | 14 | class TestFeatureExtractorConfig(unittest.TestCase): |
| 15 | def test_default_constructor(self): | 15 | def test_default_constructor(self): |
| 16 | - config = sherpa_onnx.FeatureExtractorConfig() | 16 | + config = _sherpa_onnx.FeatureExtractorConfig() |
| 17 | assert config.sampling_rate == 16000, config.sampling_rate | 17 | assert config.sampling_rate == 16000, config.sampling_rate |
| 18 | assert config.feature_dim == 80, config.feature_dim | 18 | assert config.feature_dim == 80, config.feature_dim |
| 19 | print(config) | 19 | print(config) |
| 20 | 20 | ||
| 21 | def test_constructor(self): | 21 | def test_constructor(self): |
| 22 | - config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) | 22 | + config = _sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) |
| 23 | assert config.sampling_rate == 8000, config.sampling_rate | 23 | assert config.sampling_rate == 8000, config.sampling_rate |
| 24 | assert config.feature_dim == 40, config.feature_dim | 24 | assert config.feature_dim == 40, config.feature_dim |
| 25 | print(config) | 25 | print(config) |
| @@ -8,21 +8,23 @@ | @@ -8,21 +8,23 @@ | ||
| 8 | 8 | ||
| 9 | import unittest | 9 | import unittest |
| 10 | 10 | ||
| 11 | -import sherpa_onnx | 11 | +import _sherpa_onnx |
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | class TestOnlineTransducerModelConfig(unittest.TestCase): | 14 | class TestOnlineTransducerModelConfig(unittest.TestCase): |
| 15 | def test_constructor(self): | 15 | def test_constructor(self): |
| 16 | - config = sherpa_onnx.OnlineTransducerModelConfig( | 16 | + config = _sherpa_onnx.OnlineTransducerModelConfig( |
| 17 | encoder_filename="encoder.onnx", | 17 | encoder_filename="encoder.onnx", |
| 18 | decoder_filename="decoder.onnx", | 18 | decoder_filename="decoder.onnx", |
| 19 | joiner_filename="joiner.onnx", | 19 | joiner_filename="joiner.onnx", |
| 20 | + tokens="tokens.txt", | ||
| 20 | num_threads=8, | 21 | num_threads=8, |
| 21 | debug=True, | 22 | debug=True, |
| 22 | ) | 23 | ) |
| 23 | assert config.encoder_filename == "encoder.onnx", config.encoder_filename | 24 | assert config.encoder_filename == "encoder.onnx", config.encoder_filename |
| 24 | assert config.decoder_filename == "decoder.onnx", config.decoder_filename | 25 | assert config.decoder_filename == "decoder.onnx", config.decoder_filename |
| 25 | assert config.joiner_filename == "joiner.onnx", config.joiner_filename | 26 | assert config.joiner_filename == "joiner.onnx", config.joiner_filename |
| 27 | + assert config.tokens == "tokens.txt", config.tokens | ||
| 26 | assert config.num_threads == 8, config.num_threads | 28 | assert config.num_threads == 8, config.num_threads |
| 27 | assert config.debug is True, config.debug | 29 | assert config.debug is True, config.debug |
| 28 | print(config) | 30 | print(config) |
-
请 注册 或 登录 后发表评论