Committed by
GitHub
Support streaming zipformer CTC (#496)
* Support streaming zipformer CTC * test online zipformer2 CTC * Update doc of sherpa-onnx.cc * Add Python APIs for streaming zipformer2 ctc * Add Python API examples for streaming zipformer2 ctc * Swift API for streaming zipformer2 CTC * NodeJS API for streaming zipformer2 CTC * Kotlin API for streaming zipformer2 CTC * Golang API for streaming zipformer2 CTC * C# API for streaming zipformer2 CTC * Release v1.9.6
正在显示
70 个修改的文件
包含
1518 行增加
和
212 行删除
| @@ -51,6 +51,13 @@ rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | @@ -51,6 +51,13 @@ rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 51 | node ./test-online-transducer.js | 51 | node ./test-online-transducer.js |
| 52 | rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 | 52 | rm -rf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 |
| 53 | 53 | ||
| 54 | +curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 55 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 56 | +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 57 | + | ||
| 58 | +node ./test-online-zipformer2-ctc.js | ||
| 59 | +rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 60 | + | ||
| 54 | # offline tts | 61 | # offline tts |
| 55 | 62 | ||
| 56 | curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 | 63 | curl -LS -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 |
| @@ -14,6 +14,37 @@ echo "PATH: $PATH" | @@ -14,6 +14,37 @@ echo "PATH: $PATH" | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | log "------------------------------------------------------------" | 16 | log "------------------------------------------------------------" |
| 17 | +log "Run streaming Zipformer2 CTC " | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 21 | +repo=$(basename -s .tar.bz2 $url) | ||
| 22 | +curl -SL -O $url | ||
| 23 | +tar xvf $repo.tar.bz2 | ||
| 24 | +rm $repo.tar.bz2 | ||
| 25 | + | ||
| 26 | +log "test fp32" | ||
| 27 | + | ||
| 28 | +time $EXE \ | ||
| 29 | + --debug=1 \ | ||
| 30 | + --zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 31 | + --tokens=$repo/tokens.txt \ | ||
| 32 | + $repo/test_wavs/DEV_T0000000000.wav \ | ||
| 33 | + $repo/test_wavs/DEV_T0000000001.wav \ | ||
| 34 | + $repo/test_wavs/DEV_T0000000002.wav | ||
| 35 | + | ||
| 36 | +log "test int8" | ||
| 37 | + | ||
| 38 | +time $EXE \ | ||
| 39 | + --debug=1 \ | ||
| 40 | + --zipformer2-ctc-model=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ | ||
| 41 | + --tokens=$repo/tokens.txt \ | ||
| 42 | + $repo/test_wavs/DEV_T0000000000.wav \ | ||
| 43 | + $repo/test_wavs/DEV_T0000000001.wav \ | ||
| 44 | + $repo/test_wavs/DEV_T0000000002.wav | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +log "------------------------------------------------------------" | ||
| 17 | log "Run streaming Conformer CTC from WeNet" | 48 | log "Run streaming Conformer CTC from WeNet" |
| 18 | log "------------------------------------------------------------" | 49 | log "------------------------------------------------------------" |
| 19 | wenet_models=( | 50 | wenet_models=( |
| @@ -8,6 +8,27 @@ log() { | @@ -8,6 +8,27 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +mkdir -p /tmp/icefall-models | ||
| 12 | +dir=/tmp/icefall-models | ||
| 13 | + | ||
| 14 | +pushd $dir | ||
| 15 | +wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 16 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 17 | +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 18 | +popd | ||
| 19 | +repo=$dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 20 | + | ||
| 21 | +python3 ./python-api-examples/online-decode-files.py \ | ||
| 22 | + --tokens=$repo/tokens.txt \ | ||
| 23 | + --zipformer2-ctc=$repo/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 24 | + $repo/test_wavs/DEV_T0000000000.wav \ | ||
| 25 | + $repo/test_wavs/DEV_T0000000001.wav \ | ||
| 26 | + $repo/test_wavs/DEV_T0000000002.wav | ||
| 27 | + | ||
| 28 | +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | ||
| 29 | + | ||
| 30 | +rm -rf $dir/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 31 | + | ||
| 11 | wenet_models=( | 32 | wenet_models=( |
| 12 | sherpa-onnx-zh-wenet-aishell | 33 | sherpa-onnx-zh-wenet-aishell |
| 13 | sherpa-onnx-zh-wenet-aishell2 | 34 | sherpa-onnx-zh-wenet-aishell2 |
| @@ -17,8 +38,6 @@ sherpa-onnx-en-wenet-librispeech | @@ -17,8 +38,6 @@ sherpa-onnx-en-wenet-librispeech | ||
| 17 | sherpa-onnx-en-wenet-gigaspeech | 38 | sherpa-onnx-en-wenet-gigaspeech |
| 18 | ) | 39 | ) |
| 19 | 40 | ||
| 20 | -mkdir -p /tmp/icefall-models | ||
| 21 | -dir=/tmp/icefall-models | ||
| 22 | 41 | ||
| 23 | for name in ${wenet_models[@]}; do | 42 | for name in ${wenet_models[@]}; do |
| 24 | repo_url=https://huggingface.co/csukuangfj/$name | 43 | repo_url=https://huggingface.co/csukuangfj/$name |
| @@ -22,6 +22,9 @@ cat /Users/fangjun/Desktop/Obama.srt | @@ -22,6 +22,9 @@ cat /Users/fangjun/Desktop/Obama.srt | ||
| 22 | ls -lh | 22 | ls -lh |
| 23 | 23 | ||
| 24 | ./run-decode-file.sh | 24 | ./run-decode-file.sh |
| 25 | +rm decode-file | ||
| 26 | +sed -i.bak '20d' ./decode-file.swift | ||
| 27 | +./run-decode-file.sh | ||
| 25 | 28 | ||
| 26 | ./run-decode-file-non-streaming.sh | 29 | ./run-decode-file-non-streaming.sh |
| 27 | 30 |
| @@ -22,7 +22,7 @@ jobs: | @@ -22,7 +22,7 @@ jobs: | ||
| 22 | - uses: actions/checkout@v4 | 22 | - uses: actions/checkout@v4 |
| 23 | 23 | ||
| 24 | - name: Setup Python ${{ matrix.python-version }} | 24 | - name: Setup Python ${{ matrix.python-version }} |
| 25 | - uses: actions/setup-python@v2 | 25 | + uses: actions/setup-python@v4 |
| 26 | with: | 26 | with: |
| 27 | python-version: ${{ matrix.python-version }} | 27 | python-version: ${{ matrix.python-version }} |
| 28 | 28 |
| @@ -22,7 +22,7 @@ jobs: | @@ -22,7 +22,7 @@ jobs: | ||
| 22 | - uses: actions/checkout@v4 | 22 | - uses: actions/checkout@v4 |
| 23 | 23 | ||
| 24 | - name: Setup Python ${{ matrix.python-version }} | 24 | - name: Setup Python ${{ matrix.python-version }} |
| 25 | - uses: actions/setup-python@v2 | 25 | + uses: actions/setup-python@v4 |
| 26 | with: | 26 | with: |
| 27 | python-version: ${{ matrix.python-version }} | 27 | python-version: ${{ matrix.python-version }} |
| 28 | 28 |
| @@ -24,7 +24,7 @@ jobs: | @@ -24,7 +24,7 @@ jobs: | ||
| 24 | - uses: actions/checkout@v4 | 24 | - uses: actions/checkout@v4 |
| 25 | 25 | ||
| 26 | - name: Setup Python ${{ matrix.python-version }} | 26 | - name: Setup Python ${{ matrix.python-version }} |
| 27 | - uses: actions/setup-python@v2 | 27 | + uses: actions/setup-python@v4 |
| 28 | with: | 28 | with: |
| 29 | python-version: ${{ matrix.python-version }} | 29 | python-version: ${{ matrix.python-version }} |
| 30 | 30 |
| @@ -107,23 +107,23 @@ jobs: | @@ -107,23 +107,23 @@ jobs: | ||
| 107 | name: release-static | 107 | name: release-static |
| 108 | path: build/bin/* | 108 | path: build/bin/* |
| 109 | 109 | ||
| 110 | - - name: Test offline Whisper | 110 | + - name: Test online CTC |
| 111 | shell: bash | 111 | shell: bash |
| 112 | run: | | 112 | run: | |
| 113 | export PATH=$PWD/build/bin:$PATH | 113 | export PATH=$PWD/build/bin:$PATH |
| 114 | - export EXE=sherpa-onnx-offline | ||
| 115 | - | ||
| 116 | - readelf -d build/bin/sherpa-onnx-offline | 114 | + export EXE=sherpa-onnx |
| 117 | 115 | ||
| 118 | - .github/scripts/test-offline-whisper.sh | 116 | + .github/scripts/test-online-ctc.sh |
| 119 | 117 | ||
| 120 | - - name: Test online CTC | 118 | + - name: Test offline Whisper |
| 121 | shell: bash | 119 | shell: bash |
| 122 | run: | | 120 | run: | |
| 123 | export PATH=$PWD/build/bin:$PATH | 121 | export PATH=$PWD/build/bin:$PATH |
| 124 | - export EXE=sherpa-onnx | 122 | + export EXE=sherpa-onnx-offline |
| 125 | 123 | ||
| 126 | - .github/scripts/test-online-ctc.sh | 124 | + readelf -d build/bin/sherpa-onnx-offline |
| 125 | + | ||
| 126 | + .github/scripts/test-offline-whisper.sh | ||
| 127 | 127 | ||
| 128 | - name: Test offline CTC | 128 | - name: Test offline CTC |
| 129 | shell: bash | 129 | shell: bash |
| @@ -25,7 +25,7 @@ jobs: | @@ -25,7 +25,7 @@ jobs: | ||
| 25 | fetch-depth: 0 | 25 | fetch-depth: 0 |
| 26 | 26 | ||
| 27 | - name: Setup Python ${{ matrix.python-version }} | 27 | - name: Setup Python ${{ matrix.python-version }} |
| 28 | - uses: actions/setup-python@v2 | 28 | + uses: actions/setup-python@v4 |
| 29 | with: | 29 | with: |
| 30 | python-version: ${{ matrix.python-version }} | 30 | python-version: ${{ matrix.python-version }} |
| 31 | 31 |
| @@ -55,7 +55,7 @@ jobs: | @@ -55,7 +55,7 @@ jobs: | ||
| 55 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} | 55 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} |
| 56 | 56 | ||
| 57 | - name: Setup Python | 57 | - name: Setup Python |
| 58 | - uses: actions/setup-python@v2 | 58 | + uses: actions/setup-python@v4 |
| 59 | with: | 59 | with: |
| 60 | python-version: ${{ matrix.python-version }} | 60 | python-version: ${{ matrix.python-version }} |
| 61 | 61 |
| @@ -49,7 +49,7 @@ jobs: | @@ -49,7 +49,7 @@ jobs: | ||
| 49 | fetch-depth: 0 | 49 | fetch-depth: 0 |
| 50 | 50 | ||
| 51 | - name: Setup Python ${{ matrix.python-version }} | 51 | - name: Setup Python ${{ matrix.python-version }} |
| 52 | - uses: actions/setup-python@v1 | 52 | + uses: actions/setup-python@v4 |
| 53 | with: | 53 | with: |
| 54 | python-version: ${{ matrix.python-version }} | 54 | python-version: ${{ matrix.python-version }} |
| 55 | 55 |
| @@ -29,7 +29,7 @@ jobs: | @@ -29,7 +29,7 @@ jobs: | ||
| 29 | fetch-depth: 0 | 29 | fetch-depth: 0 |
| 30 | 30 | ||
| 31 | - name: Setup Python ${{ matrix.python-version }} | 31 | - name: Setup Python ${{ matrix.python-version }} |
| 32 | - uses: actions/setup-python@v2 | 32 | + uses: actions/setup-python@v4 |
| 33 | with: | 33 | with: |
| 34 | python-version: ${{ matrix.python-version }} | 34 | python-version: ${{ matrix.python-version }} |
| 35 | 35 |
| @@ -61,7 +61,7 @@ jobs: | @@ -61,7 +61,7 @@ jobs: | ||
| 61 | strategy: | 61 | strategy: |
| 62 | fail-fast: false | 62 | fail-fast: false |
| 63 | matrix: | 63 | matrix: |
| 64 | - os: [ubuntu-latest, macos-latest] | 64 | + os: [ubuntu-latest, macos-latest] #, windows-latest] |
| 65 | python-version: ["3.8"] | 65 | python-version: ["3.8"] |
| 66 | 66 | ||
| 67 | steps: | 67 | steps: |
| @@ -70,7 +70,7 @@ jobs: | @@ -70,7 +70,7 @@ jobs: | ||
| 70 | fetch-depth: 0 | 70 | fetch-depth: 0 |
| 71 | 71 | ||
| 72 | - name: Setup Python ${{ matrix.python-version }} | 72 | - name: Setup Python ${{ matrix.python-version }} |
| 73 | - uses: actions/setup-python@v2 | 73 | + uses: actions/setup-python@v4 |
| 74 | with: | 74 | with: |
| 75 | python-version: ${{ matrix.python-version }} | 75 | python-version: ${{ matrix.python-version }} |
| 76 | 76 | ||
| @@ -143,6 +143,7 @@ jobs: | @@ -143,6 +143,7 @@ jobs: | ||
| 143 | cd dotnet-examples/ | 143 | cd dotnet-examples/ |
| 144 | 144 | ||
| 145 | cd online-decode-files | 145 | cd online-decode-files |
| 146 | + ./run-zipformer2-ctc.sh | ||
| 146 | ./run-transducer.sh | 147 | ./run-transducer.sh |
| 147 | ./run-paraformer.sh | 148 | ./run-paraformer.sh |
| 148 | 149 |
| @@ -53,7 +53,7 @@ jobs: | @@ -53,7 +53,7 @@ jobs: | ||
| 53 | mkdir build | 53 | mkdir build |
| 54 | cd build | 54 | cd build |
| 55 | cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF .. | 55 | cmake -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DBUILD_SHARED_LIBS=ON -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF .. |
| 56 | - make -j | 56 | + make -j1 |
| 57 | cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/ | 57 | cp -v _deps/onnxruntime-src/lib/libonnxruntime*dylib ./lib/ |
| 58 | 58 | ||
| 59 | cd ../scripts/go/_internal/ | 59 | cd ../scripts/go/_internal/ |
| @@ -153,6 +153,14 @@ jobs: | @@ -153,6 +153,14 @@ jobs: | ||
| 153 | 153 | ||
| 154 | git lfs install | 154 | git lfs install |
| 155 | 155 | ||
| 156 | + echo "Test zipformer2 CTC" | ||
| 157 | + wget -qq https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 158 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 159 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 160 | + | ||
| 161 | + ./run-zipformer2-ctc.sh | ||
| 162 | + rm -rf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 163 | + | ||
| 156 | echo "Test transducer" | 164 | echo "Test transducer" |
| 157 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 | 165 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 |
| 158 | ./run-transducer.sh | 166 | ./run-transducer.sh |
| @@ -34,7 +34,7 @@ jobs: | @@ -34,7 +34,7 @@ jobs: | ||
| 34 | fetch-depth: 0 | 34 | fetch-depth: 0 |
| 35 | 35 | ||
| 36 | - name: Setup Python ${{ matrix.python-version }} | 36 | - name: Setup Python ${{ matrix.python-version }} |
| 37 | - uses: actions/setup-python@v2 | 37 | + uses: actions/setup-python@v4 |
| 38 | with: | 38 | with: |
| 39 | python-version: ${{ matrix.python-version }} | 39 | python-version: ${{ matrix.python-version }} |
| 40 | 40 |
| @@ -52,7 +52,7 @@ jobs: | @@ -52,7 +52,7 @@ jobs: | ||
| 52 | ls -lh install/lib | 52 | ls -lh install/lib |
| 53 | 53 | ||
| 54 | - name: Setup Python ${{ matrix.python-version }} | 54 | - name: Setup Python ${{ matrix.python-version }} |
| 55 | - uses: actions/setup-python@v2 | 55 | + uses: actions/setup-python@v4 |
| 56 | with: | 56 | with: |
| 57 | python-version: ${{ matrix.python-version }} | 57 | python-version: ${{ matrix.python-version }} |
| 58 | 58 |
| @@ -40,7 +40,7 @@ jobs: | @@ -40,7 +40,7 @@ jobs: | ||
| 40 | fetch-depth: 0 | 40 | fetch-depth: 0 |
| 41 | 41 | ||
| 42 | - name: Setup Python ${{ matrix.python-version }} | 42 | - name: Setup Python ${{ matrix.python-version }} |
| 43 | - uses: actions/setup-python@v2 | 43 | + uses: actions/setup-python@v4 |
| 44 | with: | 44 | with: |
| 45 | python-version: ${{ matrix.python-version }} | 45 | python-version: ${{ matrix.python-version }} |
| 46 | 46 |
| @@ -38,7 +38,7 @@ jobs: | @@ -38,7 +38,7 @@ jobs: | ||
| 38 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} | 38 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} |
| 39 | 39 | ||
| 40 | - name: Setup Python ${{ matrix.python-version }} | 40 | - name: Setup Python ${{ matrix.python-version }} |
| 41 | - uses: actions/setup-python@v2 | 41 | + uses: actions/setup-python@v4 |
| 42 | with: | 42 | with: |
| 43 | python-version: ${{ matrix.python-version }} | 43 | python-version: ${{ matrix.python-version }} |
| 44 | 44 |
| @@ -25,7 +25,7 @@ jobs: | @@ -25,7 +25,7 @@ jobs: | ||
| 25 | matrix: | 25 | matrix: |
| 26 | os: [ubuntu-latest, windows-latest, macos-latest] | 26 | os: [ubuntu-latest, windows-latest, macos-latest] |
| 27 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | 27 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] |
| 28 | - model_type: ["transducer", "paraformer"] | 28 | + model_type: ["transducer", "paraformer", "zipformer2-ctc"] |
| 29 | 29 | ||
| 30 | steps: | 30 | steps: |
| 31 | - uses: actions/checkout@v4 | 31 | - uses: actions/checkout@v4 |
| @@ -38,7 +38,7 @@ jobs: | @@ -38,7 +38,7 @@ jobs: | ||
| 38 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} | 38 | key: ${{ matrix.os }}-python-${{ matrix.python-version }} |
| 39 | 39 | ||
| 40 | - name: Setup Python ${{ matrix.python-version }} | 40 | - name: Setup Python ${{ matrix.python-version }} |
| 41 | - uses: actions/setup-python@v2 | 41 | + uses: actions/setup-python@v4 |
| 42 | with: | 42 | with: |
| 43 | python-version: ${{ matrix.python-version }} | 43 | python-version: ${{ matrix.python-version }} |
| 44 | 44 | ||
| @@ -57,6 +57,26 @@ jobs: | @@ -57,6 +57,26 @@ jobs: | ||
| 57 | python3 -m pip install --no-deps --verbose . | 57 | python3 -m pip install --no-deps --verbose . |
| 58 | python3 -m pip install websockets | 58 | python3 -m pip install websockets |
| 59 | 59 | ||
| 60 | + - name: Start server for zipformer2 CTC models | ||
| 61 | + if: matrix.model_type == 'zipformer2-ctc' | ||
| 62 | + shell: bash | ||
| 63 | + run: | | ||
| 64 | + curl -O -L https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 65 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 66 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 67 | + | ||
| 68 | + python3 ./python-api-examples/streaming_server.py \ | ||
| 69 | + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 70 | + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt & | ||
| 71 | + echo "sleep 10 seconds to wait the server start" | ||
| 72 | + sleep 10 | ||
| 73 | + | ||
| 74 | + - name: Start client for zipformer2 CTC models | ||
| 75 | + if: matrix.model_type == 'zipformer2-ctc' | ||
| 76 | + shell: bash | ||
| 77 | + run: | | ||
| 78 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 79 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav | ||
| 60 | 80 | ||
| 61 | - name: Start server for transducer models | 81 | - name: Start server for transducer models |
| 62 | if: matrix.model_type == 'transducer' | 82 | if: matrix.model_type == 'transducer' |
| @@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig( | @@ -26,9 +26,14 @@ data class OnlineParaformerModelConfig( | ||
| 26 | var decoder: String = "", | 26 | var decoder: String = "", |
| 27 | ) | 27 | ) |
| 28 | 28 | ||
| 29 | +data class OnlineZipformer2CtcModelConfig( | ||
| 30 | + var model: String = "", | ||
| 31 | +) | ||
| 32 | + | ||
| 29 | data class OnlineModelConfig( | 33 | data class OnlineModelConfig( |
| 30 | var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), | 34 | var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), |
| 31 | var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), | 35 | var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), |
| 36 | + var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(), | ||
| 32 | var tokens: String, | 37 | var tokens: String, |
| 33 | var numThreads: Int = 1, | 38 | var numThreads: Int = 1, |
| 34 | var debug: Boolean = false, | 39 | var debug: Boolean = false, |
| @@ -38,6 +38,9 @@ class OnlineDecodeFiles | @@ -38,6 +38,9 @@ class OnlineDecodeFiles | ||
| 38 | [Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")] | 38 | [Option("paraformer-decoder", Required = false, HelpText = "Path to paraformer decoder.onnx")] |
| 39 | public string ParaformerDecoder { get; set; } | 39 | public string ParaformerDecoder { get; set; } |
| 40 | 40 | ||
| 41 | + [Option("zipformer2-ctc", Required = false, HelpText = "Path to zipformer2 CTC onnx model")] | ||
| 42 | + public string Zipformer2Ctc { get; set; } | ||
| 43 | + | ||
| 41 | [Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")] | 44 | [Option("num-threads", Required = false, Default = 1, HelpText = "Number of threads for computation")] |
| 42 | public int NumThreads { get; set; } | 45 | public int NumThreads { get; set; } |
| 43 | 46 | ||
| @@ -107,7 +110,19 @@ dotnet run \ | @@ -107,7 +110,19 @@ dotnet run \ | ||
| 107 | --files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ | 110 | --files ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav \ |
| 108 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav | 111 | ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav |
| 109 | 112 | ||
| 110 | -(2) Streaming Paraformer models | 113 | +(2) Streaming Zipformer2 Ctc models |
| 114 | + | ||
| 115 | +dotnet run -c Release \ | ||
| 116 | + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ | ||
| 117 | + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 118 | + --files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ | ||
| 119 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ | ||
| 120 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \ | ||
| 121 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \ | ||
| 122 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \ | ||
| 123 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav | ||
| 124 | + | ||
| 125 | +(3) Streaming Paraformer models | ||
| 111 | dotnet run \ | 126 | dotnet run \ |
| 112 | --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ | 127 | --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ |
| 113 | --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ | 128 | --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ |
| @@ -121,6 +136,7 @@ dotnet run \ | @@ -121,6 +136,7 @@ dotnet run \ | ||
| 121 | Please refer to | 136 | Please refer to |
| 122 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html | 137 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html |
| 123 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html | 138 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html |
| 139 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html | ||
| 124 | to download pre-trained streaming models. | 140 | to download pre-trained streaming models. |
| 125 | "; | 141 | "; |
| 126 | 142 | ||
| @@ -150,6 +166,8 @@ to download pre-trained streaming models. | @@ -150,6 +166,8 @@ to download pre-trained streaming models. | ||
| 150 | config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder; | 166 | config.ModelConfig.Paraformer.Encoder = options.ParaformerEncoder; |
| 151 | config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder; | 167 | config.ModelConfig.Paraformer.Decoder = options.ParaformerDecoder; |
| 152 | 168 | ||
| 169 | + config.ModelConfig.Zipformer2Ctc.Model = options.Zipformer2Ctc; | ||
| 170 | + | ||
| 153 | config.ModelConfig.Tokens = options.Tokens; | 171 | config.ModelConfig.Tokens = options.Tokens; |
| 154 | config.ModelConfig.Provider = options.Provider; | 172 | config.ModelConfig.Provider = options.Provider; |
| 155 | config.ModelConfig.NumThreads = options.NumThreads; | 173 | config.ModelConfig.NumThreads = options.NumThreads; |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +# Please refer to | ||
| 4 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese | ||
| 5 | +# to download the model files | ||
| 6 | + | ||
| 7 | +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then | ||
| 8 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 9 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 10 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 11 | +fi | ||
| 12 | + | ||
| 13 | +dotnet run -c Release \ | ||
| 14 | + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ | ||
| 15 | + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 16 | + --files ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ | ||
| 17 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ | ||
| 18 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav \ | ||
| 19 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000113.wav \ | ||
| 20 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000219.wav \ | ||
| 21 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/TEST_MEETING_T0000000351.wav |
go-api-examples/.gitignore
0 → 100644
| 1 | +!*.sh |
| @@ -22,6 +22,7 @@ func main() { | @@ -22,6 +22,7 @@ func main() { | ||
| 22 | flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") | 22 | flag.StringVar(&config.ModelConfig.Transducer.Joiner, "joiner", "", "Path to the transducer joiner model") |
| 23 | flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") | 23 | flag.StringVar(&config.ModelConfig.Paraformer.Encoder, "paraformer-encoder", "", "Path to the paraformer encoder model") |
| 24 | flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") | 24 | flag.StringVar(&config.ModelConfig.Paraformer.Decoder, "paraformer-decoder", "", "Path to the paraformer decoder model") |
| 25 | + flag.StringVar(&config.ModelConfig.Zipformer2Ctc.Model, "zipformer2-ctc", "", "Path to the zipformer2 CTC model") | ||
| 25 | flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") | 26 | flag.StringVar(&config.ModelConfig.Tokens, "tokens", "", "Path to the tokens file") |
| 26 | flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") | 27 | flag.IntVar(&config.ModelConfig.NumThreads, "num-threads", 1, "Number of threads for computing") |
| 27 | flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") | 28 | flag.IntVar(&config.ModelConfig.Debug, "debug", 0, "Whether to show debug message") |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +# Please refer to | ||
| 4 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese | ||
| 5 | +# to download the model | ||
| 6 | +# before you run this script. | ||
| 7 | +# | ||
| 8 | +# You can switch to a different online model if you need | ||
| 9 | + | ||
| 10 | +./streaming-decode-files \ | ||
| 11 | + --zipformer2-ctc ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 12 | + --tokens ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ | ||
| 13 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav |
| @@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit { | @@ -8,7 +8,8 @@ fun callback(samples: FloatArray): Unit { | ||
| 8 | 8 | ||
| 9 | fun main() { | 9 | fun main() { |
| 10 | testTts() | 10 | testTts() |
| 11 | - testAsr() | 11 | + testAsr("transducer") |
| 12 | + testAsr("zipformer2-ctc") | ||
| 12 | } | 13 | } |
| 13 | 14 | ||
| 14 | fun testTts() { | 15 | fun testTts() { |
| @@ -30,25 +31,43 @@ fun testTts() { | @@ -30,25 +31,43 @@ fun testTts() { | ||
| 30 | audio.save(filename="test-en.wav") | 31 | audio.save(filename="test-en.wav") |
| 31 | } | 32 | } |
| 32 | 33 | ||
| 33 | -fun testAsr() { | 34 | +fun testAsr(type: String) { |
| 34 | var featConfig = FeatureConfig( | 35 | var featConfig = FeatureConfig( |
| 35 | sampleRate = 16000, | 36 | sampleRate = 16000, |
| 36 | featureDim = 80, | 37 | featureDim = 80, |
| 37 | ) | 38 | ) |
| 38 | 39 | ||
| 39 | - // please refer to | ||
| 40 | - // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 41 | - // to dowload pre-trained models | ||
| 42 | - var modelConfig = OnlineModelConfig( | ||
| 43 | - transducer = OnlineTransducerModelConfig( | ||
| 44 | - encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", | ||
| 45 | - decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", | ||
| 46 | - joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", | ||
| 47 | - ), | ||
| 48 | - tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", | ||
| 49 | - numThreads = 1, | ||
| 50 | - debug = false, | ||
| 51 | - ) | 40 | + var waveFilename: String |
| 41 | + var modelConfig: OnlineModelConfig = when (type) { | ||
| 42 | + "transducer" -> { | ||
| 43 | + waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav" | ||
| 44 | + // please refer to | ||
| 45 | + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 46 | + // to dowload pre-trained models | ||
| 47 | + OnlineModelConfig( | ||
| 48 | + transducer = OnlineTransducerModelConfig( | ||
| 49 | + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", | ||
| 50 | + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", | ||
| 51 | + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", | ||
| 52 | + ), | ||
| 53 | + tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", | ||
| 54 | + numThreads = 1, | ||
| 55 | + debug = false, | ||
| 56 | + ) | ||
| 57 | + } | ||
| 58 | + "zipformer2-ctc" -> { | ||
| 59 | + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" | ||
| 60 | + OnlineModelConfig( | ||
| 61 | + zipformer2Ctc = OnlineZipformer2CtcModelConfig( | ||
| 62 | + model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx", | ||
| 63 | + ), | ||
| 64 | + tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt", | ||
| 65 | + numThreads = 1, | ||
| 66 | + debug = false, | ||
| 67 | + ) | ||
| 68 | + } | ||
| 69 | + else -> throw IllegalArgumentException(type) | ||
| 70 | + } | ||
| 52 | 71 | ||
| 53 | var endpointConfig = EndpointConfig() | 72 | var endpointConfig = EndpointConfig() |
| 54 | 73 | ||
| @@ -69,7 +88,7 @@ fun testAsr() { | @@ -69,7 +88,7 @@ fun testAsr() { | ||
| 69 | ) | 88 | ) |
| 70 | 89 | ||
| 71 | var objArray = WaveReader.readWaveFromFile( | 90 | var objArray = WaveReader.readWaveFromFile( |
| 72 | - filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav", | 91 | + filename = waveFilename, |
| 73 | ) | 92 | ) |
| 74 | var samples: FloatArray = objArray[0] as FloatArray | 93 | var samples: FloatArray = objArray[0] as FloatArray |
| 75 | var sampleRate: Int = objArray[1] as Int | 94 | var sampleRate: Int = objArray[1] as Int |
| @@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then | @@ -34,6 +34,12 @@ if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then | ||
| 34 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 | 34 | git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 |
| 35 | fi | 35 | fi |
| 36 | 36 | ||
| 37 | +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then | ||
| 38 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 39 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 40 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 41 | +fi | ||
| 42 | + | ||
| 37 | if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then | 43 | if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then |
| 38 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 | 44 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 |
| 39 | tar xf vits-piper-en_US-amy-low.tar.bz2 | 45 | tar xf vits-piper-en_US-amy-low.tar.bz2 |
| @@ -85,7 +85,7 @@ npm install wav naudiodon2 | @@ -85,7 +85,7 @@ npm install wav naudiodon2 | ||
| 85 | how to decode a file with a NeMo CTC model. In the code we use | 85 | how to decode a file with a NeMo CTC model. In the code we use |
| 86 | [stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small). | 86 | [stt_en_conformer_ctc_small](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/english.html#stt-en-conformer-ctc-small). |
| 87 | 87 | ||
| 88 | -You can use the following command run it: | 88 | +You can use the following command to run it: |
| 89 | 89 | ||
| 90 | ```bash | 90 | ```bash |
| 91 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2 | 91 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-small.tar.bz2 |
| @@ -99,7 +99,7 @@ node ./test-offline-nemo-ctc.js | @@ -99,7 +99,7 @@ node ./test-offline-nemo-ctc.js | ||
| 99 | how to decode a file with a non-streaming Paraformer model. In the code we use | 99 | how to decode a file with a non-streaming Paraformer model. In the code we use |
| 100 | [sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese). | 100 | [sherpa-onnx-paraformer-zh-2023-03-28](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese). |
| 101 | 101 | ||
| 102 | -You can use the following command run it: | 102 | +You can use the following command to run it: |
| 103 | 103 | ||
| 104 | ```bash | 104 | ```bash |
| 105 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 | 105 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 |
| @@ -113,7 +113,7 @@ node ./test-offline-paraformer.js | @@ -113,7 +113,7 @@ node ./test-offline-paraformer.js | ||
| 113 | how to decode a file with a non-streaming transducer model. In the code we use | 113 | how to decode a file with a non-streaming transducer model. In the code we use |
| 114 | [sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english). | 114 | [sherpa-onnx-zipformer-en-2023-06-26](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-zipformer-en-2023-06-26-english). |
| 115 | 115 | ||
| 116 | -You can use the following command run it: | 116 | +You can use the following command to run it: |
| 117 | 117 | ||
| 118 | ```bash | 118 | ```bash |
| 119 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 | 119 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 |
| @@ -126,7 +126,7 @@ node ./test-offline-transducer.js | @@ -126,7 +126,7 @@ node ./test-offline-transducer.js | ||
| 126 | how to decode a file with a Whisper model. In the code we use | 126 | how to decode a file with a Whisper model. In the code we use |
| 127 | [sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html). | 127 | [sherpa-onnx-whisper-tiny.en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html). |
| 128 | 128 | ||
| 129 | -You can use the following command run it: | 129 | +You can use the following command to run it: |
| 130 | 130 | ||
| 131 | ```bash | 131 | ```bash |
| 132 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 | 132 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 |
| @@ -140,7 +140,7 @@ demonstrates how to do real-time speech recognition from microphone | @@ -140,7 +140,7 @@ demonstrates how to do real-time speech recognition from microphone | ||
| 140 | with a streaming Paraformer model. In the code we use | 140 | with a streaming Paraformer model. In the code we use |
| 141 | [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). | 141 | [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). |
| 142 | 142 | ||
| 143 | -You can use the following command run it: | 143 | +You can use the following command to run it: |
| 144 | 144 | ||
| 145 | ```bash | 145 | ```bash |
| 146 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | 146 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 |
| @@ -153,7 +153,7 @@ node ./test-online-paraformer-microphone.js | @@ -153,7 +153,7 @@ node ./test-online-paraformer-microphone.js | ||
| 153 | how to decode a file using a streaming Paraformer model. In the code we use | 153 | how to decode a file using a streaming Paraformer model. In the code we use |
| 154 | [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). | 154 | [sherpa-onnx-streaming-paraformer-bilingual-zh-en](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-streaming-paraformer-bilingual-zh-en-chinese-english). |
| 155 | 155 | ||
| 156 | -You can use the following command run it: | 156 | +You can use the following command to run it: |
| 157 | 157 | ||
| 158 | ```bash | 158 | ```bash |
| 159 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | 159 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 |
| @@ -167,7 +167,7 @@ demonstrates how to do real-time speech recognition with microphone using a stre | @@ -167,7 +167,7 @@ demonstrates how to do real-time speech recognition with microphone using a stre | ||
| 167 | we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). | 167 | we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). |
| 168 | 168 | ||
| 169 | 169 | ||
| 170 | -You can use the following command run it: | 170 | +You can use the following command to run it: |
| 171 | 171 | ||
| 172 | ```bash | 172 | ```bash |
| 173 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | 173 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 |
| @@ -180,7 +180,7 @@ node ./test-online-transducer-microphone.js | @@ -180,7 +180,7 @@ node ./test-online-transducer-microphone.js | ||
| 180 | how to decode a file using a streaming transducer model. In the code | 180 | how to decode a file using a streaming transducer model. In the code |
| 181 | we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). | 181 | we use [sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english). |
| 182 | 182 | ||
| 183 | -You can use the following command run it: | 183 | +You can use the following command to run it: |
| 184 | 184 | ||
| 185 | ```bash | 185 | ```bash |
| 186 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | 186 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 |
| @@ -188,13 +188,26 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | @@ -188,13 +188,26 @@ tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 188 | node ./test-online-transducer.js | 188 | node ./test-online-transducer.js |
| 189 | ``` | 189 | ``` |
| 190 | 190 | ||
| 191 | +## ./test-online-zipformer2-ctc.js | ||
| 192 | +[./test-online-zipformer2-ctc.js](./test-online-zipformer2-ctc.js) demonstrates | ||
| 193 | +how to decode a file using a streaming zipformer2 CTC model. In the code | ||
| 194 | +we use [sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html#sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13-chinese). | ||
| 195 | + | ||
| 196 | +You can use the following command to run it: | ||
| 197 | + | ||
| 198 | +```bash | ||
| 199 | +wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 200 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 201 | +node ./test-online-zipformer2-ctc.js | ||
| 202 | +``` | ||
| 203 | + | ||
| 191 | ## ./test-vad-microphone-offline-paraformer.js | 204 | ## ./test-vad-microphone-offline-paraformer.js |
| 192 | 205 | ||
| 193 | [./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js) | 206 | [./test-vad-microphone-offline-paraformer.js](./test-vad-microphone-offline-paraformer.js) |
| 194 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) | 207 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) |
| 195 | with non-streaming Paraformer for speech recognition from microphone. | 208 | with non-streaming Paraformer for speech recognition from microphone. |
| 196 | 209 | ||
| 197 | -You can use the following command run it: | 210 | +You can use the following command to run it: |
| 198 | 211 | ||
| 199 | ```bash | 212 | ```bash |
| 200 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 213 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| @@ -209,7 +222,7 @@ node ./test-vad-microphone-offline-paraformer.js | @@ -209,7 +222,7 @@ node ./test-vad-microphone-offline-paraformer.js | ||
| 209 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) | 222 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) |
| 210 | with a non-streaming transducer model for speech recognition from microphone. | 223 | with a non-streaming transducer model for speech recognition from microphone. |
| 211 | 224 | ||
| 212 | -You can use the following command run it: | 225 | +You can use the following command to run it: |
| 213 | 226 | ||
| 214 | ```bash | 227 | ```bash |
| 215 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 228 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| @@ -224,7 +237,7 @@ node ./test-vad-microphone-offline-transducer.js | @@ -224,7 +237,7 @@ node ./test-vad-microphone-offline-transducer.js | ||
| 224 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) | 237 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad) |
| 225 | with whisper for speech recognition from microphone. | 238 | with whisper for speech recognition from microphone. |
| 226 | 239 | ||
| 227 | -You can use the following command run it: | 240 | +You can use the following command to run it: |
| 228 | 241 | ||
| 229 | ```bash | 242 | ```bash |
| 230 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 243 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| @@ -238,7 +251,7 @@ node ./test-vad-microphone-offline-whisper.js | @@ -238,7 +251,7 @@ node ./test-vad-microphone-offline-whisper.js | ||
| 238 | [./test-vad-microphone.js](./test-vad-microphone.js) | 251 | [./test-vad-microphone.js](./test-vad-microphone.js) |
| 239 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad). | 252 | demonstrates how to use [silero-vad](https://github.com/snakers4/silero-vad). |
| 240 | 253 | ||
| 241 | -You can use the following command run it: | 254 | +You can use the following command to run it: |
| 242 | 255 | ||
| 243 | ```bash | 256 | ```bash |
| 244 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 257 | wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| 1 | +// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 2 | +// | ||
| 3 | +const fs = require('fs'); | ||
| 4 | +const {Readable} = require('stream'); | ||
| 5 | +const wav = require('wav'); | ||
| 6 | + | ||
| 7 | +const sherpa_onnx = require('sherpa-onnx'); | ||
| 8 | + | ||
| 9 | +function createRecognizer() { | ||
| 10 | + const featConfig = new sherpa_onnx.FeatureConfig(); | ||
| 11 | + featConfig.sampleRate = 16000; | ||
| 12 | + featConfig.featureDim = 80; | ||
| 13 | + | ||
| 14 | + // test online recognizer | ||
| 15 | + const zipformer2Ctc = new sherpa_onnx.OnlineZipformer2CtcModelConfig(); | ||
| 16 | + zipformer2Ctc.model = | ||
| 17 | + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx'; | ||
| 18 | + const tokens = | ||
| 19 | + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt'; | ||
| 20 | + | ||
| 21 | + const modelConfig = new sherpa_onnx.OnlineModelConfig(); | ||
| 22 | + modelConfig.zipformer2Ctc = zipformer2Ctc; | ||
| 23 | + modelConfig.tokens = tokens; | ||
| 24 | + | ||
| 25 | + const recognizerConfig = new sherpa_onnx.OnlineRecognizerConfig(); | ||
| 26 | + recognizerConfig.featConfig = featConfig; | ||
| 27 | + recognizerConfig.modelConfig = modelConfig; | ||
| 28 | + recognizerConfig.decodingMethod = 'greedy_search'; | ||
| 29 | + | ||
| 30 | + recognizer = new sherpa_onnx.OnlineRecognizer(recognizerConfig); | ||
| 31 | + return recognizer; | ||
| 32 | +} | ||
| 33 | +recognizer = createRecognizer(); | ||
| 34 | +stream = recognizer.createStream(); | ||
| 35 | + | ||
| 36 | +const waveFilename = | ||
| 37 | + './sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav'; | ||
| 38 | + | ||
| 39 | +const reader = new wav.Reader(); | ||
| 40 | +const readable = new Readable().wrap(reader); | ||
| 41 | + | ||
| 42 | +function decode(samples) { | ||
| 43 | + stream.acceptWaveform(recognizer.config.featConfig.sampleRate, samples); | ||
| 44 | + | ||
| 45 | + while (recognizer.isReady(stream)) { | ||
| 46 | + recognizer.decode(stream); | ||
| 47 | + } | ||
| 48 | + const r = recognizer.getResult(stream); | ||
| 49 | + console.log(r.text); | ||
| 50 | +} | ||
| 51 | + | ||
| 52 | +reader.on('format', ({audioFormat, bitDepth, channels, sampleRate}) => { | ||
| 53 | + if (sampleRate != recognizer.config.featConfig.sampleRate) { | ||
| 54 | + throw new Error(`Only support sampleRate ${ | ||
| 55 | + recognizer.config.featConfig.sampleRate}. Given ${sampleRate}`); | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + if (audioFormat != 1) { | ||
| 59 | + throw new Error(`Only support PCM format. Given ${audioFormat}`); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + if (channels != 1) { | ||
| 63 | + throw new Error(`Only a single channel. Given ${channel}`); | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + if (bitDepth != 16) { | ||
| 67 | + throw new Error(`Only support 16-bit samples. Given ${bitDepth}`); | ||
| 68 | + } | ||
| 69 | +}); | ||
| 70 | + | ||
| 71 | +fs.createReadStream(waveFilename, {'highWaterMark': 4096}) | ||
| 72 | + .pipe(reader) | ||
| 73 | + .on('finish', function(err) { | ||
| 74 | + // tail padding | ||
| 75 | + const floatSamples = | ||
| 76 | + new Float32Array(recognizer.config.featConfig.sampleRate * 0.5); | ||
| 77 | + decode(floatSamples); | ||
| 78 | + stream.free(); | ||
| 79 | + recognizer.free(); | ||
| 80 | + }); | ||
| 81 | + | ||
| 82 | +readable.on('readable', function() { | ||
| 83 | + let chunk; | ||
| 84 | + while ((chunk = readable.read()) != null) { | ||
| 85 | + const int16Samples = new Int16Array( | ||
| 86 | + chunk.buffer, chunk.byteOffset, | ||
| 87 | + chunk.length / Int16Array.BYTES_PER_ELEMENT); | ||
| 88 | + | ||
| 89 | + const floatSamples = new Float32Array(int16Samples.length); | ||
| 90 | + | ||
| 91 | + for (let i = 0; i < floatSamples.length; i++) { | ||
| 92 | + floatSamples[i] = int16Samples[i] / 32768.0; | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + decode(floatSamples); | ||
| 96 | + } | ||
| 97 | +}); |
| @@ -37,7 +37,20 @@ git lfs pull --include "*.onnx" | @@ -37,7 +37,20 @@ git lfs pull --include "*.onnx" | ||
| 37 | ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ | 37 | ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ |
| 38 | ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav | 38 | ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav |
| 39 | 39 | ||
| 40 | -(3) Streaming Conformer CTC from WeNet | 40 | +(3) Streaming Zipformer2 CTC |
| 41 | + | ||
| 42 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 43 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 44 | +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 45 | +ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 46 | + | ||
| 47 | +./python-api-examples/online-decode-files.py \ | ||
| 48 | + --zipformer2-ctc=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ | ||
| 49 | + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ | ||
| 50 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ | ||
| 51 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav | ||
| 52 | + | ||
| 53 | +(4) Streaming Conformer CTC from WeNet | ||
| 41 | 54 | ||
| 42 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech | 55 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech |
| 43 | cd sherpa-onnx-zh-wenet-wenetspeech | 56 | cd sherpa-onnx-zh-wenet-wenetspeech |
| @@ -51,12 +64,9 @@ git lfs pull --include "*.onnx" | @@ -51,12 +64,9 @@ git lfs pull --include "*.onnx" | ||
| 51 | ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav | 64 | ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav |
| 52 | 65 | ||
| 53 | 66 | ||
| 54 | - | ||
| 55 | Please refer to | 67 | Please refer to |
| 56 | -https://k2-fsa.github.io/sherpa/onnx/index.html | ||
| 57 | -and | ||
| 58 | -https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html | ||
| 59 | -to install sherpa-onnx and to download streaming pre-trained models. | 68 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 69 | +to download streaming pre-trained models. | ||
| 60 | """ | 70 | """ |
| 61 | import argparse | 71 | import argparse |
| 62 | import time | 72 | import time |
| @@ -98,6 +108,12 @@ def get_args(): | @@ -98,6 +108,12 @@ def get_args(): | ||
| 98 | ) | 108 | ) |
| 99 | 109 | ||
| 100 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | + "--zipformer2-ctc", | ||
| 112 | + type=str, | ||
| 113 | + help="Path to the zipformer2 ctc model", | ||
| 114 | + ) | ||
| 115 | + | ||
| 116 | + parser.add_argument( | ||
| 101 | "--paraformer-encoder", | 117 | "--paraformer-encoder", |
| 102 | type=str, | 118 | type=str, |
| 103 | help="Path to the paraformer encoder model", | 119 | help="Path to the paraformer encoder model", |
| @@ -112,7 +128,7 @@ def get_args(): | @@ -112,7 +128,7 @@ def get_args(): | ||
| 112 | parser.add_argument( | 128 | parser.add_argument( |
| 113 | "--wenet-ctc", | 129 | "--wenet-ctc", |
| 114 | type=str, | 130 | type=str, |
| 115 | - help="Path to the wenet ctc model model", | 131 | + help="Path to the wenet ctc model", |
| 116 | ) | 132 | ) |
| 117 | 133 | ||
| 118 | parser.add_argument( | 134 | parser.add_argument( |
| @@ -275,6 +291,16 @@ def main(): | @@ -275,6 +291,16 @@ def main(): | ||
| 275 | hotwords_file=args.hotwords_file, | 291 | hotwords_file=args.hotwords_file, |
| 276 | hotwords_score=args.hotwords_score, | 292 | hotwords_score=args.hotwords_score, |
| 277 | ) | 293 | ) |
| 294 | + elif args.zipformer2_ctc: | ||
| 295 | + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( | ||
| 296 | + tokens=args.tokens, | ||
| 297 | + model=args.zipformer2_ctc, | ||
| 298 | + num_threads=args.num_threads, | ||
| 299 | + provider=args.provider, | ||
| 300 | + sample_rate=16000, | ||
| 301 | + feature_dim=80, | ||
| 302 | + decoding_method="greedy_search", | ||
| 303 | + ) | ||
| 278 | elif args.paraformer_encoder: | 304 | elif args.paraformer_encoder: |
| 279 | recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( | 305 | recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( |
| 280 | tokens=args.tokens, | 306 | tokens=args.tokens, |
| @@ -25,6 +25,7 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | @@ -25,6 +25,7 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | ||
| 25 | 25 | ||
| 26 | import argparse | 26 | import argparse |
| 27 | import asyncio | 27 | import asyncio |
| 28 | +import json | ||
| 28 | import logging | 29 | import logging |
| 29 | import wave | 30 | import wave |
| 30 | 31 | ||
| @@ -112,7 +113,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): | @@ -112,7 +113,7 @@ async def receive_results(socket: websockets.WebSocketServerProtocol): | ||
| 112 | async for message in socket: | 113 | async for message in socket: |
| 113 | if message != "Done!": | 114 | if message != "Done!": |
| 114 | last_message = message | 115 | last_message = message |
| 115 | - logging.info(message) | 116 | + logging.info(json.loads(message)) |
| 116 | else: | 117 | else: |
| 117 | break | 118 | break |
| 118 | return last_message | 119 | return last_message |
| @@ -151,7 +152,7 @@ async def run( | @@ -151,7 +152,7 @@ async def run( | ||
| 151 | await websocket.send("Done") | 152 | await websocket.send("Done") |
| 152 | 153 | ||
| 153 | decoding_results = await receive_task | 154 | decoding_results = await receive_task |
| 154 | - logging.info(f"\nFinal result is:\n{decoding_results}") | 155 | + logging.info(f"\nFinal result is:\n{json.loads(decoding_results)}") |
| 155 | 156 | ||
| 156 | 157 | ||
| 157 | async def main(): | 158 | async def main(): |
| @@ -138,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser): | @@ -138,6 +138,12 @@ def add_model_args(parser: argparse.ArgumentParser): | ||
| 138 | ) | 138 | ) |
| 139 | 139 | ||
| 140 | parser.add_argument( | 140 | parser.add_argument( |
| 141 | + "--zipformer2-ctc", | ||
| 142 | + type=str, | ||
| 143 | + help="Path to the model file from zipformer2 ctc", | ||
| 144 | + ) | ||
| 145 | + | ||
| 146 | + parser.add_argument( | ||
| 141 | "--wenet-ctc", | 147 | "--wenet-ctc", |
| 142 | type=str, | 148 | type=str, |
| 143 | help="Path to the model.onnx from WeNet", | 149 | help="Path to the model.onnx from WeNet", |
| @@ -405,6 +411,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | @@ -405,6 +411,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 405 | rule3_min_utterance_length=args.rule3_min_utterance_length, | 411 | rule3_min_utterance_length=args.rule3_min_utterance_length, |
| 406 | provider=args.provider, | 412 | provider=args.provider, |
| 407 | ) | 413 | ) |
| 414 | + elif args.zipformer2_ctc: | ||
| 415 | + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( | ||
| 416 | + tokens=args.tokens, | ||
| 417 | + model=args.zipformer2_ctc, | ||
| 418 | + num_threads=args.num_threads, | ||
| 419 | + sample_rate=args.sample_rate, | ||
| 420 | + feature_dim=args.feat_dim, | ||
| 421 | + decoding_method=args.decoding_method, | ||
| 422 | + enable_endpoint_detection=args.use_endpoint != 0, | ||
| 423 | + rule1_min_trailing_silence=args.rule1_min_trailing_silence, | ||
| 424 | + rule2_min_trailing_silence=args.rule2_min_trailing_silence, | ||
| 425 | + rule3_min_utterance_length=args.rule3_min_utterance_length, | ||
| 426 | + provider=args.provider, | ||
| 427 | + ) | ||
| 408 | elif args.wenet_ctc: | 428 | elif args.wenet_ctc: |
| 409 | recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( | 429 | recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( |
| 410 | tokens=args.tokens, | 430 | tokens=args.tokens, |
| @@ -748,6 +768,8 @@ def check_args(args): | @@ -748,6 +768,8 @@ def check_args(args): | ||
| 748 | 768 | ||
| 749 | assert args.paraformer_encoder is None, args.paraformer_encoder | 769 | assert args.paraformer_encoder is None, args.paraformer_encoder |
| 750 | assert args.paraformer_decoder is None, args.paraformer_decoder | 770 | assert args.paraformer_decoder is None, args.paraformer_decoder |
| 771 | + assert args.zipformer2_ctc is None, args.zipformer2_ctc | ||
| 772 | + assert args.wenet_ctc is None, args.wenet_ctc | ||
| 751 | elif args.paraformer_encoder: | 773 | elif args.paraformer_encoder: |
| 752 | assert Path( | 774 | assert Path( |
| 753 | args.paraformer_encoder | 775 | args.paraformer_encoder |
| @@ -756,6 +778,10 @@ def check_args(args): | @@ -756,6 +778,10 @@ def check_args(args): | ||
| 756 | assert Path( | 778 | assert Path( |
| 757 | args.paraformer_decoder | 779 | args.paraformer_decoder |
| 758 | ).is_file(), f"{args.paraformer_decoder} does not exist" | 780 | ).is_file(), f"{args.paraformer_decoder} does not exist" |
| 781 | + elif args.zipformer2_ctc: | ||
| 782 | + assert Path( | ||
| 783 | + args.zipformer2_ctc | ||
| 784 | + ).is_file(), f"{args.zipformer2_ctc} does not exist" | ||
| 759 | elif args.wenet_ctc: | 785 | elif args.wenet_ctc: |
| 760 | assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" | 786 | assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" |
| 761 | else: | 787 | else: |
| @@ -51,12 +51,25 @@ namespace SherpaOnnx | @@ -51,12 +51,25 @@ namespace SherpaOnnx | ||
| 51 | } | 51 | } |
| 52 | 52 | ||
| 53 | [StructLayout(LayoutKind.Sequential)] | 53 | [StructLayout(LayoutKind.Sequential)] |
| 54 | + public struct OnlineZipformer2CtcModelConfig | ||
| 55 | + { | ||
| 56 | + public OnlineZipformer2CtcModelConfig() | ||
| 57 | + { | ||
| 58 | + Model = ""; | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 62 | + public string Model; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + [StructLayout(LayoutKind.Sequential)] | ||
| 54 | public struct OnlineModelConfig | 66 | public struct OnlineModelConfig |
| 55 | { | 67 | { |
| 56 | public OnlineModelConfig() | 68 | public OnlineModelConfig() |
| 57 | { | 69 | { |
| 58 | Transducer = new OnlineTransducerModelConfig(); | 70 | Transducer = new OnlineTransducerModelConfig(); |
| 59 | Paraformer = new OnlineParaformerModelConfig(); | 71 | Paraformer = new OnlineParaformerModelConfig(); |
| 72 | + Zipformer2Ctc = new OnlineZipformer2CtcModelConfig(); | ||
| 60 | Tokens = ""; | 73 | Tokens = ""; |
| 61 | NumThreads = 1; | 74 | NumThreads = 1; |
| 62 | Provider = "cpu"; | 75 | Provider = "cpu"; |
| @@ -66,6 +79,7 @@ namespace SherpaOnnx | @@ -66,6 +79,7 @@ namespace SherpaOnnx | ||
| 66 | 79 | ||
| 67 | public OnlineTransducerModelConfig Transducer; | 80 | public OnlineTransducerModelConfig Transducer; |
| 68 | public OnlineParaformerModelConfig Paraformer; | 81 | public OnlineParaformerModelConfig Paraformer; |
| 82 | + public OnlineZipformer2CtcModelConfig Zipformer2Ctc; | ||
| 69 | 83 | ||
| 70 | [MarshalAs(UnmanagedType.LPStr)] | 84 | [MarshalAs(UnmanagedType.LPStr)] |
| 71 | public string Tokens; | 85 | public string Tokens; |
| 1 | +../../../../go-api-examples/streaming-decode-files/run-zipformer2-ctc.sh |
| @@ -65,6 +65,13 @@ type OnlineParaformerModelConfig struct { | @@ -65,6 +65,13 @@ type OnlineParaformerModelConfig struct { | ||
| 65 | Decoder string // Path to the decoder model. | 65 | Decoder string // Path to the decoder model. |
| 66 | } | 66 | } |
| 67 | 67 | ||
| 68 | +// Please refer to | ||
| 69 | +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html | ||
| 70 | +// to download pre-trained models | ||
| 71 | +type OnlineZipformer2CtcModelConfig struct { | ||
| 72 | + Model string // Path to the onnx model | ||
| 73 | +} | ||
| 74 | + | ||
| 68 | // Configuration for online/streaming models | 75 | // Configuration for online/streaming models |
| 69 | // | 76 | // |
| 70 | // Please refer to | 77 | // Please refer to |
| @@ -72,13 +79,14 @@ type OnlineParaformerModelConfig struct { | @@ -72,13 +79,14 @@ type OnlineParaformerModelConfig struct { | ||
| 72 | // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html | 79 | // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html |
| 73 | // to download pre-trained models | 80 | // to download pre-trained models |
| 74 | type OnlineModelConfig struct { | 81 | type OnlineModelConfig struct { |
| 75 | - Transducer OnlineTransducerModelConfig | ||
| 76 | - Paraformer OnlineParaformerModelConfig | ||
| 77 | - Tokens string // Path to tokens.txt | ||
| 78 | - NumThreads int // Number of threads to use for neural network computation | ||
| 79 | - Provider string // Optional. Valid values are: cpu, cuda, coreml | ||
| 80 | - Debug int // 1 to show model meta information while loading it. | ||
| 81 | - ModelType string // Optional. You can specify it for faster model initialization | 82 | + Transducer OnlineTransducerModelConfig |
| 83 | + Paraformer OnlineParaformerModelConfig | ||
| 84 | + Zipformer2Ctc OnlineZipformer2CtcModelConfig | ||
| 85 | + Tokens string // Path to tokens.txt | ||
| 86 | + NumThreads int // Number of threads to use for neural network computation | ||
| 87 | + Provider string // Optional. Valid values are: cpu, cuda, coreml | ||
| 88 | + Debug int // 1 to show model meta information while loading it. | ||
| 89 | + ModelType string // Optional. You can specify it for faster model initialization | ||
| 82 | } | 90 | } |
| 83 | 91 | ||
| 84 | // Configuration for the feature extractor | 92 | // Configuration for the feature extractor |
| @@ -157,6 +165,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | @@ -157,6 +165,9 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | ||
| 157 | c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) | 165 | c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) |
| 158 | defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) | 166 | defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) |
| 159 | 167 | ||
| 168 | + c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model) | ||
| 169 | + defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model)) | ||
| 170 | + | ||
| 160 | c.model_config.tokens = C.CString(config.ModelConfig.Tokens) | 171 | c.model_config.tokens = C.CString(config.ModelConfig.Tokens) |
| 161 | defer C.free(unsafe.Pointer(c.model_config.tokens)) | 172 | defer C.free(unsafe.Pointer(c.model_config.tokens)) |
| 162 | 173 |
| @@ -41,9 +41,14 @@ const SherpaOnnxOnlineParaformerModelConfig = StructType({ | @@ -41,9 +41,14 @@ const SherpaOnnxOnlineParaformerModelConfig = StructType({ | ||
| 41 | "decoder" : cstring, | 41 | "decoder" : cstring, |
| 42 | }); | 42 | }); |
| 43 | 43 | ||
| 44 | +const SherpaOnnxOnlineZipformer2CtcModelConfig = StructType({ | ||
| 45 | + "model" : cstring, | ||
| 46 | +}); | ||
| 47 | + | ||
| 44 | const SherpaOnnxOnlineModelConfig = StructType({ | 48 | const SherpaOnnxOnlineModelConfig = StructType({ |
| 45 | "transducer" : SherpaOnnxOnlineTransducerModelConfig, | 49 | "transducer" : SherpaOnnxOnlineTransducerModelConfig, |
| 46 | "paraformer" : SherpaOnnxOnlineParaformerModelConfig, | 50 | "paraformer" : SherpaOnnxOnlineParaformerModelConfig, |
| 51 | + "zipformer2Ctc" : SherpaOnnxOnlineZipformer2CtcModelConfig, | ||
| 47 | "tokens" : cstring, | 52 | "tokens" : cstring, |
| 48 | "numThreads" : int32_t, | 53 | "numThreads" : int32_t, |
| 49 | "provider" : cstring, | 54 | "provider" : cstring, |
| @@ -663,6 +668,7 @@ const OnlineModelConfig = SherpaOnnxOnlineModelConfig; | @@ -663,6 +668,7 @@ const OnlineModelConfig = SherpaOnnxOnlineModelConfig; | ||
| 663 | const FeatureConfig = SherpaOnnxFeatureConfig; | 668 | const FeatureConfig = SherpaOnnxFeatureConfig; |
| 664 | const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig; | 669 | const OnlineRecognizerConfig = SherpaOnnxOnlineRecognizerConfig; |
| 665 | const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig; | 670 | const OnlineParaformerModelConfig = SherpaOnnxOnlineParaformerModelConfig; |
| 671 | +const OnlineZipformer2CtcModelConfig = SherpaOnnxOnlineZipformer2CtcModelConfig; | ||
| 666 | 672 | ||
| 667 | // offline asr | 673 | // offline asr |
| 668 | const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig; | 674 | const OfflineTransducerModelConfig = SherpaOnnxOfflineTransducerModelConfig; |
| @@ -692,6 +698,7 @@ module.exports = { | @@ -692,6 +698,7 @@ module.exports = { | ||
| 692 | OnlineRecognizer, | 698 | OnlineRecognizer, |
| 693 | OnlineStream, | 699 | OnlineStream, |
| 694 | OnlineParaformerModelConfig, | 700 | OnlineParaformerModelConfig, |
| 701 | + OnlineZipformer2CtcModelConfig, | ||
| 695 | 702 | ||
| 696 | // offline asr | 703 | // offline asr |
| 697 | OfflineRecognizer, | 704 | OfflineRecognizer, |
| @@ -54,6 +54,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -54,6 +54,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 54 | recognizer_config.model_config.paraformer.decoder = | 54 | recognizer_config.model_config.paraformer.decoder = |
| 55 | SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); | 55 | SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); |
| 56 | 56 | ||
| 57 | + recognizer_config.model_config.zipformer2_ctc.model = | ||
| 58 | + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); | ||
| 59 | + | ||
| 57 | recognizer_config.model_config.tokens = | 60 | recognizer_config.model_config.tokens = |
| 58 | SHERPA_ONNX_OR(config->model_config.tokens, ""); | 61 | SHERPA_ONNX_OR(config->model_config.tokens, ""); |
| 59 | recognizer_config.model_config.num_threads = | 62 | recognizer_config.model_config.num_threads = |
| @@ -66,9 +66,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig { | @@ -66,9 +66,17 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineParaformerModelConfig { | ||
| 66 | const char *decoder; | 66 | const char *decoder; |
| 67 | } SherpaOnnxOnlineParaformerModelConfig; | 67 | } SherpaOnnxOnlineParaformerModelConfig; |
| 68 | 68 | ||
| 69 | -SHERPA_ONNX_API typedef struct SherpaOnnxModelConfig { | 69 | +// Please visit |
| 70 | +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html# | ||
| 71 | +// to download pre-trained streaming zipformer2 ctc models | ||
| 72 | +SHERPA_ONNX_API typedef struct SherpaOnnxOnlineZipformer2CtcModelConfig { | ||
| 73 | + const char *model; | ||
| 74 | +} SherpaOnnxOnlineZipformer2CtcModelConfig; | ||
| 75 | + | ||
| 76 | +SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | ||
| 70 | SherpaOnnxOnlineTransducerModelConfig transducer; | 77 | SherpaOnnxOnlineTransducerModelConfig transducer; |
| 71 | SherpaOnnxOnlineParaformerModelConfig paraformer; | 78 | SherpaOnnxOnlineParaformerModelConfig paraformer; |
| 79 | + SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc; | ||
| 72 | const char *tokens; | 80 | const char *tokens; |
| 73 | int32_t num_threads; | 81 | int32_t num_threads; |
| 74 | const char *provider; | 82 | const char *provider; |
| @@ -70,6 +70,8 @@ set(sources | @@ -70,6 +70,8 @@ set(sources | ||
| 70 | online-wenet-ctc-model-config.cc | 70 | online-wenet-ctc-model-config.cc |
| 71 | online-wenet-ctc-model.cc | 71 | online-wenet-ctc-model.cc |
| 72 | online-zipformer-transducer-model.cc | 72 | online-zipformer-transducer-model.cc |
| 73 | + online-zipformer2-ctc-model-config.cc | ||
| 74 | + online-zipformer2-ctc-model.cc | ||
| 73 | online-zipformer2-transducer-model.cc | 75 | online-zipformer2-transducer-model.cc |
| 74 | onnx-utils.cc | 76 | onnx-utils.cc |
| 75 | packed-sequence.cc | 77 | packed-sequence.cc |
| @@ -12,6 +12,9 @@ | @@ -12,6 +12,9 @@ | ||
| 12 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 13 | 13 | ||
| 14 | struct OnlineCtcDecoderResult { | 14 | struct OnlineCtcDecoderResult { |
| 15 | + /// Number of frames after subsampling we have decoded so far | ||
| 16 | + int32_t frame_offset = 0; | ||
| 17 | + | ||
| 15 | /// The decoded token IDs | 18 | /// The decoded token IDs |
| 16 | std::vector<int64_t> tokens; | 19 | std::vector<int64_t> tokens; |
| 17 | 20 |
| @@ -49,12 +49,17 @@ void OnlineCtcGreedySearchDecoder::Decode( | @@ -49,12 +49,17 @@ void OnlineCtcGreedySearchDecoder::Decode( | ||
| 49 | 49 | ||
| 50 | if (y != blank_id_ && y != prev_id) { | 50 | if (y != blank_id_ && y != prev_id) { |
| 51 | r.tokens.push_back(y); | 51 | r.tokens.push_back(y); |
| 52 | - r.timestamps.push_back(t); | 52 | + r.timestamps.push_back(t + r.frame_offset); |
| 53 | } | 53 | } |
| 54 | 54 | ||
| 55 | prev_id = y; | 55 | prev_id = y; |
| 56 | } // for (int32_t t = 0; t != num_frames; ++t) { | 56 | } // for (int32_t t = 0; t != num_frames; ++t) { |
| 57 | } // for (int32_t b = 0; b != batch_size; ++b) | 57 | } // for (int32_t b = 0; b != batch_size; ++b) |
| 58 | + | ||
| 59 | + // Update frame_offset | ||
| 60 | + for (auto &r : *results) { | ||
| 61 | + r.frame_offset += num_frames; | ||
| 62 | + } | ||
| 58 | } | 63 | } |
| 59 | 64 | ||
| 60 | } // namespace sherpa_onnx | 65 | } // namespace sherpa_onnx |
| @@ -11,127 +11,35 @@ | @@ -11,127 +11,35 @@ | ||
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" | 13 | #include "sherpa-onnx/csrc/online-wenet-ctc-model.h" |
| 14 | +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" | ||
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 15 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | 16 | ||
| 16 | -namespace { | ||
| 17 | - | ||
| 18 | -enum class ModelType { | ||
| 19 | - kZipformerCtc, | ||
| 20 | - kWenetCtc, | ||
| 21 | - kUnkown, | ||
| 22 | -}; | ||
| 23 | - | ||
| 24 | -} // namespace | ||
| 25 | - | ||
| 26 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 27 | 18 | ||
| 28 | -static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 29 | - bool debug) { | ||
| 30 | - Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 31 | - Ort::SessionOptions sess_opts; | ||
| 32 | - | ||
| 33 | - auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, | ||
| 34 | - sess_opts); | ||
| 35 | - | ||
| 36 | - Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 37 | - if (debug) { | ||
| 38 | - std::ostringstream os; | ||
| 39 | - PrintModelMetadata(os, meta_data); | ||
| 40 | - SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 41 | - } | ||
| 42 | - | ||
| 43 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 44 | - auto model_type = | ||
| 45 | - meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 46 | - if (!model_type) { | ||
| 47 | - SHERPA_ONNX_LOGE( | ||
| 48 | - "No model_type in the metadata!\n" | ||
| 49 | - "If you are using models from WeNet, please refer to\n" | ||
| 50 | - "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" | ||
| 51 | - "run.sh\n" | ||
| 52 | - "\n" | ||
| 53 | - "for how to add metadta to model.onnx\n"); | ||
| 54 | - return ModelType::kUnkown; | ||
| 55 | - } | ||
| 56 | - | ||
| 57 | - if (model_type.get() == std::string("zipformer2")) { | ||
| 58 | - return ModelType::kZipformerCtc; | ||
| 59 | - } else if (model_type.get() == std::string("wenet_ctc")) { | ||
| 60 | - return ModelType::kWenetCtc; | ||
| 61 | - } else { | ||
| 62 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | ||
| 63 | - return ModelType::kUnkown; | ||
| 64 | - } | ||
| 65 | -} | ||
| 66 | - | ||
| 67 | std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | 19 | std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( |
| 68 | const OnlineModelConfig &config) { | 20 | const OnlineModelConfig &config) { |
| 69 | - ModelType model_type = ModelType::kUnkown; | ||
| 70 | - | ||
| 71 | - std::string filename; | ||
| 72 | if (!config.wenet_ctc.model.empty()) { | 21 | if (!config.wenet_ctc.model.empty()) { |
| 73 | - filename = config.wenet_ctc.model; | 22 | + return std::make_unique<OnlineWenetCtcModel>(config); |
| 23 | + } else if (!config.zipformer2_ctc.model.empty()) { | ||
| 24 | + return std::make_unique<OnlineZipformer2CtcModel>(config); | ||
| 74 | } else { | 25 | } else { |
| 75 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 26 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 76 | exit(-1); | 27 | exit(-1); |
| 77 | } | 28 | } |
| 78 | - | ||
| 79 | - { | ||
| 80 | - auto buffer = ReadFile(filename); | ||
| 81 | - | ||
| 82 | - model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 83 | - } | ||
| 84 | - | ||
| 85 | - switch (model_type) { | ||
| 86 | - case ModelType::kZipformerCtc: | ||
| 87 | - return nullptr; | ||
| 88 | - // return std::make_unique<OnlineZipformerCtcModel>(config); | ||
| 89 | - break; | ||
| 90 | - case ModelType::kWenetCtc: | ||
| 91 | - return std::make_unique<OnlineWenetCtcModel>(config); | ||
| 92 | - break; | ||
| 93 | - case ModelType::kUnkown: | ||
| 94 | - SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); | ||
| 95 | - return nullptr; | ||
| 96 | - } | ||
| 97 | - | ||
| 98 | - return nullptr; | ||
| 99 | } | 29 | } |
| 100 | 30 | ||
| 101 | #if __ANDROID_API__ >= 9 | 31 | #if __ANDROID_API__ >= 9 |
| 102 | 32 | ||
| 103 | std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | 33 | std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( |
| 104 | AAssetManager *mgr, const OnlineModelConfig &config) { | 34 | AAssetManager *mgr, const OnlineModelConfig &config) { |
| 105 | - ModelType model_type = ModelType::kUnkown; | ||
| 106 | - | ||
| 107 | - std::string filename; | ||
| 108 | if (!config.wenet_ctc.model.empty()) { | 35 | if (!config.wenet_ctc.model.empty()) { |
| 109 | - filename = config.wenet_ctc.model; | 36 | + return std::make_unique<OnlineWenetCtcModel>(mgr, config); |
| 37 | + } else if (!config.zipformer2_ctc.model.empty()) { | ||
| 38 | + return std::make_unique<OnlineZipformer2CtcModel>(mgr, config); | ||
| 110 | } else { | 39 | } else { |
| 111 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 40 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 112 | exit(-1); | 41 | exit(-1); |
| 113 | } | 42 | } |
| 114 | - | ||
| 115 | - { | ||
| 116 | - auto buffer = ReadFile(mgr, filename); | ||
| 117 | - | ||
| 118 | - model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 119 | - } | ||
| 120 | - | ||
| 121 | - switch (model_type) { | ||
| 122 | - case ModelType::kZipformerCtc: | ||
| 123 | - return nullptr; | ||
| 124 | - // return std::make_unique<OnlineZipformerCtcModel>(mgr, config); | ||
| 125 | - break; | ||
| 126 | - case ModelType::kWenetCtc: | ||
| 127 | - return std::make_unique<OnlineWenetCtcModel>(mgr, config); | ||
| 128 | - break; | ||
| 129 | - case ModelType::kUnkown: | ||
| 130 | - SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); | ||
| 131 | - return nullptr; | ||
| 132 | - } | ||
| 133 | - | ||
| 134 | - return nullptr; | ||
| 135 | } | 43 | } |
| 136 | #endif | 44 | #endif |
| 137 | 45 |
| @@ -33,6 +33,26 @@ class OnlineCtcModel { | @@ -33,6 +33,26 @@ class OnlineCtcModel { | ||
| 33 | // Return a list of tensors containing the initial states | 33 | // Return a list of tensors containing the initial states |
| 34 | virtual std::vector<Ort::Value> GetInitStates() const = 0; | 34 | virtual std::vector<Ort::Value> GetInitStates() const = 0; |
| 35 | 35 | ||
| 36 | + /** Stack a list of individual states into a batch. | ||
| 37 | + * | ||
| 38 | + * It is the inverse operation of `UnStackStates`. | ||
| 39 | + * | ||
| 40 | + * @param states states[i] contains the state for the i-th utterance. | ||
| 41 | + * @return Return a single value representing the batched state. | ||
| 42 | + */ | ||
| 43 | + virtual std::vector<Ort::Value> StackStates( | ||
| 44 | + std::vector<std::vector<Ort::Value>> states) const = 0; | ||
| 45 | + | ||
| 46 | + /** Unstack a batch state into a list of individual states. | ||
| 47 | + * | ||
| 48 | + * It is the inverse operation of `StackStates`. | ||
| 49 | + * | ||
| 50 | + * @param states A batched state. | ||
| 51 | + * @return ans[i] contains the state for the i-th utterance. | ||
| 52 | + */ | ||
| 53 | + virtual std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 54 | + std::vector<Ort::Value> states) const = 0; | ||
| 55 | + | ||
| 36 | /** | 56 | /** |
| 37 | * | 57 | * |
| 38 | * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | 58 | * @param x A 3-D tensor of shape (N, T, C). N has to be 1. |
| @@ -60,6 +80,9 @@ class OnlineCtcModel { | @@ -60,6 +80,9 @@ class OnlineCtcModel { | ||
| 60 | // ChunkLength() frames, we advance by ChunkShift() frames | 80 | // ChunkLength() frames, we advance by ChunkShift() frames |
| 61 | // before we process the next chunk. | 81 | // before we process the next chunk. |
| 62 | virtual int32_t ChunkShift() const = 0; | 82 | virtual int32_t ChunkShift() const = 0; |
| 83 | + | ||
| 84 | + // Return true if the model supports batch size > 1 | ||
| 85 | + virtual bool SupportBatchProcessing() const { return true; } | ||
| 63 | }; | 86 | }; |
| 64 | 87 | ||
| 65 | } // namespace sherpa_onnx | 88 | } // namespace sherpa_onnx |
| @@ -14,6 +14,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -14,6 +14,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 14 | transducer.Register(po); | 14 | transducer.Register(po); |
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | wenet_ctc.Register(po); | 16 | wenet_ctc.Register(po); |
| 17 | + zipformer2_ctc.Register(po); | ||
| 17 | 18 | ||
| 18 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 19 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 19 | 20 | ||
| @@ -26,10 +27,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -26,10 +27,11 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 26 | po->Register("provider", &provider, | 27 | po->Register("provider", &provider, |
| 27 | "Specify a provider to use: cpu, cuda, coreml"); | 28 | "Specify a provider to use: cpu, cuda, coreml"); |
| 28 | 29 | ||
| 29 | - po->Register("model-type", &model_type, | ||
| 30 | - "Specify it to reduce model initialization time. " | ||
| 31 | - "Valid values are: conformer, lstm, zipformer, zipformer2." | ||
| 32 | - "All other values lead to loading the model twice."); | 30 | + po->Register( |
| 31 | + "model-type", &model_type, | ||
| 32 | + "Specify it to reduce model initialization time. " | ||
| 33 | + "Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc" | ||
| 34 | + "All other values lead to loading the model twice."); | ||
| 33 | } | 35 | } |
| 34 | 36 | ||
| 35 | bool OnlineModelConfig::Validate() const { | 37 | bool OnlineModelConfig::Validate() const { |
| @@ -51,6 +53,10 @@ bool OnlineModelConfig::Validate() const { | @@ -51,6 +53,10 @@ bool OnlineModelConfig::Validate() const { | ||
| 51 | return wenet_ctc.Validate(); | 53 | return wenet_ctc.Validate(); |
| 52 | } | 54 | } |
| 53 | 55 | ||
| 56 | + if (!zipformer2_ctc.model.empty()) { | ||
| 57 | + return zipformer2_ctc.Validate(); | ||
| 58 | + } | ||
| 59 | + | ||
| 54 | return transducer.Validate(); | 60 | return transducer.Validate(); |
| 55 | } | 61 | } |
| 56 | 62 | ||
| @@ -61,6 +67,7 @@ std::string OnlineModelConfig::ToString() const { | @@ -61,6 +67,7 @@ std::string OnlineModelConfig::ToString() const { | ||
| 61 | os << "transducer=" << transducer.ToString() << ", "; | 67 | os << "transducer=" << transducer.ToString() << ", "; |
| 62 | os << "paraformer=" << paraformer.ToString() << ", "; | 68 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 63 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 69 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 70 | + os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; | ||
| 64 | os << "tokens=\"" << tokens << "\", "; | 71 | os << "tokens=\"" << tokens << "\", "; |
| 65 | os << "num_threads=" << num_threads << ", "; | 72 | os << "num_threads=" << num_threads << ", "; |
| 66 | os << "debug=" << (debug ? "True" : "False") << ", "; | 73 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" | 9 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" |
| 12 | +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" | ||
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | 15 | ||
| @@ -16,6 +17,7 @@ struct OnlineModelConfig { | @@ -16,6 +17,7 @@ struct OnlineModelConfig { | ||
| 16 | OnlineTransducerModelConfig transducer; | 17 | OnlineTransducerModelConfig transducer; |
| 17 | OnlineParaformerModelConfig paraformer; | 18 | OnlineParaformerModelConfig paraformer; |
| 18 | OnlineWenetCtcModelConfig wenet_ctc; | 19 | OnlineWenetCtcModelConfig wenet_ctc; |
| 20 | + OnlineZipformer2CtcModelConfig zipformer2_ctc; | ||
| 19 | std::string tokens; | 21 | std::string tokens; |
| 20 | int32_t num_threads = 1; | 22 | int32_t num_threads = 1; |
| 21 | bool debug = false; | 23 | bool debug = false; |
| @@ -25,7 +27,8 @@ struct OnlineModelConfig { | @@ -25,7 +27,8 @@ struct OnlineModelConfig { | ||
| 25 | // - conformer, conformer transducer from icefall | 27 | // - conformer, conformer transducer from icefall |
| 26 | // - lstm, lstm transducer from icefall | 28 | // - lstm, lstm transducer from icefall |
| 27 | // - zipformer, zipformer transducer from icefall | 29 | // - zipformer, zipformer transducer from icefall |
| 28 | - // - zipformer2, zipformer2 transducer from icefall | 30 | + // - zipformer2, zipformer2 transducer or CTC from icefall |
| 31 | + // - wenet_ctc, wenet CTC model | ||
| 29 | // | 32 | // |
| 30 | // All other values are invalid and lead to loading the model twice. | 33 | // All other values are invalid and lead to loading the model twice. |
| 31 | std::string model_type; | 34 | std::string model_type; |
| @@ -34,11 +37,13 @@ struct OnlineModelConfig { | @@ -34,11 +37,13 @@ struct OnlineModelConfig { | ||
| 34 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | 37 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, |
| 35 | const OnlineParaformerModelConfig ¶former, | 38 | const OnlineParaformerModelConfig ¶former, |
| 36 | const OnlineWenetCtcModelConfig &wenet_ctc, | 39 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 40 | + const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | ||
| 37 | const std::string &tokens, int32_t num_threads, bool debug, | 41 | const std::string &tokens, int32_t num_threads, bool debug, |
| 38 | const std::string &provider, const std::string &model_type) | 42 | const std::string &provider, const std::string &model_type) |
| 39 | : transducer(transducer), | 43 | : transducer(transducer), |
| 40 | paraformer(paraformer), | 44 | paraformer(paraformer), |
| 41 | wenet_ctc(wenet_ctc), | 45 | wenet_ctc(wenet_ctc), |
| 46 | + zipformer2_ctc(zipformer2_ctc), | ||
| 42 | tokens(tokens), | 47 | tokens(tokens), |
| 43 | num_threads(num_threads), | 48 | num_threads(num_threads), |
| 44 | debug(debug), | 49 | debug(debug), |
| @@ -96,8 +96,67 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -96,8 +96,67 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 96 | } | 96 | } |
| 97 | 97 | ||
| 98 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { | 98 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { |
| 99 | + if (n == 1 || !model_->SupportBatchProcessing()) { | ||
| 100 | + for (int32_t i = 0; i != n; ++i) { | ||
| 101 | + DecodeStream(ss[i]); | ||
| 102 | + } | ||
| 103 | + return; | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + // batch processing | ||
| 107 | + int32_t chunk_length = model_->ChunkLength(); | ||
| 108 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 109 | + | ||
| 110 | + int32_t feat_dim = ss[0]->FeatureDim(); | ||
| 111 | + | ||
| 112 | + std::vector<OnlineCtcDecoderResult> results(n); | ||
| 113 | + std::vector<float> features_vec(n * chunk_length * feat_dim); | ||
| 114 | + std::vector<std::vector<Ort::Value>> states_vec(n); | ||
| 115 | + std::vector<int64_t> all_processed_frames(n); | ||
| 116 | + | ||
| 99 | for (int32_t i = 0; i != n; ++i) { | 117 | for (int32_t i = 0; i != n; ++i) { |
| 100 | - DecodeStream(ss[i]); | 118 | + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); |
| 119 | + std::vector<float> features = | ||
| 120 | + ss[i]->GetFrames(num_processed_frames, chunk_length); | ||
| 121 | + | ||
| 122 | + // Question: should num_processed_frames include chunk_shift? | ||
| 123 | + ss[i]->GetNumProcessedFrames() += chunk_shift; | ||
| 124 | + | ||
| 125 | + std::copy(features.begin(), features.end(), | ||
| 126 | + features_vec.data() + i * chunk_length * feat_dim); | ||
| 127 | + | ||
| 128 | + results[i] = std::move(ss[i]->GetCtcResult()); | ||
| 129 | + states_vec[i] = std::move(ss[i]->GetStates()); | ||
| 130 | + all_processed_frames[i] = num_processed_frames; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + auto memory_info = | ||
| 134 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 135 | + | ||
| 136 | + std::array<int64_t, 3> x_shape{n, chunk_length, feat_dim}; | ||
| 137 | + | ||
| 138 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | ||
| 139 | + features_vec.size(), x_shape.data(), | ||
| 140 | + x_shape.size()); | ||
| 141 | + | ||
| 142 | + auto states = model_->StackStates(std::move(states_vec)); | ||
| 143 | + int32_t num_states = states.size(); | ||
| 144 | + auto out = model_->Forward(std::move(x), std::move(states)); | ||
| 145 | + std::vector<Ort::Value> out_states; | ||
| 146 | + out_states.reserve(num_states); | ||
| 147 | + | ||
| 148 | + for (int32_t k = 1; k != num_states + 1; ++k) { | ||
| 149 | + out_states.push_back(std::move(out[k])); | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + std::vector<std::vector<Ort::Value>> next_states = | ||
| 153 | + model_->UnStackStates(std::move(out_states)); | ||
| 154 | + | ||
| 155 | + decoder_->Decode(std::move(out[0]), &results); | ||
| 156 | + | ||
| 157 | + for (int32_t k = 0; k != n; ++k) { | ||
| 158 | + ss[k]->SetCtcResult(results[k]); | ||
| 159 | + ss[k]->SetStates(std::move(next_states[k])); | ||
| 101 | } | 160 | } |
| 102 | } | 161 | } |
| 103 | 162 |
| @@ -20,7 +20,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -20,7 +20,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 20 | return std::make_unique<OnlineRecognizerParaformerImpl>(config); | 20 | return std::make_unique<OnlineRecognizerParaformerImpl>(config); |
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | - if (!config.model_config.wenet_ctc.model.empty()) { | 23 | + if (!config.model_config.wenet_ctc.model.empty() || |
| 24 | + !config.model_config.zipformer2_ctc.model.empty()) { | ||
| 24 | return std::make_unique<OnlineRecognizerCtcImpl>(config); | 25 | return std::make_unique<OnlineRecognizerCtcImpl>(config); |
| 25 | } | 26 | } |
| 26 | 27 | ||
| @@ -39,7 +40,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -39,7 +40,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 39 | return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); | 40 | return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); |
| 40 | } | 41 | } |
| 41 | 42 | ||
| 42 | - if (!config.model_config.wenet_ctc.model.empty()) { | 43 | + if (!config.model_config.wenet_ctc.model.empty() || |
| 44 | + !config.model_config.zipformer2_ctc.model.empty()) { | ||
| 43 | return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); | 45 | return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); |
| 44 | } | 46 | } |
| 45 | 47 |
| 1 | -// sherpa-onnx/csrc/online-paraformer-model.cc | 1 | +// sherpa-onnx/csrc/online-wenet-ctc-model.cc |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | 4 | ||
| @@ -239,4 +239,21 @@ std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const { | @@ -239,4 +239,21 @@ std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const { | ||
| 239 | return impl_->GetInitStates(); | 239 | return impl_->GetInitStates(); |
| 240 | } | 240 | } |
| 241 | 241 | ||
| 242 | +std::vector<Ort::Value> OnlineWenetCtcModel::StackStates( | ||
| 243 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 244 | + if (states.size() != 1) { | ||
| 245 | + SHERPA_ONNX_LOGE("wenet CTC model supports only batch_size==1. Given: %d", | ||
| 246 | + static_cast<int32_t>(states.size())); | ||
| 247 | + } | ||
| 248 | + | ||
| 249 | + return std::move(states[0]); | ||
| 250 | +} | ||
| 251 | + | ||
| 252 | +std::vector<std::vector<Ort::Value>> OnlineWenetCtcModel::UnStackStates( | ||
| 253 | + std::vector<Ort::Value> states) const { | ||
| 254 | + std::vector<std::vector<Ort::Value>> ans(1); | ||
| 255 | + ans[0] = std::move(states); | ||
| 256 | + return ans; | ||
| 257 | +} | ||
| 258 | + | ||
| 242 | } // namespace sherpa_onnx | 259 | } // namespace sherpa_onnx |
| @@ -35,6 +35,12 @@ class OnlineWenetCtcModel : public OnlineCtcModel { | @@ -35,6 +35,12 @@ class OnlineWenetCtcModel : public OnlineCtcModel { | ||
| 35 | // - offset | 35 | // - offset |
| 36 | std::vector<Ort::Value> GetInitStates() const override; | 36 | std::vector<Ort::Value> GetInitStates() const override; |
| 37 | 37 | ||
| 38 | + std::vector<Ort::Value> StackStates( | ||
| 39 | + std::vector<std::vector<Ort::Value>> states) const override; | ||
| 40 | + | ||
| 41 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 42 | + std::vector<Ort::Value> states) const override; | ||
| 43 | + | ||
| 38 | /** | 44 | /** |
| 39 | * | 45 | * |
| 40 | * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | 46 | * @param x A 3-D tensor of shape (N, T, C). N has to be 1. |
| @@ -63,6 +69,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel { | @@ -63,6 +69,8 @@ class OnlineWenetCtcModel : public OnlineCtcModel { | ||
| 63 | // before we process the next chunk. | 69 | // before we process the next chunk. |
| 64 | int32_t ChunkShift() const override; | 70 | int32_t ChunkShift() const override; |
| 65 | 71 | ||
| 72 | + bool SupportBatchProcessing() const override { return false; } | ||
| 73 | + | ||
| 66 | private: | 74 | private: |
| 67 | class Impl; | 75 | class Impl; |
| 68 | std::unique_ptr<Impl> impl_; | 76 | std::unique_ptr<Impl> impl_; |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OnlineZipformer2CtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("zipformer2-ctc-model", &model, | ||
| 14 | + "Path to CTC model.onnx. See also " | ||
| 15 | + "https://github.com/k2-fsa/icefall/pull/1413"); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +bool OnlineZipformer2CtcModelConfig::Validate() const { | ||
| 19 | + if (model.empty()) { | ||
| 20 | + SHERPA_ONNX_LOGE("--zipformer2-ctc-model is empty!"); | ||
| 21 | + return false; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + if (!FileExists(model)) { | ||
| 25 | + SHERPA_ONNX_LOGE("--zipformer2-ctc-model %s does not exist", model.c_str()); | ||
| 26 | + return false; | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + return true; | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +std::string OnlineZipformer2CtcModelConfig::ToString() const { | ||
| 33 | + std::ostringstream os; | ||
| 34 | + | ||
| 35 | + os << "OnlineZipformer2CtcModelConfig("; | ||
| 36 | + os << "model=\"" << model << "\")"; | ||
| 37 | + | ||
| 38 | + return os.str(); | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OnlineZipformer2CtcModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OnlineZipformer2CtcModelConfig() = default; | ||
| 17 | + | ||
| 18 | + explicit OnlineZipformer2CtcModelConfig(const std::string &model) | ||
| 19 | + : model(model) {} | ||
| 20 | + | ||
| 21 | + void Register(ParseOptions *po); | ||
| 22 | + bool Validate() const; | ||
| 23 | + | ||
| 24 | + std::string ToString() const; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | +#include <math.h> | ||
| 9 | + | ||
| 10 | +#include <algorithm> | ||
| 11 | +#include <cmath> | ||
| 12 | +#include <numeric> | ||
| 13 | +#include <string> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#include "sherpa-onnx/csrc/cat.h" | ||
| 21 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 22 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 23 | +#include "sherpa-onnx/csrc/session.h" | ||
| 24 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 25 | +#include "sherpa-onnx/csrc/unbind.h" | ||
| 26 | + | ||
| 27 | +namespace sherpa_onnx { | ||
| 28 | + | ||
| 29 | +class OnlineZipformer2CtcModel::Impl { | ||
| 30 | + public: | ||
| 31 | + explicit Impl(const OnlineModelConfig &config) | ||
| 32 | + : config_(config), | ||
| 33 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 34 | + sess_opts_(GetSessionOptions(config)), | ||
| 35 | + allocator_{} { | ||
| 36 | + { | ||
| 37 | + auto buf = ReadFile(config.zipformer2_ctc.model); | ||
| 38 | + Init(buf.data(), buf.size()); | ||
| 39 | + } | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | +#if __ANDROID_API__ >= 9 | ||
| 43 | + Impl(AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 44 | + : config_(config), | ||
| 45 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 46 | + sess_opts_(GetSessionOptions(config)), | ||
| 47 | + allocator_{} { | ||
| 48 | + { | ||
| 49 | + auto buf = ReadFile(mgr, config.zipformer2_ctc.model); | ||
| 50 | + Init(buf.data(), buf.size()); | ||
| 51 | + } | ||
| 52 | + } | ||
| 53 | +#endif | ||
| 54 | + | ||
| 55 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 56 | + std::vector<Ort::Value> states) { | ||
| 57 | + std::vector<Ort::Value> inputs; | ||
| 58 | + inputs.reserve(1 + states.size()); | ||
| 59 | + | ||
| 60 | + inputs.push_back(std::move(features)); | ||
| 61 | + for (auto &v : states) { | ||
| 62 | + inputs.push_back(std::move(v)); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 66 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 70 | + | ||
| 71 | + int32_t ChunkLength() const { return T_; } | ||
| 72 | + | ||
| 73 | + int32_t ChunkShift() const { return decode_chunk_len_; } | ||
| 74 | + | ||
| 75 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 76 | + | ||
| 77 | + // Return a vector containing 3 tensors | ||
| 78 | + // - attn_cache | ||
| 79 | + // - conv_cache | ||
| 80 | + // - offset | ||
| 81 | + std::vector<Ort::Value> GetInitStates() { | ||
| 82 | + std::vector<Ort::Value> ans; | ||
| 83 | + ans.reserve(initial_states_.size()); | ||
| 84 | + for (auto &s : initial_states_) { | ||
| 85 | + ans.push_back(View(&s)); | ||
| 86 | + } | ||
| 87 | + return ans; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + std::vector<Ort::Value> StackStates( | ||
| 91 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 92 | + int32_t batch_size = static_cast<int32_t>(states.size()); | ||
| 93 | + int32_t num_encoders = static_cast<int32_t>(num_encoder_layers_.size()); | ||
| 94 | + | ||
| 95 | + std::vector<const Ort::Value *> buf(batch_size); | ||
| 96 | + | ||
| 97 | + std::vector<Ort::Value> ans; | ||
| 98 | + int32_t num_states = static_cast<int32_t>(states[0].size()); | ||
| 99 | + ans.reserve(num_states); | ||
| 100 | + | ||
| 101 | + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { | ||
| 102 | + { | ||
| 103 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 104 | + buf[n] = &states[n][6 * i]; | ||
| 105 | + } | ||
| 106 | + auto v = Cat(allocator_, buf, 1); | ||
| 107 | + ans.push_back(std::move(v)); | ||
| 108 | + } | ||
| 109 | + { | ||
| 110 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 111 | + buf[n] = &states[n][6 * i + 1]; | ||
| 112 | + } | ||
| 113 | + auto v = Cat(allocator_, buf, 1); | ||
| 114 | + ans.push_back(std::move(v)); | ||
| 115 | + } | ||
| 116 | + { | ||
| 117 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 118 | + buf[n] = &states[n][6 * i + 2]; | ||
| 119 | + } | ||
| 120 | + auto v = Cat(allocator_, buf, 1); | ||
| 121 | + ans.push_back(std::move(v)); | ||
| 122 | + } | ||
| 123 | + { | ||
| 124 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 125 | + buf[n] = &states[n][6 * i + 3]; | ||
| 126 | + } | ||
| 127 | + auto v = Cat(allocator_, buf, 1); | ||
| 128 | + ans.push_back(std::move(v)); | ||
| 129 | + } | ||
| 130 | + { | ||
| 131 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 132 | + buf[n] = &states[n][6 * i + 4]; | ||
| 133 | + } | ||
| 134 | + auto v = Cat(allocator_, buf, 0); | ||
| 135 | + ans.push_back(std::move(v)); | ||
| 136 | + } | ||
| 137 | + { | ||
| 138 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 139 | + buf[n] = &states[n][6 * i + 5]; | ||
| 140 | + } | ||
| 141 | + auto v = Cat(allocator_, buf, 0); | ||
| 142 | + ans.push_back(std::move(v)); | ||
| 143 | + } | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + { | ||
| 147 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 148 | + buf[n] = &states[n][num_states - 2]; | ||
| 149 | + } | ||
| 150 | + auto v = Cat(allocator_, buf, 0); | ||
| 151 | + ans.push_back(std::move(v)); | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + { | ||
| 155 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 156 | + buf[n] = &states[n][num_states - 1]; | ||
| 157 | + } | ||
| 158 | + auto v = Cat<int64_t>(allocator_, buf, 0); | ||
| 159 | + ans.push_back(std::move(v)); | ||
| 160 | + } | ||
| 161 | + return ans; | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 165 | + std::vector<Ort::Value> states) const { | ||
| 166 | + int32_t m = std::accumulate(num_encoder_layers_.begin(), | ||
| 167 | + num_encoder_layers_.end(), 0); | ||
| 168 | + assert(states.size() == m * 6 + 2); | ||
| 169 | + | ||
| 170 | + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 171 | + int32_t num_encoders = num_encoder_layers_.size(); | ||
| 172 | + | ||
| 173 | + std::vector<std::vector<Ort::Value>> ans; | ||
| 174 | + ans.resize(batch_size); | ||
| 175 | + | ||
| 176 | + for (int32_t i = 0; i != m; ++i) { | ||
| 177 | + { | ||
| 178 | + auto v = Unbind(allocator_, &states[i * 6], 1); | ||
| 179 | + assert(v.size() == batch_size); | ||
| 180 | + | ||
| 181 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 182 | + ans[n].push_back(std::move(v[n])); | ||
| 183 | + } | ||
| 184 | + } | ||
| 185 | + { | ||
| 186 | + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); | ||
| 187 | + assert(v.size() == batch_size); | ||
| 188 | + | ||
| 189 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 190 | + ans[n].push_back(std::move(v[n])); | ||
| 191 | + } | ||
| 192 | + } | ||
| 193 | + { | ||
| 194 | + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); | ||
| 195 | + assert(v.size() == batch_size); | ||
| 196 | + | ||
| 197 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 198 | + ans[n].push_back(std::move(v[n])); | ||
| 199 | + } | ||
| 200 | + } | ||
| 201 | + { | ||
| 202 | + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); | ||
| 203 | + assert(v.size() == batch_size); | ||
| 204 | + | ||
| 205 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 206 | + ans[n].push_back(std::move(v[n])); | ||
| 207 | + } | ||
| 208 | + } | ||
| 209 | + { | ||
| 210 | + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); | ||
| 211 | + assert(v.size() == batch_size); | ||
| 212 | + | ||
| 213 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 214 | + ans[n].push_back(std::move(v[n])); | ||
| 215 | + } | ||
| 216 | + } | ||
| 217 | + { | ||
| 218 | + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); | ||
| 219 | + assert(v.size() == batch_size); | ||
| 220 | + | ||
| 221 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 222 | + ans[n].push_back(std::move(v[n])); | ||
| 223 | + } | ||
| 224 | + } | ||
| 225 | + } | ||
| 226 | + | ||
| 227 | + { | ||
| 228 | + auto v = Unbind(allocator_, &states[m * 6], 0); | ||
| 229 | + assert(v.size() == batch_size); | ||
| 230 | + | ||
| 231 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 232 | + ans[n].push_back(std::move(v[n])); | ||
| 233 | + } | ||
| 234 | + } | ||
| 235 | + { | ||
| 236 | + auto v = Unbind<int64_t>(allocator_, &states[m * 6 + 1], 0); | ||
| 237 | + assert(v.size() == batch_size); | ||
| 238 | + | ||
| 239 | + for (int32_t n = 0; n != batch_size; ++n) { | ||
| 240 | + ans[n].push_back(std::move(v[n])); | ||
| 241 | + } | ||
| 242 | + } | ||
| 243 | + | ||
| 244 | + return ans; | ||
| 245 | + } | ||
| 246 | + | ||
| 247 | + private: | ||
| 248 | + void Init(void *model_data, size_t model_data_length) { | ||
| 249 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 250 | + sess_opts_); | ||
| 251 | + | ||
| 252 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 253 | + | ||
| 254 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 255 | + | ||
| 256 | + // get meta data | ||
| 257 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 258 | + if (config_.debug) { | ||
| 259 | + std::ostringstream os; | ||
| 260 | + os << "---zipformer2_ctc---\n"; | ||
| 261 | + PrintModelMetadata(os, meta_data); | ||
| 262 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 263 | + } | ||
| 264 | + | ||
| 265 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 266 | + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); | ||
| 267 | + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); | ||
| 268 | + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); | ||
| 269 | + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); | ||
| 270 | + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); | ||
| 271 | + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 272 | + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); | ||
| 273 | + | ||
| 274 | + SHERPA_ONNX_READ_META_DATA(T_, "T"); | ||
| 275 | + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); | ||
| 276 | + | ||
| 277 | + { | ||
| 278 | + auto shape = | ||
| 279 | + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); | ||
| 280 | + vocab_size_ = shape[2]; | ||
| 281 | + } | ||
| 282 | + | ||
| 283 | + if (config_.debug) { | ||
| 284 | + auto print = [](const std::vector<int32_t> &v, const char *name) { | ||
| 285 | + fprintf(stderr, "%s: ", name); | ||
| 286 | + for (auto i : v) { | ||
| 287 | + fprintf(stderr, "%d ", i); | ||
| 288 | + } | ||
| 289 | + fprintf(stderr, "\n"); | ||
| 290 | + }; | ||
| 291 | + print(encoder_dims_, "encoder_dims"); | ||
| 292 | + print(query_head_dims_, "query_head_dims"); | ||
| 293 | + print(value_head_dims_, "value_head_dims"); | ||
| 294 | + print(num_heads_, "num_heads"); | ||
| 295 | + print(num_encoder_layers_, "num_encoder_layers"); | ||
| 296 | + print(cnn_module_kernels_, "cnn_module_kernels"); | ||
| 297 | + print(left_context_len_, "left_context_len"); | ||
| 298 | + SHERPA_ONNX_LOGE("T: %d", T_); | ||
| 299 | + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); | ||
| 300 | + SHERPA_ONNX_LOGE("vocab_size_: %d", vocab_size_); | ||
| 301 | + } | ||
| 302 | + | ||
| 303 | + InitStates(); | ||
| 304 | + } | ||
| 305 | + | ||
| 306 | + void InitStates() { | ||
| 307 | + int32_t n = static_cast<int32_t>(encoder_dims_.size()); | ||
| 308 | + int32_t m = std::accumulate(num_encoder_layers_.begin(), | ||
| 309 | + num_encoder_layers_.end(), 0); | ||
| 310 | + initial_states_.reserve(m * 6 + 2); | ||
| 311 | + | ||
| 312 | + for (int32_t i = 0; i != n; ++i) { | ||
| 313 | + int32_t num_layers = num_encoder_layers_[i]; | ||
| 314 | + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; | ||
| 315 | + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; | ||
| 316 | + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; | ||
| 317 | + | ||
| 318 | + for (int32_t j = 0; j != num_layers; ++j) { | ||
| 319 | + { | ||
| 320 | + std::array<int64_t, 3> s{left_context_len_[i], 1, key_dim}; | ||
| 321 | + auto v = | ||
| 322 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 323 | + Fill(&v, 0); | ||
| 324 | + initial_states_.push_back(std::move(v)); | ||
| 325 | + } | ||
| 326 | + | ||
| 327 | + { | ||
| 328 | + std::array<int64_t, 4> s{1, 1, left_context_len_[i], | ||
| 329 | + nonlin_attn_head_dim}; | ||
| 330 | + auto v = | ||
| 331 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 332 | + Fill(&v, 0); | ||
| 333 | + initial_states_.push_back(std::move(v)); | ||
| 334 | + } | ||
| 335 | + | ||
| 336 | + { | ||
| 337 | + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim}; | ||
| 338 | + auto v = | ||
| 339 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 340 | + Fill(&v, 0); | ||
| 341 | + initial_states_.push_back(std::move(v)); | ||
| 342 | + } | ||
| 343 | + | ||
| 344 | + { | ||
| 345 | + std::array<int64_t, 3> s{left_context_len_[i], 1, value_dim}; | ||
| 346 | + auto v = | ||
| 347 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 348 | + Fill(&v, 0); | ||
| 349 | + initial_states_.push_back(std::move(v)); | ||
| 350 | + } | ||
| 351 | + | ||
| 352 | + { | ||
| 353 | + std::array<int64_t, 3> s{1, encoder_dims_[i], | ||
| 354 | + cnn_module_kernels_[i] / 2}; | ||
| 355 | + auto v = | ||
| 356 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 357 | + Fill(&v, 0); | ||
| 358 | + initial_states_.push_back(std::move(v)); | ||
| 359 | + } | ||
| 360 | + | ||
| 361 | + { | ||
| 362 | + std::array<int64_t, 3> s{1, encoder_dims_[i], | ||
| 363 | + cnn_module_kernels_[i] / 2}; | ||
| 364 | + auto v = | ||
| 365 | + Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 366 | + Fill(&v, 0); | ||
| 367 | + initial_states_.push_back(std::move(v)); | ||
| 368 | + } | ||
| 369 | + } | ||
| 370 | + } | ||
| 371 | + | ||
| 372 | + { | ||
| 373 | + std::array<int64_t, 4> s{1, 128, 3, 19}; | ||
| 374 | + auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | ||
| 375 | + Fill(&v, 0); | ||
| 376 | + initial_states_.push_back(std::move(v)); | ||
| 377 | + } | ||
| 378 | + | ||
| 379 | + { | ||
| 380 | + std::array<int64_t, 1> s{1}; | ||
| 381 | + auto v = | ||
| 382 | + Ort::Value::CreateTensor<int64_t>(allocator_, s.data(), s.size()); | ||
| 383 | + Fill<int64_t>(&v, 0); | ||
| 384 | + initial_states_.push_back(std::move(v)); | ||
| 385 | + } | ||
| 386 | + } | ||
| 387 | + | ||
| 388 | + private: | ||
| 389 | + OnlineModelConfig config_; | ||
| 390 | + Ort::Env env_; | ||
| 391 | + Ort::SessionOptions sess_opts_; | ||
| 392 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 393 | + | ||
| 394 | + std::unique_ptr<Ort::Session> sess_; | ||
| 395 | + | ||
| 396 | + std::vector<std::string> input_names_; | ||
| 397 | + std::vector<const char *> input_names_ptr_; | ||
| 398 | + | ||
| 399 | + std::vector<std::string> output_names_; | ||
| 400 | + std::vector<const char *> output_names_ptr_; | ||
| 401 | + | ||
| 402 | + std::vector<Ort::Value> initial_states_; | ||
| 403 | + | ||
| 404 | + std::vector<int32_t> encoder_dims_; | ||
| 405 | + std::vector<int32_t> query_head_dims_; | ||
| 406 | + std::vector<int32_t> value_head_dims_; | ||
| 407 | + std::vector<int32_t> num_heads_; | ||
| 408 | + std::vector<int32_t> num_encoder_layers_; | ||
| 409 | + std::vector<int32_t> cnn_module_kernels_; | ||
| 410 | + std::vector<int32_t> left_context_len_; | ||
| 411 | + | ||
| 412 | + int32_t T_ = 0; | ||
| 413 | + int32_t decode_chunk_len_ = 0; | ||
| 414 | + int32_t vocab_size_ = 0; | ||
| 415 | +}; | ||
| 416 | + | ||
| 417 | +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( | ||
| 418 | + const OnlineModelConfig &config) | ||
| 419 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 420 | + | ||
| 421 | +#if __ANDROID_API__ >= 9 | ||
| 422 | +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( | ||
| 423 | + AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 424 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 425 | +#endif | ||
| 426 | + | ||
| 427 | +OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default; | ||
| 428 | + | ||
| 429 | +std::vector<Ort::Value> OnlineZipformer2CtcModel::Forward( | ||
| 430 | + Ort::Value x, std::vector<Ort::Value> states) const { | ||
| 431 | + return impl_->Forward(std::move(x), std::move(states)); | ||
| 432 | +} | ||
| 433 | + | ||
| 434 | +int32_t OnlineZipformer2CtcModel::VocabSize() const { | ||
| 435 | + return impl_->VocabSize(); | ||
| 436 | +} | ||
| 437 | + | ||
| 438 | +int32_t OnlineZipformer2CtcModel::ChunkLength() const { | ||
| 439 | + return impl_->ChunkLength(); | ||
| 440 | +} | ||
| 441 | + | ||
| 442 | +int32_t OnlineZipformer2CtcModel::ChunkShift() const { | ||
| 443 | + return impl_->ChunkShift(); | ||
| 444 | +} | ||
| 445 | + | ||
| 446 | +OrtAllocator *OnlineZipformer2CtcModel::Allocator() const { | ||
| 447 | + return impl_->Allocator(); | ||
| 448 | +} | ||
| 449 | + | ||
| 450 | +std::vector<Ort::Value> OnlineZipformer2CtcModel::GetInitStates() const { | ||
| 451 | + return impl_->GetInitStates(); | ||
| 452 | +} | ||
| 453 | + | ||
| 454 | +std::vector<Ort::Value> OnlineZipformer2CtcModel::StackStates( | ||
| 455 | + std::vector<std::vector<Ort::Value>> states) const { | ||
| 456 | + return impl_->StackStates(std::move(states)); | ||
| 457 | +} | ||
| 458 | + | ||
| 459 | +std::vector<std::vector<Ort::Value>> OnlineZipformer2CtcModel::UnStackStates( | ||
| 460 | + std::vector<Ort::Value> states) const { | ||
| 461 | + return impl_->UnStackStates(std::move(states)); | ||
| 462 | +} | ||
| 463 | + | ||
| 464 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-zipformer2-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/online-ctc-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +class OnlineZipformer2CtcModel : public OnlineCtcModel { | ||
| 23 | + public: | ||
| 24 | + explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config); | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + OnlineZipformer2CtcModel(AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 28 | +#endif | ||
| 29 | + | ||
| 30 | + ~OnlineZipformer2CtcModel() override; | ||
| 31 | + | ||
| 32 | + // A list of tensors. | ||
| 33 | + // See also | ||
| 34 | + // https://github.com/k2-fsa/icefall/pull/1413 | ||
| 35 | + // and | ||
| 36 | + // https://github.com/k2-fsa/icefall/pull/1415 | ||
| 37 | + std::vector<Ort::Value> GetInitStates() const override; | ||
| 38 | + | ||
| 39 | + std::vector<Ort::Value> StackStates( | ||
| 40 | + std::vector<std::vector<Ort::Value>> states) const override; | ||
| 41 | + | ||
| 42 | + std::vector<std::vector<Ort::Value>> UnStackStates( | ||
| 43 | + std::vector<Ort::Value> states) const override; | ||
| 44 | + | ||
| 45 | + /** | ||
| 46 | + * | ||
| 47 | + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | ||
| 48 | + * @param states It is from GetInitStates() or returned from this method. | ||
| 49 | + * | ||
| 50 | + * @return Return a list of tensors | ||
| 51 | + * - ans[0] contains log_probs, of shape (N, T, C) | ||
| 52 | + * - ans[1:] contains next_states | ||
| 53 | + */ | ||
| 54 | + std::vector<Ort::Value> Forward( | ||
| 55 | + Ort::Value x, std::vector<Ort::Value> states) const override; | ||
| 56 | + | ||
| 57 | + /** Return the vocabulary size of the model | ||
| 58 | + */ | ||
| 59 | + int32_t VocabSize() const override; | ||
| 60 | + | ||
| 61 | + /** Return an allocator for allocating memory | ||
| 62 | + */ | ||
| 63 | + OrtAllocator *Allocator() const override; | ||
| 64 | + | ||
| 65 | + // The model accepts this number of frames before subsampling as input | ||
| 66 | + int32_t ChunkLength() const override; | ||
| 67 | + | ||
| 68 | + // Similar to frame_shift in feature extractor, after processing | ||
| 69 | + // ChunkLength() frames, we advance by ChunkShift() frames | ||
| 70 | + // before we process the next chunk. | ||
| 71 | + int32_t ChunkShift() const override; | ||
| 72 | + | ||
| 73 | + private: | ||
| 74 | + class Impl; | ||
| 75 | + std::unique_ptr<Impl> impl_; | ||
| 76 | +}; | ||
| 77 | + | ||
| 78 | +} // namespace sherpa_onnx | ||
| 79 | + | ||
| 80 | +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ |
| @@ -26,6 +26,8 @@ int main(int32_t argc, char *argv[]) { | @@ -26,6 +26,8 @@ int main(int32_t argc, char *argv[]) { | ||
| 26 | const char *kUsageMessage = R"usage( | 26 | const char *kUsageMessage = R"usage( |
| 27 | Usage: | 27 | Usage: |
| 28 | 28 | ||
| 29 | +(1) Streaming transducer | ||
| 30 | + | ||
| 29 | ./bin/sherpa-onnx \ | 31 | ./bin/sherpa-onnx \ |
| 30 | --tokens=/path/to/tokens.txt \ | 32 | --tokens=/path/to/tokens.txt \ |
| 31 | --encoder=/path/to/encoder.onnx \ | 33 | --encoder=/path/to/encoder.onnx \ |
| @@ -36,6 +38,30 @@ Usage: | @@ -36,6 +38,30 @@ Usage: | ||
| 36 | --decoding-method=greedy_search \ | 38 | --decoding-method=greedy_search \ |
| 37 | /path/to/foo.wav [bar.wav foobar.wav ...] | 39 | /path/to/foo.wav [bar.wav foobar.wav ...] |
| 38 | 40 | ||
| 41 | +(2) Streaming zipformer2 CTC | ||
| 42 | + | ||
| 43 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 44 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 45 | + | ||
| 46 | + ./bin/sherpa-onnx \ | ||
| 47 | + --debug=1 \ | ||
| 48 | + --zipformer2-ctc-model=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ | ||
| 49 | + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ | ||
| 50 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ | ||
| 51 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ | ||
| 52 | + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav | ||
| 53 | + | ||
| 54 | +(3) Streaming paraformer | ||
| 55 | + | ||
| 56 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 57 | + tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 58 | + | ||
| 59 | + ./bin/sherpa-onnx \ | ||
| 60 | + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ | ||
| 61 | + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \ | ||
| 62 | + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \ | ||
| 63 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav | ||
| 64 | + | ||
| 39 | Note: It supports decoding multiple files in batches | 65 | Note: It supports decoding multiple files in batches |
| 40 | 66 | ||
| 41 | Default value for num_threads is 2. | 67 | Default value for num_threads is 2. |
| @@ -8,9 +8,6 @@ | @@ -8,9 +8,6 @@ | ||
| 8 | #include <fstream> | 8 | #include <fstream> |
| 9 | #include <sstream> | 9 | #include <sstream> |
| 10 | 10 | ||
| 11 | -#include "sherpa-onnx/csrc/base64-decode.h" | ||
| 12 | -#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | - | ||
| 14 | #if __ANDROID_API__ >= 9 | 11 | #if __ANDROID_API__ >= 9 |
| 15 | #include <strstream> | 12 | #include <strstream> |
| 16 | 13 | ||
| @@ -18,6 +15,9 @@ | @@ -18,6 +15,9 @@ | ||
| 18 | #include "android/asset_manager_jni.h" | 15 | #include "android/asset_manager_jni.h" |
| 19 | #endif | 16 | #endif |
| 20 | 17 | ||
| 18 | +#include "sherpa-onnx/csrc/base64-decode.h" | ||
| 19 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 20 | + | ||
| 21 | namespace sherpa_onnx { | 21 | namespace sherpa_onnx { |
| 22 | 22 | ||
| 23 | SymbolTable::SymbolTable(const std::string &filename) { | 23 | SymbolTable::SymbolTable(const std::string &filename) { |
| @@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -262,22 +262,34 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 262 | fid = env->GetFieldID(model_config_cls, "paraformer", | 262 | fid = env->GetFieldID(model_config_cls, "paraformer", |
| 263 | "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); | 263 | "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); |
| 264 | jobject paraformer_config = env->GetObjectField(model_config, fid); | 264 | jobject paraformer_config = env->GetObjectField(model_config, fid); |
| 265 | - jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config); | 265 | + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); |
| 266 | 266 | ||
| 267 | - fid = env->GetFieldID(paraformer_config_config_cls, "encoder", | ||
| 268 | - "Ljava/lang/String;"); | 267 | + fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;"); |
| 269 | s = (jstring)env->GetObjectField(paraformer_config, fid); | 268 | s = (jstring)env->GetObjectField(paraformer_config, fid); |
| 270 | p = env->GetStringUTFChars(s, nullptr); | 269 | p = env->GetStringUTFChars(s, nullptr); |
| 271 | ans.model_config.paraformer.encoder = p; | 270 | ans.model_config.paraformer.encoder = p; |
| 272 | env->ReleaseStringUTFChars(s, p); | 271 | env->ReleaseStringUTFChars(s, p); |
| 273 | 272 | ||
| 274 | - fid = env->GetFieldID(paraformer_config_config_cls, "decoder", | ||
| 275 | - "Ljava/lang/String;"); | 273 | + fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;"); |
| 276 | s = (jstring)env->GetObjectField(paraformer_config, fid); | 274 | s = (jstring)env->GetObjectField(paraformer_config, fid); |
| 277 | p = env->GetStringUTFChars(s, nullptr); | 275 | p = env->GetStringUTFChars(s, nullptr); |
| 278 | ans.model_config.paraformer.decoder = p; | 276 | ans.model_config.paraformer.decoder = p; |
| 279 | env->ReleaseStringUTFChars(s, p); | 277 | env->ReleaseStringUTFChars(s, p); |
| 280 | 278 | ||
| 279 | + // streaming zipformer2 CTC | ||
| 280 | + fid = | ||
| 281 | + env->GetFieldID(model_config_cls, "zipformer2Ctc", | ||
| 282 | + "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;"); | ||
| 283 | + jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid); | ||
| 284 | + jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config); | ||
| 285 | + | ||
| 286 | + fid = | ||
| 287 | + env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;"); | ||
| 288 | + s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid); | ||
| 289 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 290 | + ans.model_config.zipformer2_ctc.model = p; | ||
| 291 | + env->ReleaseStringUTFChars(s, p); | ||
| 292 | + | ||
| 281 | fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); | 293 | fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); |
| 282 | s = (jstring)env->GetObjectField(model_config, fid); | 294 | s = (jstring)env->GetObjectField(model_config, fid); |
| 283 | p = env->GetStringUTFChars(s, nullptr); | 295 | p = env->GetStringUTFChars(s, nullptr); |
| @@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx | @@ -27,6 +27,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 27 | online-stream.cc | 27 | online-stream.cc |
| 28 | online-transducer-model-config.cc | 28 | online-transducer-model-config.cc |
| 29 | online-wenet-ctc-model-config.cc | 29 | online-wenet-ctc-model-config.cc |
| 30 | + online-zipformer2-ctc-model-config.cc | ||
| 30 | sherpa-onnx.cc | 31 | sherpa-onnx.cc |
| 31 | silero-vad-model-config.cc | 32 | silero-vad-model-config.cc |
| 32 | vad-model-config.cc | 33 | vad-model-config.cc |
| @@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -58,6 +58,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 58 | .def_readwrite("debug", &PyClass::debug) | 58 | .def_readwrite("debug", &PyClass::debug) |
| 59 | .def_readwrite("provider", &PyClass::provider) | 59 | .def_readwrite("provider", &PyClass::provider) |
| 60 | .def_readwrite("model_type", &PyClass::model_type) | 60 | .def_readwrite("model_type", &PyClass::model_type) |
| 61 | + .def("validate", &PyClass::Validate) | ||
| 61 | .def("__str__", &PyClass::ToString); | 62 | .def("__str__", &PyClass::ToString); |
| 62 | } | 63 | } |
| 63 | 64 |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" |
| 15 | +#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" | ||
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 17 | 18 | ||
| @@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -19,26 +20,31 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 19 | PybindOnlineTransducerModelConfig(m); | 20 | PybindOnlineTransducerModelConfig(m); |
| 20 | PybindOnlineParaformerModelConfig(m); | 21 | PybindOnlineParaformerModelConfig(m); |
| 21 | PybindOnlineWenetCtcModelConfig(m); | 22 | PybindOnlineWenetCtcModelConfig(m); |
| 23 | + PybindOnlineZipformer2CtcModelConfig(m); | ||
| 22 | 24 | ||
| 23 | using PyClass = OnlineModelConfig; | 25 | using PyClass = OnlineModelConfig; |
| 24 | py::class_<PyClass>(*m, "OnlineModelConfig") | 26 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| 25 | .def(py::init<const OnlineTransducerModelConfig &, | 27 | .def(py::init<const OnlineTransducerModelConfig &, |
| 26 | const OnlineParaformerModelConfig &, | 28 | const OnlineParaformerModelConfig &, |
| 27 | - const OnlineWenetCtcModelConfig &, const std::string &, | 29 | + const OnlineWenetCtcModelConfig &, |
| 30 | + const OnlineZipformer2CtcModelConfig &, const std::string &, | ||
| 28 | int32_t, bool, const std::string &, const std::string &>(), | 31 | int32_t, bool, const std::string &, const std::string &>(), |
| 29 | py::arg("transducer") = OnlineTransducerModelConfig(), | 32 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 30 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 33 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 31 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 34 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
| 35 | + py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), | ||
| 32 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 36 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 33 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 37 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 34 | .def_readwrite("transducer", &PyClass::transducer) | 38 | .def_readwrite("transducer", &PyClass::transducer) |
| 35 | .def_readwrite("paraformer", &PyClass::paraformer) | 39 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 36 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 40 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 41 | + .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) | ||
| 37 | .def_readwrite("tokens", &PyClass::tokens) | 42 | .def_readwrite("tokens", &PyClass::tokens) |
| 38 | .def_readwrite("num_threads", &PyClass::num_threads) | 43 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 39 | .def_readwrite("debug", &PyClass::debug) | 44 | .def_readwrite("debug", &PyClass::debug) |
| 40 | .def_readwrite("provider", &PyClass::provider) | 45 | .def_readwrite("provider", &PyClass::provider) |
| 41 | .def_readwrite("model_type", &PyClass::model_type) | 46 | .def_readwrite("model_type", &PyClass::model_type) |
| 47 | + .def("validate", &PyClass::Validate) | ||
| 42 | .def("__str__", &PyClass::ToString); | 48 | .def("__str__", &PyClass::ToString); |
| 43 | } | 49 | } |
| 44 | 50 |
| 1 | +// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOnlineZipformer2CtcModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OnlineZipformer2CtcModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OnlineZipformer2CtcModelConfig") | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineZipformer2CtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ |
| @@ -8,11 +8,14 @@ from _sherpa_onnx import ( | @@ -8,11 +8,14 @@ from _sherpa_onnx import ( | ||
| 8 | OnlineLMConfig, | 8 | OnlineLMConfig, |
| 9 | OnlineModelConfig, | 9 | OnlineModelConfig, |
| 10 | OnlineParaformerModelConfig, | 10 | OnlineParaformerModelConfig, |
| 11 | - OnlineRecognizer as _Recognizer, | 11 | +) |
| 12 | +from _sherpa_onnx import OnlineRecognizer as _Recognizer | ||
| 13 | +from _sherpa_onnx import ( | ||
| 12 | OnlineRecognizerConfig, | 14 | OnlineRecognizerConfig, |
| 13 | OnlineStream, | 15 | OnlineStream, |
| 14 | OnlineTransducerModelConfig, | 16 | OnlineTransducerModelConfig, |
| 15 | OnlineWenetCtcModelConfig, | 17 | OnlineWenetCtcModelConfig, |
| 18 | + OnlineZipformer2CtcModelConfig, | ||
| 16 | ) | 19 | ) |
| 17 | 20 | ||
| 18 | 21 | ||
| @@ -273,6 +276,101 @@ class OnlineRecognizer(object): | @@ -273,6 +276,101 @@ class OnlineRecognizer(object): | ||
| 273 | return self | 276 | return self |
| 274 | 277 | ||
| 275 | @classmethod | 278 | @classmethod |
| 279 | + def from_zipformer2_ctc( | ||
| 280 | + cls, | ||
| 281 | + tokens: str, | ||
| 282 | + model: str, | ||
| 283 | + num_threads: int = 2, | ||
| 284 | + sample_rate: float = 16000, | ||
| 285 | + feature_dim: int = 80, | ||
| 286 | + enable_endpoint_detection: bool = False, | ||
| 287 | + rule1_min_trailing_silence: float = 2.4, | ||
| 288 | + rule2_min_trailing_silence: float = 1.2, | ||
| 289 | + rule3_min_utterance_length: float = 20.0, | ||
| 290 | + decoding_method: str = "greedy_search", | ||
| 291 | + provider: str = "cpu", | ||
| 292 | + ): | ||
| 293 | + """ | ||
| 294 | + Please refer to | ||
| 295 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/index.html>`_ | ||
| 296 | + to download pre-trained models for different languages, e.g., Chinese, | ||
| 297 | + English, etc. | ||
| 298 | + | ||
| 299 | + Args: | ||
| 300 | + tokens: | ||
| 301 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 302 | + columns:: | ||
| 303 | + | ||
| 304 | + symbol integer_id | ||
| 305 | + | ||
| 306 | + model: | ||
| 307 | + Path to ``model.onnx``. | ||
| 308 | + num_threads: | ||
| 309 | + Number of threads for neural network computation. | ||
| 310 | + sample_rate: | ||
| 311 | + Sample rate of the training data used to train the model. | ||
| 312 | + feature_dim: | ||
| 313 | + Dimension of the feature used to train the model. | ||
| 314 | + enable_endpoint_detection: | ||
| 315 | + True to enable endpoint detection. False to disable endpoint | ||
| 316 | + detection. | ||
| 317 | + rule1_min_trailing_silence: | ||
| 318 | + Used only when enable_endpoint_detection is True. If the duration | ||
| 319 | + of trailing silence in seconds is larger than this value, we assume | ||
| 320 | + an endpoint is detected. | ||
| 321 | + rule2_min_trailing_silence: | ||
| 322 | + Used only when enable_endpoint_detection is True. If we have decoded | ||
| 323 | + something that is nonsilence and if the duration of trailing silence | ||
| 324 | + in seconds is larger than this value, we assume an endpoint is | ||
| 325 | + detected. | ||
| 326 | + rule3_min_utterance_length: | ||
| 327 | + Used only when enable_endpoint_detection is True. If the utterance | ||
| 328 | + length in seconds is larger than this value, we assume an endpoint | ||
| 329 | + is detected. | ||
| 330 | + decoding_method: | ||
| 331 | + The only valid value is greedy_search. | ||
| 332 | + provider: | ||
| 333 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 334 | + """ | ||
| 335 | + self = cls.__new__(cls) | ||
| 336 | + _assert_file_exists(tokens) | ||
| 337 | + _assert_file_exists(model) | ||
| 338 | + | ||
| 339 | + assert num_threads > 0, num_threads | ||
| 340 | + | ||
| 341 | + zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) | ||
| 342 | + | ||
| 343 | + model_config = OnlineModelConfig( | ||
| 344 | + zipformer2_ctc=zipformer2_ctc_config, | ||
| 345 | + tokens=tokens, | ||
| 346 | + num_threads=num_threads, | ||
| 347 | + provider=provider, | ||
| 348 | + ) | ||
| 349 | + | ||
| 350 | + feat_config = FeatureExtractorConfig( | ||
| 351 | + sampling_rate=sample_rate, | ||
| 352 | + feature_dim=feature_dim, | ||
| 353 | + ) | ||
| 354 | + | ||
| 355 | + endpoint_config = EndpointConfig( | ||
| 356 | + rule1_min_trailing_silence=rule1_min_trailing_silence, | ||
| 357 | + rule2_min_trailing_silence=rule2_min_trailing_silence, | ||
| 358 | + rule3_min_utterance_length=rule3_min_utterance_length, | ||
| 359 | + ) | ||
| 360 | + | ||
| 361 | + recognizer_config = OnlineRecognizerConfig( | ||
| 362 | + feat_config=feat_config, | ||
| 363 | + model_config=model_config, | ||
| 364 | + endpoint_config=endpoint_config, | ||
| 365 | + enable_endpoint=enable_endpoint_detection, | ||
| 366 | + decoding_method=decoding_method, | ||
| 367 | + ) | ||
| 368 | + | ||
| 369 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 370 | + self.config = recognizer_config | ||
| 371 | + return self | ||
| 372 | + | ||
| 373 | + @classmethod | ||
| 276 | def from_wenet_ctc( | 374 | def from_wenet_ctc( |
| 277 | cls, | 375 | cls, |
| 278 | tokens: str, | 376 | tokens: str, |
| @@ -352,7 +450,6 @@ class OnlineRecognizer(object): | @@ -352,7 +450,6 @@ class OnlineRecognizer(object): | ||
| 352 | tokens=tokens, | 450 | tokens=tokens, |
| 353 | num_threads=num_threads, | 451 | num_threads=num_threads, |
| 354 | provider=provider, | 452 | provider=provider, |
| 355 | - model_type="wenet_ctc", | ||
| 356 | ) | 453 | ) |
| 357 | 454 | ||
| 358 | feat_config = FeatureExtractorConfig( | 455 | feat_config = FeatureExtractorConfig( |
| @@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase): | @@ -143,6 +143,57 @@ class TestOnlineRecognizer(unittest.TestCase): | ||
| 143 | print(f"{wave_filename}\n{result}") | 143 | print(f"{wave_filename}\n{result}") |
| 144 | print("-" * 10) | 144 | print("-" * 10) |
| 145 | 145 | ||
| 146 | + def test_zipformer2_ctc(self): | ||
| 147 | + m = "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13" | ||
| 148 | + for use_int8 in [True, False]: | ||
| 149 | + name = ( | ||
| 150 | + "ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx" | ||
| 151 | + if use_int8 | ||
| 152 | + else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx" | ||
| 153 | + ) | ||
| 154 | + model = f"{d}/{m}/{name}" | ||
| 155 | + tokens = f"{d}/{m}/tokens.txt" | ||
| 156 | + wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav" | ||
| 157 | + wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav" | ||
| 158 | + wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav" | ||
| 159 | + if not Path(model).is_file(): | ||
| 160 | + print("skipping test_zipformer2_ctc()") | ||
| 161 | + return | ||
| 162 | + print(f"testing {model}") | ||
| 163 | + | ||
| 164 | + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( | ||
| 165 | + model=model, | ||
| 166 | + tokens=tokens, | ||
| 167 | + num_threads=1, | ||
| 168 | + provider="cpu", | ||
| 169 | + ) | ||
| 170 | + | ||
| 171 | + streams = [] | ||
| 172 | + waves = [wave0, wave1, wave2] | ||
| 173 | + for wave in waves: | ||
| 174 | + s = recognizer.create_stream() | ||
| 175 | + samples, sample_rate = read_wave(wave) | ||
| 176 | + s.accept_waveform(sample_rate, samples) | ||
| 177 | + | ||
| 178 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 179 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 180 | + s.input_finished() | ||
| 181 | + streams.append(s) | ||
| 182 | + | ||
| 183 | + while True: | ||
| 184 | + ready_list = [] | ||
| 185 | + for s in streams: | ||
| 186 | + if recognizer.is_ready(s): | ||
| 187 | + ready_list.append(s) | ||
| 188 | + if len(ready_list) == 0: | ||
| 189 | + break | ||
| 190 | + recognizer.decode_streams(ready_list) | ||
| 191 | + | ||
| 192 | + results = [recognizer.get_result(s) for s in streams] | ||
| 193 | + for wave_filename, result in zip(waves, results): | ||
| 194 | + print(f"{wave_filename}\n{result}") | ||
| 195 | + print("-" * 10) | ||
| 196 | + | ||
| 146 | def test_wenet_ctc(self): | 197 | def test_wenet_ctc(self): |
| 147 | models = [ | 198 | models = [ |
| 148 | "sherpa-onnx-zh-wenet-aishell", | 199 | "sherpa-onnx-zh-wenet-aishell", |
| @@ -60,6 +60,14 @@ func sherpaOnnxOnlineParaformerModelConfig( | @@ -60,6 +60,14 @@ func sherpaOnnxOnlineParaformerModelConfig( | ||
| 60 | ) | 60 | ) |
| 61 | } | 61 | } |
| 62 | 62 | ||
| 63 | +func sherpaOnnxOnlineZipformer2CtcModelConfig( | ||
| 64 | + model: String = "" | ||
| 65 | +) -> SherpaOnnxOnlineZipformer2CtcModelConfig { | ||
| 66 | + return SherpaOnnxOnlineZipformer2CtcModelConfig( | ||
| 67 | + model: toCPointer(model) | ||
| 68 | + ) | ||
| 69 | +} | ||
| 70 | + | ||
| 63 | /// Return an instance of SherpaOnnxOnlineModelConfig. | 71 | /// Return an instance of SherpaOnnxOnlineModelConfig. |
| 64 | /// | 72 | /// |
| 65 | /// Please refer to | 73 | /// Please refer to |
| @@ -75,6 +83,8 @@ func sherpaOnnxOnlineModelConfig( | @@ -75,6 +83,8 @@ func sherpaOnnxOnlineModelConfig( | ||
| 75 | tokens: String, | 83 | tokens: String, |
| 76 | transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(), | 84 | transducer: SherpaOnnxOnlineTransducerModelConfig = sherpaOnnxOnlineTransducerModelConfig(), |
| 77 | paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(), | 85 | paraformer: SherpaOnnxOnlineParaformerModelConfig = sherpaOnnxOnlineParaformerModelConfig(), |
| 86 | + zipformer2Ctc: SherpaOnnxOnlineZipformer2CtcModelConfig = | ||
| 87 | + sherpaOnnxOnlineZipformer2CtcModelConfig(), | ||
| 78 | numThreads: Int = 1, | 88 | numThreads: Int = 1, |
| 79 | provider: String = "cpu", | 89 | provider: String = "cpu", |
| 80 | debug: Int = 0, | 90 | debug: Int = 0, |
| @@ -83,6 +93,7 @@ func sherpaOnnxOnlineModelConfig( | @@ -83,6 +93,7 @@ func sherpaOnnxOnlineModelConfig( | ||
| 83 | return SherpaOnnxOnlineModelConfig( | 93 | return SherpaOnnxOnlineModelConfig( |
| 84 | transducer: transducer, | 94 | transducer: transducer, |
| 85 | paraformer: paraformer, | 95 | paraformer: paraformer, |
| 96 | + zipformer2_ctc: zipformer2Ctc, | ||
| 86 | tokens: toCPointer(tokens), | 97 | tokens: toCPointer(tokens), |
| 87 | num_threads: Int32(numThreads), | 98 | num_threads: Int32(numThreads), |
| 88 | provider: toCPointer(provider), | 99 | provider: toCPointer(provider), |
| @@ -13,24 +13,47 @@ extension AVAudioPCMBuffer { | @@ -13,24 +13,47 @@ extension AVAudioPCMBuffer { | ||
| 13 | } | 13 | } |
| 14 | 14 | ||
| 15 | func run() { | 15 | func run() { |
| 16 | - let encoder = | ||
| 17 | - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" | ||
| 18 | - let decoder = | ||
| 19 | - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" | ||
| 20 | - let joiner = | ||
| 21 | - "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" | ||
| 22 | - let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" | ||
| 23 | - | ||
| 24 | - let transducerConfig = sherpaOnnxOnlineTransducerModelConfig( | ||
| 25 | - encoder: encoder, | ||
| 26 | - decoder: decoder, | ||
| 27 | - joiner: joiner | ||
| 28 | - ) | 16 | + var modelConfig: SherpaOnnxOnlineModelConfig |
| 17 | + var modelType = "zipformer2-ctc" | ||
| 18 | + var filePath: String | ||
| 29 | 19 | ||
| 30 | - let modelConfig = sherpaOnnxOnlineModelConfig( | ||
| 31 | - tokens: tokens, | ||
| 32 | - transducer: transducerConfig | ||
| 33 | - ) | 20 | + modelType = "transducer" |
| 21 | + | ||
| 22 | + if modelType == "transducer" { | ||
| 23 | + filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" | ||
| 24 | + let encoder = | ||
| 25 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" | ||
| 26 | + let decoder = | ||
| 27 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" | ||
| 28 | + let joiner = | ||
| 29 | + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" | ||
| 30 | + let tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" | ||
| 31 | + | ||
| 32 | + let transducerConfig = sherpaOnnxOnlineTransducerModelConfig( | ||
| 33 | + encoder: encoder, | ||
| 34 | + decoder: decoder, | ||
| 35 | + joiner: joiner | ||
| 36 | + ) | ||
| 37 | + | ||
| 38 | + modelConfig = sherpaOnnxOnlineModelConfig( | ||
| 39 | + tokens: tokens, | ||
| 40 | + transducer: transducerConfig | ||
| 41 | + ) | ||
| 42 | + } else { | ||
| 43 | + filePath = | ||
| 44 | + "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" | ||
| 45 | + let model = | ||
| 46 | + "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx" | ||
| 47 | + let tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt" | ||
| 48 | + let zipfomer2CtcModelConfig = sherpaOnnxOnlineZipformer2CtcModelConfig( | ||
| 49 | + model: model | ||
| 50 | + ) | ||
| 51 | + | ||
| 52 | + modelConfig = sherpaOnnxOnlineModelConfig( | ||
| 53 | + tokens: tokens, | ||
| 54 | + zipformer2Ctc: zipfomer2CtcModelConfig | ||
| 55 | + ) | ||
| 56 | + } | ||
| 34 | 57 | ||
| 35 | let featConfig = sherpaOnnxFeatureConfig( | 58 | let featConfig = sherpaOnnxFeatureConfig( |
| 36 | sampleRate: 16000, | 59 | sampleRate: 16000, |
| @@ -43,7 +66,6 @@ func run() { | @@ -43,7 +66,6 @@ func run() { | ||
| 43 | 66 | ||
| 44 | let recognizer = SherpaOnnxRecognizer(config: &config) | 67 | let recognizer = SherpaOnnxRecognizer(config: &config) |
| 45 | 68 | ||
| 46 | - let filePath = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" | ||
| 47 | let fileURL: NSURL = NSURL(fileURLWithPath: filePath) | 69 | let fileURL: NSURL = NSURL(fileURLWithPath: filePath) |
| 48 | let audioFile = try! AVAudioFile(forReading: fileURL as URL) | 70 | let audioFile = try! AVAudioFile(forReading: fileURL as URL) |
| 49 | 71 |
| @@ -20,6 +20,12 @@ if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then | @@ -20,6 +20,12 @@ if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then | ||
| 20 | rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | 20 | rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 |
| 21 | fi | 21 | fi |
| 22 | 22 | ||
| 23 | +if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then | ||
| 24 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 25 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 26 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 27 | +fi | ||
| 28 | + | ||
| 23 | if [ ! -e ./decode-file ]; then | 29 | if [ ! -e ./decode-file ]; then |
| 24 | # Note: We use -lc++ to link against libc++ instead of libstdc++ | 30 | # Note: We use -lc++ to link against libc++ instead of libstdc++ |
| 25 | swiftc \ | 31 | swiftc \ |
| @@ -22,7 +22,7 @@ if [ ! -d ./sherpa-onnx-whisper-tiny.en ]; then | @@ -22,7 +22,7 @@ if [ ! -d ./sherpa-onnx-whisper-tiny.en ]; then | ||
| 22 | fi | 22 | fi |
| 23 | if [ ! -f ./silero_vad.onnx ]; then | 23 | if [ ! -f ./silero_vad.onnx ]; then |
| 24 | echo "downloading silero_vad" | 24 | echo "downloading silero_vad" |
| 25 | - wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 25 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| 26 | fi | 26 | fi |
| 27 | 27 | ||
| 28 | if [ ! -e ./generate-subtitles ]; then | 28 | if [ ! -e ./generate-subtitles ]; then |
-
请 注册 或 登录 后发表评论