正在显示
38 个修改的文件
包含
1488 行增加
和
112 行删除
.github/scripts/test-online-paraformer.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run streaming Paraformer" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | + | ||
| 20 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 21 | +log "Start testing ${repo_url}" | ||
| 22 | +repo=$(basename $repo_url) | ||
| 23 | +log "Download pretrained model and test-data from $repo_url" | ||
| 24 | + | ||
| 25 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 26 | +pushd $repo | ||
| 27 | +git lfs pull --include "*.onnx" | ||
| 28 | +ls -lh *.onnx | ||
| 29 | +popd | ||
| 30 | + | ||
| 31 | +time $EXE \ | ||
| 32 | + --tokens=$repo/tokens.txt \ | ||
| 33 | + --paraformer-encoder=$repo/encoder.onnx \ | ||
| 34 | + --paraformer-decoder=$repo/decoder.onnx \ | ||
| 35 | + --num-threads=2 \ | ||
| 36 | + $repo/test_wavs/0.wav \ | ||
| 37 | + $repo/test_wavs/1.wav \ | ||
| 38 | + $repo/test_wavs/2.wav \ | ||
| 39 | + $repo/test_wavs/3.wav \ | ||
| 40 | + $repo/test_wavs/8k.wav | ||
| 41 | + | ||
| 42 | +time $EXE \ | ||
| 43 | + --tokens=$repo/tokens.txt \ | ||
| 44 | + --paraformer-encoder=$repo/encoder.int8.onnx \ | ||
| 45 | + --paraformer-decoder=$repo/decoder.int8.onnx \ | ||
| 46 | + --num-threads=2 \ | ||
| 47 | + $repo/test_wavs/0.wav \ | ||
| 48 | + $repo/test_wavs/1.wav \ | ||
| 49 | + $repo/test_wavs/2.wav \ | ||
| 50 | + $repo/test_wavs/3.wav \ | ||
| 51 | + $repo/test_wavs/8k.wav | ||
| 52 | + | ||
| 53 | +rm -rf $repo |
| @@ -9,6 +9,7 @@ on: | @@ -9,6 +9,7 @@ on: | ||
| 9 | paths: | 9 | paths: |
| 10 | - '.github/workflows/linux-gpu.yaml' | 10 | - '.github/workflows/linux-gpu.yaml' |
| 11 | - '.github/scripts/test-online-transducer.sh' | 11 | - '.github/scripts/test-online-transducer.sh' |
| 12 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 12 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 13 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 14 | - 'CMakeLists.txt' | 15 | - 'CMakeLists.txt' |
| @@ -22,6 +23,7 @@ on: | @@ -22,6 +23,7 @@ on: | ||
| 22 | paths: | 23 | paths: |
| 23 | - '.github/workflows/linux-gpu.yaml' | 24 | - '.github/workflows/linux-gpu.yaml' |
| 24 | - '.github/scripts/test-online-transducer.sh' | 25 | - '.github/scripts/test-online-transducer.sh' |
| 26 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 25 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 26 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 27 | - 'CMakeLists.txt' | 29 | - 'CMakeLists.txt' |
| @@ -85,6 +87,14 @@ jobs: | @@ -85,6 +87,14 @@ jobs: | ||
| 85 | file build/bin/sherpa-onnx | 87 | file build/bin/sherpa-onnx |
| 86 | readelf -d build/bin/sherpa-onnx | 88 | readelf -d build/bin/sherpa-onnx |
| 87 | 89 | ||
| 90 | + - name: Test online paraformer | ||
| 91 | + shell: bash | ||
| 92 | + run: | | ||
| 93 | + export PATH=$PWD/build/bin:$PATH | ||
| 94 | + export EXE=sherpa-onnx | ||
| 95 | + | ||
| 96 | + .github/scripts/test-online-paraformer.sh | ||
| 97 | + | ||
| 88 | - name: Test offline Whisper | 98 | - name: Test offline Whisper |
| 89 | shell: bash | 99 | shell: bash |
| 90 | run: | | 100 | run: | |
| @@ -9,6 +9,7 @@ on: | @@ -9,6 +9,7 @@ on: | ||
| 9 | paths: | 9 | paths: |
| 10 | - '.github/workflows/linux.yaml' | 10 | - '.github/workflows/linux.yaml' |
| 11 | - '.github/scripts/test-online-transducer.sh' | 11 | - '.github/scripts/test-online-transducer.sh' |
| 12 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 12 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 13 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 14 | - 'CMakeLists.txt' | 15 | - 'CMakeLists.txt' |
| @@ -22,6 +23,7 @@ on: | @@ -22,6 +23,7 @@ on: | ||
| 22 | paths: | 23 | paths: |
| 23 | - '.github/workflows/linux.yaml' | 24 | - '.github/workflows/linux.yaml' |
| 24 | - '.github/scripts/test-online-transducer.sh' | 25 | - '.github/scripts/test-online-transducer.sh' |
| 26 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 25 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 26 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 27 | - 'CMakeLists.txt' | 29 | - 'CMakeLists.txt' |
| @@ -84,6 +86,14 @@ jobs: | @@ -84,6 +86,14 @@ jobs: | ||
| 84 | file build/bin/sherpa-onnx | 86 | file build/bin/sherpa-onnx |
| 85 | readelf -d build/bin/sherpa-onnx | 87 | readelf -d build/bin/sherpa-onnx |
| 86 | 88 | ||
| 89 | + - name: Test online paraformer | ||
| 90 | + shell: bash | ||
| 91 | + run: | | ||
| 92 | + export PATH=$PWD/build/bin:$PATH | ||
| 93 | + export EXE=sherpa-onnx | ||
| 94 | + | ||
| 95 | + .github/scripts/test-online-paraformer.sh | ||
| 96 | + | ||
| 87 | - name: Test offline Whisper | 97 | - name: Test offline Whisper |
| 88 | shell: bash | 98 | shell: bash |
| 89 | run: | | 99 | run: | |
| @@ -7,6 +7,7 @@ on: | @@ -7,6 +7,7 @@ on: | ||
| 7 | paths: | 7 | paths: |
| 8 | - '.github/workflows/macos.yaml' | 8 | - '.github/workflows/macos.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 10 | - '.github/scripts/test-offline-transducer.sh' | 11 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | - '.github/scripts/test-offline-ctc.sh' | 12 | - '.github/scripts/test-offline-ctc.sh' |
| 12 | - 'CMakeLists.txt' | 13 | - 'CMakeLists.txt' |
| @@ -18,6 +19,7 @@ on: | @@ -18,6 +19,7 @@ on: | ||
| 18 | paths: | 19 | paths: |
| 19 | - '.github/workflows/macos.yaml' | 20 | - '.github/workflows/macos.yaml' |
| 20 | - '.github/scripts/test-online-transducer.sh' | 21 | - '.github/scripts/test-online-transducer.sh' |
| 22 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 21 | - '.github/scripts/test-offline-transducer.sh' | 23 | - '.github/scripts/test-offline-transducer.sh' |
| 22 | - '.github/scripts/test-offline-ctc.sh' | 24 | - '.github/scripts/test-offline-ctc.sh' |
| 23 | - 'CMakeLists.txt' | 25 | - 'CMakeLists.txt' |
| @@ -82,6 +84,14 @@ jobs: | @@ -82,6 +84,14 @@ jobs: | ||
| 82 | otool -L build/bin/sherpa-onnx | 84 | otool -L build/bin/sherpa-onnx |
| 83 | otool -l build/bin/sherpa-onnx | 85 | otool -l build/bin/sherpa-onnx |
| 84 | 86 | ||
| 87 | + - name: Test online paraformer | ||
| 88 | + shell: bash | ||
| 89 | + run: | | ||
| 90 | + export PATH=$PWD/build/bin:$PATH | ||
| 91 | + export EXE=sherpa-onnx | ||
| 92 | + | ||
| 93 | + .github/scripts/test-online-paraformer.sh | ||
| 94 | + | ||
| 85 | - name: Test offline Whisper | 95 | - name: Test offline Whisper |
| 86 | shell: bash | 96 | shell: bash |
| 87 | run: | | 97 | run: | |
| @@ -58,7 +58,6 @@ jobs: | @@ -58,7 +58,6 @@ jobs: | ||
| 58 | sherpa-onnx-microphone-offline --help | 58 | sherpa-onnx-microphone-offline --help |
| 59 | 59 | ||
| 60 | sherpa-onnx-offline-websocket-server --help | 60 | sherpa-onnx-offline-websocket-server --help |
| 61 | - sherpa-onnx-offline-websocket-client --help | ||
| 62 | 61 | ||
| 63 | sherpa-onnx-online-websocket-server --help | 62 | sherpa-onnx-online-websocket-server --help |
| 64 | sherpa-onnx-online-websocket-client --help | 63 | sherpa-onnx-online-websocket-client --help |
| @@ -84,14 +84,14 @@ jobs: | @@ -84,14 +84,14 @@ jobs: | ||
| 84 | if: matrix.model_type == 'paraformer' | 84 | if: matrix.model_type == 'paraformer' |
| 85 | shell: bash | 85 | shell: bash |
| 86 | run: | | 86 | run: | |
| 87 | - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 88 | - cd sherpa-onnx-paraformer-zh-2023-03-28 | 87 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en |
| 88 | + cd sherpa-onnx-paraformer-bilingual-zh-en | ||
| 89 | git lfs pull --include "*.onnx" | 89 | git lfs pull --include "*.onnx" |
| 90 | cd .. | 90 | cd .. |
| 91 | 91 | ||
| 92 | python3 ./python-api-examples/non_streaming_server.py \ | 92 | python3 ./python-api-examples/non_streaming_server.py \ |
| 93 | - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ | ||
| 94 | - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt & | 93 | + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \ |
| 94 | + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt & | ||
| 95 | 95 | ||
| 96 | echo "sleep 10 seconds to wait the server start" | 96 | echo "sleep 10 seconds to wait the server start" |
| 97 | sleep 10 | 97 | sleep 10 |
| @@ -101,16 +101,16 @@ jobs: | @@ -101,16 +101,16 @@ jobs: | ||
| 101 | shell: bash | 101 | shell: bash |
| 102 | run: | | 102 | run: | |
| 103 | python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ | 103 | python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ |
| 104 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ | ||
| 105 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ | ||
| 106 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ | ||
| 107 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav | 104 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \ |
| 105 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \ | ||
| 106 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \ | ||
| 107 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav | ||
| 108 | 108 | ||
| 109 | python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ | 109 | python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ |
| 110 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \ | ||
| 111 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \ | ||
| 112 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \ | ||
| 113 | - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav | 110 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \ |
| 111 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \ | ||
| 112 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \ | ||
| 113 | + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav | ||
| 114 | 114 | ||
| 115 | - name: Start server for nemo_ctc models | 115 | - name: Start server for nemo_ctc models |
| 116 | if: matrix.model_type == 'nemo_ctc' | 116 | if: matrix.model_type == 'nemo_ctc' |
| @@ -24,7 +24,7 @@ jobs: | @@ -24,7 +24,7 @@ jobs: | ||
| 24 | matrix: | 24 | matrix: |
| 25 | os: [ubuntu-latest, windows-latest, macos-latest] | 25 | os: [ubuntu-latest, windows-latest, macos-latest] |
| 26 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | 26 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] |
| 27 | - model_type: ["transducer"] | 27 | + model_type: ["transducer", "paraformer"] |
| 28 | 28 | ||
| 29 | steps: | 29 | steps: |
| 30 | - uses: actions/checkout@v2 | 30 | - uses: actions/checkout@v2 |
| @@ -71,3 +71,36 @@ jobs: | @@ -71,3 +71,36 @@ jobs: | ||
| 71 | run: | | 71 | run: | |
| 72 | python3 ./python-api-examples/online-websocket-client-decode-file.py \ | 72 | python3 ./python-api-examples/online-websocket-client-decode-file.py \ |
| 73 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav | 73 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav |
| 74 | + | ||
| 75 | + - name: Start server for paraformer models | ||
| 76 | + if: matrix.model_type == 'paraformer' | ||
| 77 | + shell: bash | ||
| 78 | + run: | | ||
| 79 | + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 80 | + cd sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 81 | + git lfs pull --include "*.onnx" | ||
| 82 | + cd .. | ||
| 83 | + | ||
| 84 | + python3 ./python-api-examples/streaming_server.py \ | ||
| 85 | + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ | ||
| 86 | + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ | ||
| 87 | + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx & | ||
| 88 | + | ||
| 89 | + echo "sleep 10 seconds to wait the server start" | ||
| 90 | + sleep 10 | ||
| 91 | + | ||
| 92 | + - name: Start client for paraformer models | ||
| 93 | + if: matrix.model_type == 'paraformer' | ||
| 94 | + shell: bash | ||
| 95 | + run: | | ||
| 96 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 97 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav | ||
| 98 | + | ||
| 99 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 100 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav | ||
| 101 | + | ||
| 102 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 103 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav | ||
| 104 | + | ||
| 105 | + python3 ./python-api-examples/online-websocket-client-decode-file.py \ | ||
| 106 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav |
| @@ -9,6 +9,7 @@ on: | @@ -9,6 +9,7 @@ on: | ||
| 9 | paths: | 9 | paths: |
| 10 | - '.github/workflows/windows-x64-cuda.yaml' | 10 | - '.github/workflows/windows-x64-cuda.yaml' |
| 11 | - '.github/scripts/test-online-transducer.sh' | 11 | - '.github/scripts/test-online-transducer.sh' |
| 12 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 12 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 13 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 14 | - 'CMakeLists.txt' | 15 | - 'CMakeLists.txt' |
| @@ -20,6 +21,7 @@ on: | @@ -20,6 +21,7 @@ on: | ||
| 20 | paths: | 21 | paths: |
| 21 | - '.github/workflows/windows-x64-cuda.yaml' | 22 | - '.github/workflows/windows-x64-cuda.yaml' |
| 22 | - '.github/scripts/test-online-transducer.sh' | 23 | - '.github/scripts/test-online-transducer.sh' |
| 24 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 23 | - '.github/scripts/test-offline-transducer.sh' | 25 | - '.github/scripts/test-offline-transducer.sh' |
| 24 | - '.github/scripts/test-offline-ctc.sh' | 26 | - '.github/scripts/test-offline-ctc.sh' |
| 25 | - 'CMakeLists.txt' | 27 | - 'CMakeLists.txt' |
| @@ -74,6 +76,14 @@ jobs: | @@ -74,6 +76,14 @@ jobs: | ||
| 74 | 76 | ||
| 75 | ls -lh ./bin/Release/sherpa-onnx.exe | 77 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 76 | 78 | ||
| 79 | + - name: Test online paraformer for windows x64 | ||
| 80 | + shell: bash | ||
| 81 | + run: | | ||
| 82 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 83 | + export EXE=sherpa-onnx.exe | ||
| 84 | + | ||
| 85 | + .github/scripts/test-online-paraformer.sh | ||
| 86 | + | ||
| 77 | - name: Test offline Whisper for windows x64 | 87 | - name: Test offline Whisper for windows x64 |
| 78 | shell: bash | 88 | shell: bash |
| 79 | run: | | 89 | run: | |
| @@ -9,6 +9,7 @@ on: | @@ -9,6 +9,7 @@ on: | ||
| 9 | paths: | 9 | paths: |
| 10 | - '.github/workflows/windows-x64.yaml' | 10 | - '.github/workflows/windows-x64.yaml' |
| 11 | - '.github/scripts/test-online-transducer.sh' | 11 | - '.github/scripts/test-online-transducer.sh' |
| 12 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 12 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 13 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 14 | - 'CMakeLists.txt' | 15 | - 'CMakeLists.txt' |
| @@ -20,6 +21,7 @@ on: | @@ -20,6 +21,7 @@ on: | ||
| 20 | paths: | 21 | paths: |
| 21 | - '.github/workflows/windows-x64.yaml' | 22 | - '.github/workflows/windows-x64.yaml' |
| 22 | - '.github/scripts/test-online-transducer.sh' | 23 | - '.github/scripts/test-online-transducer.sh' |
| 24 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 23 | - '.github/scripts/test-offline-transducer.sh' | 25 | - '.github/scripts/test-offline-transducer.sh' |
| 24 | - '.github/scripts/test-offline-ctc.sh' | 26 | - '.github/scripts/test-offline-ctc.sh' |
| 25 | - 'CMakeLists.txt' | 27 | - 'CMakeLists.txt' |
| @@ -75,6 +77,14 @@ jobs: | @@ -75,6 +77,14 @@ jobs: | ||
| 75 | 77 | ||
| 76 | ls -lh ./bin/Release/sherpa-onnx.exe | 78 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 77 | 79 | ||
| 80 | + - name: Test online paraformer for windows x64 | ||
| 81 | + shell: bash | ||
| 82 | + run: | | ||
| 83 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 84 | + export EXE=sherpa-onnx.exe | ||
| 85 | + | ||
| 86 | + .github/scripts/test-online-paraformer.sh | ||
| 87 | + | ||
| 78 | - name: Test offline Whisper for windows x64 | 88 | - name: Test offline Whisper for windows x64 |
| 79 | shell: bash | 89 | shell: bash |
| 80 | run: | | 90 | run: | |
| @@ -7,6 +7,7 @@ on: | @@ -7,6 +7,7 @@ on: | ||
| 7 | paths: | 7 | paths: |
| 8 | - '.github/workflows/windows-x86.yaml' | 8 | - '.github/workflows/windows-x86.yaml' |
| 9 | - '.github/scripts/test-online-transducer.sh' | 9 | - '.github/scripts/test-online-transducer.sh' |
| 10 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 10 | - '.github/scripts/test-offline-transducer.sh' | 11 | - '.github/scripts/test-offline-transducer.sh' |
| 11 | - '.github/scripts/test-offline-ctc.sh' | 12 | - '.github/scripts/test-offline-ctc.sh' |
| 12 | - 'CMakeLists.txt' | 13 | - 'CMakeLists.txt' |
| @@ -18,6 +19,7 @@ on: | @@ -18,6 +19,7 @@ on: | ||
| 18 | paths: | 19 | paths: |
| 19 | - '.github/workflows/windows-x86.yaml' | 20 | - '.github/workflows/windows-x86.yaml' |
| 20 | - '.github/scripts/test-online-transducer.sh' | 21 | - '.github/scripts/test-online-transducer.sh' |
| 22 | + - '.github/scripts/test-online-paraformer.sh' | ||
| 21 | - '.github/scripts/test-offline-transducer.sh' | 23 | - '.github/scripts/test-offline-transducer.sh' |
| 22 | - '.github/scripts/test-offline-ctc.sh' | 24 | - '.github/scripts/test-offline-ctc.sh' |
| 23 | - 'CMakeLists.txt' | 25 | - 'CMakeLists.txt' |
| @@ -73,6 +75,14 @@ jobs: | @@ -73,6 +75,14 @@ jobs: | ||
| 73 | 75 | ||
| 74 | ls -lh ./bin/Release/sherpa-onnx.exe | 76 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 75 | 77 | ||
| 78 | + - name: Test online paraformer for windows x86 | ||
| 79 | + shell: bash | ||
| 80 | + run: | | ||
| 81 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 82 | + export EXE=sherpa-onnx.exe | ||
| 83 | + | ||
| 84 | + .github/scripts/test-online-paraformer.sh | ||
| 85 | + | ||
| 76 | - name: Test offline Whisper for windows x86 | 86 | - name: Test offline Whisper for windows x86 |
| 77 | shell: bash | 87 | shell: bash |
| 78 | run: | | 88 | run: | |
| @@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \ | @@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \ | ||
| 37 | (2) Use a non-streaming paraformer | 37 | (2) Use a non-streaming paraformer |
| 38 | 38 | ||
| 39 | cd /path/to/sherpa-onnx | 39 | cd /path/to/sherpa-onnx |
| 40 | -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 41 | -cd sherpa-onnx-paraformer-zh-2023-03-28 | 40 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en |
| 41 | +cd sherpa-onnx-paraformer-bilingual-zh-en/ | ||
| 42 | git lfs pull --include "*.onnx" | 42 | git lfs pull --include "*.onnx" |
| 43 | cd .. | 43 | cd .. |
| 44 | 44 | ||
| 45 | python3 ./python-api-examples/non_streaming_server.py \ | 45 | python3 ./python-api-examples/non_streaming_server.py \ |
| 46 | - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \ | ||
| 47 | - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt | 46 | + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \ |
| 47 | + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt | ||
| 48 | 48 | ||
| 49 | (3) Use a non-streaming CTC model from NeMo | 49 | (3) Use a non-streaming CTC model from NeMo |
| 50 | 50 |
| @@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe | @@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe | ||
| 5 | file(s) with a streaming model. | 5 | file(s) with a streaming model. |
| 6 | 6 | ||
| 7 | Usage: | 7 | Usage: |
| 8 | - ./online-decode-files.py \ | ||
| 9 | - /path/to/foo.wav \ | ||
| 10 | - /path/to/bar.wav \ | ||
| 11 | - /path/to/16kHz.wav \ | ||
| 12 | - /path/to/8kHz.wav | 8 | + |
| 9 | +(1) Streaming transducer | ||
| 10 | + | ||
| 11 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 | ||
| 12 | +cd sherpa-onnx-streaming-zipformer-en-2023-06-26 | ||
| 13 | +git lfs pull --include "*.onnx" | ||
| 14 | + | ||
| 15 | +./python-api-examples/online-decode-files.py \ | ||
| 16 | + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \ | ||
| 17 | + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 18 | + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 19 | + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 20 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \ | ||
| 21 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ | ||
| 22 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav | ||
| 23 | + | ||
| 24 | +(2) Streaming paraformer | ||
| 25 | + | ||
| 26 | +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 27 | +cd sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 28 | +git lfs pull --include "*.onnx" | ||
| 29 | + | ||
| 30 | +./python-api-examples/online-decode-files.py \ | ||
| 31 | + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ | ||
| 32 | + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ | ||
| 33 | + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \ | ||
| 34 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \ | ||
| 35 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \ | ||
| 36 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \ | ||
| 37 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ | ||
| 38 | + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav | ||
| 13 | 39 | ||
| 14 | Please refer to | 40 | Please refer to |
| 15 | https://k2-fsa.github.io/sherpa/onnx/index.html | 41 | https://k2-fsa.github.io/sherpa/onnx/index.html |
| 16 | -to install sherpa-onnx and to download the pre-trained models | ||
| 17 | -used in this file. | 42 | +to install sherpa-onnx and to download streaming pre-trained models. |
| 18 | """ | 43 | """ |
| 19 | import argparse | 44 | import argparse |
| 20 | import time | 45 | import time |
| @@ -41,19 +66,31 @@ def get_args(): | @@ -41,19 +66,31 @@ def get_args(): | ||
| 41 | parser.add_argument( | 66 | parser.add_argument( |
| 42 | "--encoder", | 67 | "--encoder", |
| 43 | type=str, | 68 | type=str, |
| 44 | - help="Path to the encoder model", | 69 | + help="Path to the transducer encoder model", |
| 45 | ) | 70 | ) |
| 46 | 71 | ||
| 47 | parser.add_argument( | 72 | parser.add_argument( |
| 48 | "--decoder", | 73 | "--decoder", |
| 49 | type=str, | 74 | type=str, |
| 50 | - help="Path to the decoder model", | 75 | + help="Path to the transducer decoder model", |
| 51 | ) | 76 | ) |
| 52 | 77 | ||
| 53 | parser.add_argument( | 78 | parser.add_argument( |
| 54 | "--joiner", | 79 | "--joiner", |
| 55 | type=str, | 80 | type=str, |
| 56 | - help="Path to the joiner model", | 81 | + help="Path to the transducer joiner model", |
| 82 | + ) | ||
| 83 | + | ||
| 84 | + parser.add_argument( | ||
| 85 | + "--paraformer-encoder", | ||
| 86 | + type=str, | ||
| 87 | + help="Path to the paraformer encoder model", | ||
| 88 | + ) | ||
| 89 | + | ||
| 90 | + parser.add_argument( | ||
| 91 | + "--paraformer-decoder", | ||
| 92 | + type=str, | ||
| 93 | + help="Path to the paraformer decoder model", | ||
| 57 | ) | 94 | ) |
| 58 | 95 | ||
| 59 | parser.add_argument( | 96 | parser.add_argument( |
| @@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | @@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 200 | 237 | ||
| 201 | def main(): | 238 | def main(): |
| 202 | args = get_args() | 239 | args = get_args() |
| 203 | - assert_file_exists(args.encoder) | ||
| 204 | - assert_file_exists(args.decoder) | ||
| 205 | - assert_file_exists(args.joiner) | ||
| 206 | assert_file_exists(args.tokens) | 240 | assert_file_exists(args.tokens) |
| 207 | 241 | ||
| 208 | - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( | ||
| 209 | - tokens=args.tokens, | ||
| 210 | - encoder=args.encoder, | ||
| 211 | - decoder=args.decoder, | ||
| 212 | - joiner=args.joiner, | ||
| 213 | - num_threads=args.num_threads, | ||
| 214 | - provider=args.provider, | ||
| 215 | - sample_rate=16000, | ||
| 216 | - feature_dim=80, | ||
| 217 | - decoding_method=args.decoding_method, | ||
| 218 | - max_active_paths=args.max_active_paths, | ||
| 219 | - context_score=args.context_score, | ||
| 220 | - ) | 242 | + if args.encoder: |
| 243 | + assert_file_exists(args.encoder) | ||
| 244 | + assert_file_exists(args.decoder) | ||
| 245 | + assert_file_exists(args.joiner) | ||
| 246 | + | ||
| 247 | + assert not args.paraformer_encoder, args.paraformer_encoder | ||
| 248 | + assert not args.paraformer_decoder, args.paraformer_decoder | ||
| 249 | + | ||
| 250 | + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( | ||
| 251 | + tokens=args.tokens, | ||
| 252 | + encoder=args.encoder, | ||
| 253 | + decoder=args.decoder, | ||
| 254 | + joiner=args.joiner, | ||
| 255 | + num_threads=args.num_threads, | ||
| 256 | + provider=args.provider, | ||
| 257 | + sample_rate=16000, | ||
| 258 | + feature_dim=80, | ||
| 259 | + decoding_method=args.decoding_method, | ||
| 260 | + max_active_paths=args.max_active_paths, | ||
| 261 | + context_score=args.context_score, | ||
| 262 | + ) | ||
| 263 | + elif args.paraformer_encoder: | ||
| 264 | + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( | ||
| 265 | + tokens=args.tokens, | ||
| 266 | + encoder=args.paraformer_encoder, | ||
| 267 | + decoder=args.paraformer_decoder, | ||
| 268 | + num_threads=args.num_threads, | ||
| 269 | + provider=args.provider, | ||
| 270 | + sample_rate=16000, | ||
| 271 | + feature_dim=80, | ||
| 272 | + decoding_method="greedy_search", | ||
| 273 | + ) | ||
| 274 | + else: | ||
| 275 | + raise ValueError("Please provide a model") | ||
| 221 | 276 | ||
| 222 | print("Started!") | 277 | print("Started!") |
| 223 | start_time = time.time() | 278 | start_time = time.time() |
| @@ -243,7 +298,7 @@ def main(): | @@ -243,7 +298,7 @@ def main(): | ||
| 243 | 298 | ||
| 244 | s.accept_waveform(sample_rate, samples) | 299 | s.accept_waveform(sample_rate, samples) |
| 245 | 300 | ||
| 246 | - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | 301 | + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) |
| 247 | s.accept_waveform(sample_rate, tail_paddings) | 302 | s.accept_waveform(sample_rate, tail_paddings) |
| 248 | 303 | ||
| 249 | s.input_finished() | 304 | s.input_finished() |
| @@ -16,9 +16,9 @@ Example: | @@ -16,9 +16,9 @@ Example: | ||
| 16 | (1) Without a certificate | 16 | (1) Without a certificate |
| 17 | 17 | ||
| 18 | python3 ./python-api-examples/streaming_server.py \ | 18 | python3 ./python-api-examples/streaming_server.py \ |
| 19 | - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ | ||
| 20 | - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | ||
| 21 | - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | 19 | + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ |
| 20 | + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | ||
| 21 | + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | ||
| 22 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt | 22 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt |
| 23 | 23 | ||
| 24 | (2) With a certificate | 24 | (2) With a certificate |
| @@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \ | @@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \ | ||
| 32 | (b) Start the server | 32 | (b) Start the server |
| 33 | 33 | ||
| 34 | python3 ./python-api-examples/streaming_server.py \ | 34 | python3 ./python-api-examples/streaming_server.py \ |
| 35 | - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ | ||
| 36 | - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | ||
| 37 | - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | 35 | + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ |
| 36 | + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ | ||
| 37 | + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ | ||
| 38 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ | 38 | --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ |
| 39 | --certificate ./python-api-examples/web/cert.pem | 39 | --certificate ./python-api-examples/web/cert.pem |
| 40 | 40 | ||
| @@ -113,24 +113,33 @@ def setup_logger( | @@ -113,24 +113,33 @@ def setup_logger( | ||
| 113 | 113 | ||
| 114 | def add_model_args(parser: argparse.ArgumentParser): | 114 | def add_model_args(parser: argparse.ArgumentParser): |
| 115 | parser.add_argument( | 115 | parser.add_argument( |
| 116 | - "--encoder-model", | 116 | + "--encoder", |
| 117 | type=str, | 117 | type=str, |
| 118 | - required=True, | ||
| 119 | - help="Path to the encoder model", | 118 | + help="Path to the transducer encoder model", |
| 120 | ) | 119 | ) |
| 121 | 120 | ||
| 122 | parser.add_argument( | 121 | parser.add_argument( |
| 123 | - "--decoder-model", | 122 | + "--decoder", |
| 124 | type=str, | 123 | type=str, |
| 125 | - required=True, | ||
| 126 | - help="Path to the decoder model.", | 124 | + help="Path to the transducer decoder model.", |
| 127 | ) | 125 | ) |
| 128 | 126 | ||
| 129 | parser.add_argument( | 127 | parser.add_argument( |
| 130 | - "--joiner-model", | 128 | + "--joiner", |
| 131 | type=str, | 129 | type=str, |
| 132 | - required=True, | ||
| 133 | - help="Path to the joiner model.", | 130 | + help="Path to the transducer joiner model.", |
| 131 | + ) | ||
| 132 | + | ||
| 133 | + parser.add_argument( | ||
| 134 | + "--paraformer-encoder", | ||
| 135 | + type=str, | ||
| 136 | + help="Path to the paraformer encoder model", | ||
| 137 | + ) | ||
| 138 | + | ||
| 139 | + parser.add_argument( | ||
| 140 | + "--paraformer-decoder", | ||
| 141 | + type=str, | ||
| 142 | + help="Path to the transducer decoder model.", | ||
| 134 | ) | 143 | ) |
| 135 | 144 | ||
| 136 | parser.add_argument( | 145 | parser.add_argument( |
| @@ -323,22 +332,40 @@ def get_args(): | @@ -323,22 +332,40 @@ def get_args(): | ||
| 323 | 332 | ||
| 324 | 333 | ||
| 325 | def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | 334 | def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: |
| 326 | - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( | ||
| 327 | - tokens=args.tokens, | ||
| 328 | - encoder=args.encoder_model, | ||
| 329 | - decoder=args.decoder_model, | ||
| 330 | - joiner=args.joiner_model, | ||
| 331 | - num_threads=args.num_threads, | ||
| 332 | - sample_rate=args.sample_rate, | ||
| 333 | - feature_dim=args.feat_dim, | ||
| 334 | - decoding_method=args.decoding_method, | ||
| 335 | - max_active_paths=args.num_active_paths, | ||
| 336 | - enable_endpoint_detection=args.use_endpoint != 0, | ||
| 337 | - rule1_min_trailing_silence=args.rule1_min_trailing_silence, | ||
| 338 | - rule2_min_trailing_silence=args.rule2_min_trailing_silence, | ||
| 339 | - rule3_min_utterance_length=args.rule3_min_utterance_length, | ||
| 340 | - provider=args.provider, | ||
| 341 | - ) | 335 | + if args.encoder: |
| 336 | + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( | ||
| 337 | + tokens=args.tokens, | ||
| 338 | + encoder=args.encoder, | ||
| 339 | + decoder=args.decoder, | ||
| 340 | + joiner=args.joiner, | ||
| 341 | + num_threads=args.num_threads, | ||
| 342 | + sample_rate=args.sample_rate, | ||
| 343 | + feature_dim=args.feat_dim, | ||
| 344 | + decoding_method=args.decoding_method, | ||
| 345 | + max_active_paths=args.num_active_paths, | ||
| 346 | + enable_endpoint_detection=args.use_endpoint != 0, | ||
| 347 | + rule1_min_trailing_silence=args.rule1_min_trailing_silence, | ||
| 348 | + rule2_min_trailing_silence=args.rule2_min_trailing_silence, | ||
| 349 | + rule3_min_utterance_length=args.rule3_min_utterance_length, | ||
| 350 | + provider=args.provider, | ||
| 351 | + ) | ||
| 352 | + elif args.paraformer_encoder: | ||
| 353 | + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( | ||
| 354 | + tokens=args.tokens, | ||
| 355 | + encoder=args.paraformer_encoder, | ||
| 356 | + decoder=args.paraformer_decoder, | ||
| 357 | + num_threads=args.num_threads, | ||
| 358 | + sample_rate=args.sample_rate, | ||
| 359 | + feature_dim=args.feat_dim, | ||
| 360 | + decoding_method=args.decoding_method, | ||
| 361 | + enable_endpoint_detection=args.use_endpoint != 0, | ||
| 362 | + rule1_min_trailing_silence=args.rule1_min_trailing_silence, | ||
| 363 | + rule2_min_trailing_silence=args.rule2_min_trailing_silence, | ||
| 364 | + rule3_min_utterance_length=args.rule3_min_utterance_length, | ||
| 365 | + provider=args.provider, | ||
| 366 | + ) | ||
| 367 | + else: | ||
| 368 | + raise ValueError("Please provide a model") | ||
| 342 | 369 | ||
| 343 | return recognizer | 370 | return recognizer |
| 344 | 371 | ||
| @@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a> | @@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a> | ||
| 654 | 681 | ||
| 655 | 682 | ||
| 656 | def check_args(args): | 683 | def check_args(args): |
| 657 | - assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist" | 684 | + if args.encoder: |
| 685 | + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist" | ||
| 686 | + | ||
| 687 | + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist" | ||
| 688 | + | ||
| 689 | + assert Path(args.joiner).is_file(), f"{args.joiner} does not exist" | ||
| 658 | 690 | ||
| 659 | - assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist" | 691 | + assert args.paraformer_encoder is None, args.paraformer_encoder |
| 692 | + assert args.paraformer_decoder is None, args.paraformer_decoder | ||
| 693 | + elif args.paraformer_encoder: | ||
| 694 | + assert Path( | ||
| 695 | + args.paraformer_encoder | ||
| 696 | + ).is_file(), f"{args.paraformer_encoder} does not exist" | ||
| 660 | 697 | ||
| 661 | - assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist" | 698 | + assert Path( |
| 699 | + args.paraformer_decoder | ||
| 700 | + ).is_file(), f"{args.paraformer_decoder} does not exist" | ||
| 701 | + else: | ||
| 702 | + raise ValueError("Please provide a model") | ||
| 662 | 703 | ||
| 663 | if not Path(args.tokens).is_file(): | 704 | if not Path(args.tokens).is_file(): |
| 664 | raise ValueError(f"{args.tokens} does not exist") | 705 | raise ValueError(f"{args.tokens} does not exist") |
| @@ -46,6 +46,8 @@ set(sources | @@ -46,6 +46,8 @@ set(sources | ||
| 46 | online-lm.cc | 46 | online-lm.cc |
| 47 | online-lstm-transducer-model.cc | 47 | online-lstm-transducer-model.cc |
| 48 | online-model-config.cc | 48 | online-model-config.cc |
| 49 | + online-paraformer-model-config.cc | ||
| 50 | + online-paraformer-model.cc | ||
| 49 | online-recognizer-impl.cc | 51 | online-recognizer-impl.cc |
| 50 | online-recognizer.cc | 52 | online-recognizer.cc |
| 51 | online-rnn-lm.cc | 53 | online-rnn-lm.cc |
| @@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 39 | 39 | ||
| 40 | class FeatureExtractor::Impl { | 40 | class FeatureExtractor::Impl { |
| 41 | public: | 41 | public: |
| 42 | - explicit Impl(const FeatureExtractorConfig &config) { | 42 | + explicit Impl(const FeatureExtractorConfig &config) : config_(config) { |
| 43 | opts_.frame_opts.dither = 0; | 43 | opts_.frame_opts.dither = 0; |
| 44 | opts_.frame_opts.snip_edges = false; | 44 | opts_.frame_opts.snip_edges = false; |
| 45 | opts_.frame_opts.samp_freq = config.sampling_rate; | 45 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| @@ -50,6 +50,19 @@ class FeatureExtractor::Impl { | @@ -50,6 +50,19 @@ class FeatureExtractor::Impl { | ||
| 50 | } | 50 | } |
| 51 | 51 | ||
| 52 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 52 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 53 | + if (config_.normalize_samples) { | ||
| 54 | + AcceptWaveformImpl(sampling_rate, waveform, n); | ||
| 55 | + } else { | ||
| 56 | + std::vector<float> buf(n); | ||
| 57 | + for (int32_t i = 0; i != n; ++i) { | ||
| 58 | + buf[i] = waveform[i] * 32768; | ||
| 59 | + } | ||
| 60 | + AcceptWaveformImpl(sampling_rate, buf.data(), n); | ||
| 61 | + } | ||
| 62 | + } | ||
| 63 | + | ||
| 64 | + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, | ||
| 65 | + int32_t n) { | ||
| 53 | std::lock_guard<std::mutex> lock(mutex_); | 66 | std::lock_guard<std::mutex> lock(mutex_); |
| 54 | 67 | ||
| 55 | if (resampler_) { | 68 | if (resampler_) { |
| @@ -146,6 +159,7 @@ class FeatureExtractor::Impl { | @@ -146,6 +159,7 @@ class FeatureExtractor::Impl { | ||
| 146 | private: | 159 | private: |
| 147 | std::unique_ptr<knf::OnlineFbank> fbank_; | 160 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 148 | knf::FbankOptions opts_; | 161 | knf::FbankOptions opts_; |
| 162 | + FeatureExtractorConfig config_; | ||
| 149 | mutable std::mutex mutex_; | 163 | mutable std::mutex mutex_; |
| 150 | std::unique_ptr<LinearResample> resampler_; | 164 | std::unique_ptr<LinearResample> resampler_; |
| 151 | int32_t last_frame_index_ = 0; | 165 | int32_t last_frame_index_ = 0; |
| @@ -21,6 +21,13 @@ struct FeatureExtractorConfig { | @@ -21,6 +21,13 @@ struct FeatureExtractorConfig { | ||
| 21 | // Feature dimension | 21 | // Feature dimension |
| 22 | int32_t feature_dim = 80; | 22 | int32_t feature_dim = 80; |
| 23 | 23 | ||
| 24 | + // Set internally by some models, e.g., paraformer sets it to false. | ||
| 25 | + // This parameter is not exposed to users from the commandline | ||
| 26 | + // If true, the feature extractor expects inputs to be normalized to | ||
| 27 | + // the range [-1, 1]. | ||
| 28 | + // If false, we will multiply the inputs by 32768 | ||
| 29 | + bool normalize_samples = true; | ||
| 30 | + | ||
| 24 | std::string ToString() const; | 31 | std::string ToString() const; |
| 25 | 32 | ||
| 26 | void Register(ParseOptions *po); | 33 | void Register(ParseOptions *po); |
| @@ -12,6 +12,7 @@ namespace sherpa_onnx { | @@ -12,6 +12,7 @@ namespace sherpa_onnx { | ||
| 12 | 12 | ||
| 13 | void OnlineModelConfig::Register(ParseOptions *po) { | 13 | void OnlineModelConfig::Register(ParseOptions *po) { |
| 14 | transducer.Register(po); | 14 | transducer.Register(po); |
| 15 | + paraformer.Register(po); | ||
| 15 | 16 | ||
| 16 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 17 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 17 | 18 | ||
| @@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const { | @@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const { | ||
| 41 | return false; | 42 | return false; |
| 42 | } | 43 | } |
| 43 | 44 | ||
| 45 | + if (!paraformer.encoder.empty()) { | ||
| 46 | + return paraformer.Validate(); | ||
| 47 | + } | ||
| 48 | + | ||
| 44 | return transducer.Validate(); | 49 | return transducer.Validate(); |
| 45 | } | 50 | } |
| 46 | 51 | ||
| @@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const { | @@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const { | ||
| 49 | 54 | ||
| 50 | os << "OnlineModelConfig("; | 55 | os << "OnlineModelConfig("; |
| 51 | os << "transducer=" << transducer.ToString() << ", "; | 56 | os << "transducer=" << transducer.ToString() << ", "; |
| 57 | + os << "paraformer=" << paraformer.ToString() << ", "; | ||
| 52 | os << "tokens=\"" << tokens << "\", "; | 58 | os << "tokens=\"" << tokens << "\", "; |
| 53 | os << "num_threads=" << num_threads << ", "; | 59 | os << "num_threads=" << num_threads << ", "; |
| 54 | os << "debug=" << (debug ? "True" : "False") << ", "; | 60 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -6,12 +6,14 @@ | @@ -6,12 +6,14 @@ | ||
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" | ||
| 9 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 10 | 11 | ||
| 11 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 12 | 13 | ||
| 13 | struct OnlineModelConfig { | 14 | struct OnlineModelConfig { |
| 14 | OnlineTransducerModelConfig transducer; | 15 | OnlineTransducerModelConfig transducer; |
| 16 | + OnlineParaformerModelConfig paraformer; | ||
| 15 | std::string tokens; | 17 | std::string tokens; |
| 16 | int32_t num_threads = 1; | 18 | int32_t num_threads = 1; |
| 17 | bool debug = false; | 19 | bool debug = false; |
| @@ -28,9 +30,11 @@ struct OnlineModelConfig { | @@ -28,9 +30,11 @@ struct OnlineModelConfig { | ||
| 28 | 30 | ||
| 29 | OnlineModelConfig() = default; | 31 | OnlineModelConfig() = default; |
| 30 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | 32 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, |
| 33 | + const OnlineParaformerModelConfig ¶former, | ||
| 31 | const std::string &tokens, int32_t num_threads, bool debug, | 34 | const std::string &tokens, int32_t num_threads, bool debug, |
| 32 | const std::string &provider, const std::string &model_type) | 35 | const std::string &provider, const std::string &model_type) |
| 33 | : transducer(transducer), | 36 | : transducer(transducer), |
| 37 | + paraformer(paraformer), | ||
| 34 | tokens(tokens), | 38 | tokens(tokens), |
| 35 | num_threads(num_threads), | 39 | num_threads(num_threads), |
| 36 | debug(debug), | 40 | debug(debug), |
sherpa-onnx/csrc/online-paraformer-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-paraformer-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OnlineParaformerDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int32_t> tokens; | ||
| 17 | + | ||
| 18 | + int32_t last_non_blank_frame_index = 0; | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +} // namespace sherpa_onnx | ||
| 22 | + | ||
| 23 | +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/online-paraformer-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OnlineParaformerModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("paraformer-encoder", &encoder, | ||
| 14 | + "Path to encoder.onnx of paraformer."); | ||
| 15 | + po->Register("paraformer-decoder", &decoder, | ||
| 16 | + "Path to decoder.onnx of paraformer."); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OnlineParaformerModelConfig::Validate() const { | ||
| 20 | + if (!FileExists(encoder)) { | ||
| 21 | + SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str()); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + if (!FileExists(decoder)) { | ||
| 26 | + SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str()); | ||
| 27 | + return false; | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + return true; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +std::string OnlineParaformerModelConfig::ToString() const { | ||
| 34 | + std::ostringstream os; | ||
| 35 | + | ||
| 36 | + os << "OnlineParaformerModelConfig("; | ||
| 37 | + os << "encoder=\"" << encoder << "\", "; | ||
| 38 | + os << "decoder=\"" << decoder << "\")"; | ||
| 39 | + | ||
| 40 | + return os.str(); | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-paraformer-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OnlineParaformerModelConfig { | ||
| 14 | + std::string encoder; | ||
| 15 | + std::string decoder; | ||
| 16 | + | ||
| 17 | + OnlineParaformerModelConfig() = default; | ||
| 18 | + | ||
| 19 | + OnlineParaformerModelConfig(const std::string &encoder, | ||
| 20 | + const std::string &decoder) | ||
| 21 | + : encoder(encoder), decoder(decoder) {} | ||
| 22 | + | ||
| 23 | + void Register(ParseOptions *po); | ||
| 24 | + bool Validate() const; | ||
| 25 | + | ||
| 26 | + std::string ToString() const; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/online-paraformer-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-paraformer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-paraformer-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cmath> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 17 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 18 | +#include "sherpa-onnx/csrc/session.h" | ||
| 19 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 20 | + | ||
| 21 | +namespace sherpa_onnx { | ||
| 22 | + | ||
| 23 | +class OnlineParaformerModel::Impl { | ||
| 24 | + public: | ||
| 25 | + explicit Impl(const OnlineModelConfig &config) | ||
| 26 | + : config_(config), | ||
| 27 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 28 | + sess_opts_(GetSessionOptions(config)), | ||
| 29 | + allocator_{} { | ||
| 30 | + { | ||
| 31 | + auto buf = ReadFile(config.paraformer.encoder); | ||
| 32 | + InitEncoder(buf.data(), buf.size()); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + { | ||
| 36 | + auto buf = ReadFile(config.paraformer.decoder); | ||
| 37 | + InitDecoder(buf.data(), buf.size()); | ||
| 38 | + } | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | +#if __ANDROID_API__ >= 9 | ||
| 42 | + Impl(AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 43 | + : config_(config), | ||
| 44 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 45 | + sess_opts_(GetSessionOptions(config)), | ||
| 46 | + allocator_{} { | ||
| 47 | + { | ||
| 48 | + auto buf = ReadFile(mgr, config.paraformer.encoder); | ||
| 49 | + InitEncoder(buf.data(), buf.size()); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + { | ||
| 53 | + auto buf = ReadFile(mgr, config.paraformer.decoder); | ||
| 54 | + InitDecoder(buf.data(), buf.size()); | ||
| 55 | + } | ||
| 56 | + } | ||
| 57 | +#endif | ||
| 58 | + | ||
| 59 | + std::vector<Ort::Value> ForwardEncoder(Ort::Value features, | ||
| 60 | + Ort::Value features_length) { | ||
| 61 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 62 | + std::move(features_length)}; | ||
| 63 | + | ||
| 64 | + return encoder_sess_->Run( | ||
| 65 | + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 66 | + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out, | ||
| 70 | + Ort::Value encoder_out_length, | ||
| 71 | + Ort::Value acoustic_embedding, | ||
| 72 | + Ort::Value acoustic_embedding_length, | ||
| 73 | + std::vector<Ort::Value> states) { | ||
| 74 | + std::vector<Ort::Value> decoder_inputs; | ||
| 75 | + decoder_inputs.reserve(4 + states.size()); | ||
| 76 | + | ||
| 77 | + decoder_inputs.push_back(std::move(encoder_out)); | ||
| 78 | + decoder_inputs.push_back(std::move(encoder_out_length)); | ||
| 79 | + decoder_inputs.push_back(std::move(acoustic_embedding)); | ||
| 80 | + decoder_inputs.push_back(std::move(acoustic_embedding_length)); | ||
| 81 | + | ||
| 82 | + for (auto &v : states) { | ||
| 83 | + decoder_inputs.push_back(std::move(v)); | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + return decoder_sess_->Run({}, decoder_input_names_ptr_.data(), | ||
| 87 | + decoder_inputs.data(), decoder_inputs.size(), | ||
| 88 | + decoder_output_names_ptr_.data(), | ||
| 89 | + decoder_output_names_ptr_.size()); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 93 | + | ||
| 94 | + int32_t LfrWindowSize() const { return lfr_window_size_; } | ||
| 95 | + | ||
| 96 | + int32_t LfrWindowShift() const { return lfr_window_shift_; } | ||
| 97 | + | ||
| 98 | + int32_t EncoderOutputSize() const { return encoder_output_size_; } | ||
| 99 | + | ||
| 100 | + int32_t DecoderKernelSize() const { return decoder_kernel_size_; } | ||
| 101 | + | ||
| 102 | + int32_t DecoderNumBlocks() const { return decoder_num_blocks_; } | ||
| 103 | + | ||
| 104 | + const std::vector<float> &NegativeMean() const { return neg_mean_; } | ||
| 105 | + | ||
| 106 | + const std::vector<float> &InverseStdDev() const { return inv_stddev_; } | ||
| 107 | + | ||
| 108 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 109 | + | ||
| 110 | + private: | ||
| 111 | + void InitEncoder(void *model_data, size_t model_data_length) { | ||
| 112 | + encoder_sess_ = std::make_unique<Ort::Session>( | ||
| 113 | + env_, model_data, model_data_length, sess_opts_); | ||
| 114 | + | ||
| 115 | + GetInputNames(encoder_sess_.get(), &encoder_input_names_, | ||
| 116 | + &encoder_input_names_ptr_); | ||
| 117 | + | ||
| 118 | + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, | ||
| 119 | + &encoder_output_names_ptr_); | ||
| 120 | + | ||
| 121 | + // get meta data | ||
| 122 | + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); | ||
| 123 | + if (config_.debug) { | ||
| 124 | + std::ostringstream os; | ||
| 125 | + PrintModelMetadata(os, meta_data); | ||
| 126 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 130 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 131 | + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); | ||
| 132 | + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); | ||
| 133 | + SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size"); | ||
| 134 | + SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks"); | ||
| 135 | + SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size"); | ||
| 136 | + | ||
| 137 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); | ||
| 138 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); | ||
| 139 | + | ||
| 140 | + float scale = std::sqrt(encoder_output_size_); | ||
| 141 | + for (auto &f : inv_stddev_) { | ||
| 142 | + f *= scale; | ||
| 143 | + } | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + void InitDecoder(void *model_data, size_t model_data_length) { | ||
| 147 | + decoder_sess_ = std::make_unique<Ort::Session>( | ||
| 148 | + env_, model_data, model_data_length, sess_opts_); | ||
| 149 | + | ||
| 150 | + GetInputNames(decoder_sess_.get(), &decoder_input_names_, | ||
| 151 | + &decoder_input_names_ptr_); | ||
| 152 | + | ||
| 153 | + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, | ||
| 154 | + &decoder_output_names_ptr_); | ||
| 155 | + } | ||
| 156 | + | ||
| 157 | + private: | ||
| 158 | + OnlineModelConfig config_; | ||
| 159 | + Ort::Env env_; | ||
| 160 | + Ort::SessionOptions sess_opts_; | ||
| 161 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 162 | + | ||
| 163 | + std::unique_ptr<Ort::Session> encoder_sess_; | ||
| 164 | + | ||
| 165 | + std::vector<std::string> encoder_input_names_; | ||
| 166 | + std::vector<const char *> encoder_input_names_ptr_; | ||
| 167 | + | ||
| 168 | + std::vector<std::string> encoder_output_names_; | ||
| 169 | + std::vector<const char *> encoder_output_names_ptr_; | ||
| 170 | + | ||
| 171 | + std::unique_ptr<Ort::Session> decoder_sess_; | ||
| 172 | + | ||
| 173 | + std::vector<std::string> decoder_input_names_; | ||
| 174 | + std::vector<const char *> decoder_input_names_ptr_; | ||
| 175 | + | ||
| 176 | + std::vector<std::string> decoder_output_names_; | ||
| 177 | + std::vector<const char *> decoder_output_names_ptr_; | ||
| 178 | + | ||
| 179 | + std::vector<float> neg_mean_; | ||
| 180 | + std::vector<float> inv_stddev_; | ||
| 181 | + | ||
| 182 | + int32_t vocab_size_ = 0; // initialized in Init | ||
| 183 | + int32_t lfr_window_size_ = 0; | ||
| 184 | + int32_t lfr_window_shift_ = 0; | ||
| 185 | + | ||
| 186 | + int32_t encoder_output_size_ = 0; | ||
| 187 | + int32_t decoder_num_blocks_ = 0; | ||
| 188 | + int32_t decoder_kernel_size_ = 0; | ||
| 189 | +}; | ||
| 190 | + | ||
| 191 | +OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config) | ||
| 192 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 193 | + | ||
| 194 | +#if __ANDROID_API__ >= 9 | ||
| 195 | +OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr, | ||
| 196 | + const OnlineModelConfig &config) | ||
| 197 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 198 | +#endif | ||
| 199 | + | ||
| 200 | +OnlineParaformerModel::~OnlineParaformerModel() = default; | ||
| 201 | + | ||
| 202 | +std::vector<Ort::Value> OnlineParaformerModel::ForwardEncoder( | ||
| 203 | + Ort::Value features, Ort::Value features_length) const { | ||
| 204 | + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); | ||
| 205 | +} | ||
| 206 | + | ||
| 207 | +std::vector<Ort::Value> OnlineParaformerModel::ForwardDecoder( | ||
| 208 | + Ort::Value encoder_out, Ort::Value encoder_out_length, | ||
| 209 | + Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length, | ||
| 210 | + std::vector<Ort::Value> states) const { | ||
| 211 | + return impl_->ForwardDecoder( | ||
| 212 | + std::move(encoder_out), std::move(encoder_out_length), | ||
| 213 | + std::move(acoustic_embedding), std::move(acoustic_embedding_length), | ||
| 214 | + std::move(states)); | ||
| 215 | +} | ||
| 216 | + | ||
| 217 | +int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 218 | + | ||
| 219 | +int32_t OnlineParaformerModel::LfrWindowSize() const { | ||
| 220 | + return impl_->LfrWindowSize(); | ||
| 221 | +} | ||
| 222 | +int32_t OnlineParaformerModel::LfrWindowShift() const { | ||
| 223 | + return impl_->LfrWindowShift(); | ||
| 224 | +} | ||
| 225 | + | ||
| 226 | +int32_t OnlineParaformerModel::EncoderOutputSize() const { | ||
| 227 | + return impl_->EncoderOutputSize(); | ||
| 228 | +} | ||
| 229 | + | ||
| 230 | +int32_t OnlineParaformerModel::DecoderKernelSize() const { | ||
| 231 | + return impl_->DecoderKernelSize(); | ||
| 232 | +} | ||
| 233 | + | ||
| 234 | +int32_t OnlineParaformerModel::DecoderNumBlocks() const { | ||
| 235 | + return impl_->DecoderNumBlocks(); | ||
| 236 | +} | ||
| 237 | + | ||
| 238 | +const std::vector<float> &OnlineParaformerModel::NegativeMean() const { | ||
| 239 | + return impl_->NegativeMean(); | ||
| 240 | +} | ||
| 241 | +const std::vector<float> &OnlineParaformerModel::InverseStdDev() const { | ||
| 242 | + return impl_->InverseStdDev(); | ||
| 243 | +} | ||
| 244 | + | ||
| 245 | +OrtAllocator *OnlineParaformerModel::Allocator() const { | ||
| 246 | + return impl_->Allocator(); | ||
| 247 | +} | ||
| 248 | + | ||
| 249 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-paraformer-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-paraformer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +class OnlineParaformerModel { | ||
| 22 | + public: | ||
| 23 | + explicit OnlineParaformerModel(const OnlineModelConfig &config); | ||
| 24 | + | ||
| 25 | +#if __ANDROID_API__ >= 9 | ||
| 26 | + OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 27 | +#endif | ||
| 28 | + | ||
| 29 | + ~OnlineParaformerModel(); | ||
| 30 | + | ||
| 31 | + std::vector<Ort::Value> ForwardEncoder(Ort::Value features, | ||
| 32 | + Ort::Value features_length) const; | ||
| 33 | + | ||
| 34 | + std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out, | ||
| 35 | + Ort::Value encoder_out_length, | ||
| 36 | + Ort::Value acoustic_embedding, | ||
| 37 | + Ort::Value acoustic_embedding_length, | ||
| 38 | + std::vector<Ort::Value> states) const; | ||
| 39 | + | ||
| 40 | + /** Return the vocabulary size of the model | ||
| 41 | + */ | ||
| 42 | + int32_t VocabSize() const; | ||
| 43 | + | ||
| 44 | + /** It is lfr_m in config.yaml | ||
| 45 | + */ | ||
| 46 | + int32_t LfrWindowSize() const; | ||
| 47 | + | ||
| 48 | + /** It is lfr_n in config.yaml | ||
| 49 | + */ | ||
| 50 | + int32_t LfrWindowShift() const; | ||
| 51 | + | ||
| 52 | + int32_t EncoderOutputSize() const; | ||
| 53 | + | ||
| 54 | + int32_t DecoderKernelSize() const; | ||
| 55 | + int32_t DecoderNumBlocks() const; | ||
| 56 | + | ||
| 57 | + /** Return negative mean for CMVN | ||
| 58 | + */ | ||
| 59 | + const std::vector<float> &NegativeMean() const; | ||
| 60 | + | ||
| 61 | + /** Return inverse stddev for CMVN | ||
| 62 | + */ | ||
| 63 | + const std::vector<float> &InverseStdDev() const; | ||
| 64 | + | ||
| 65 | + /** Return an allocator for allocating memory | ||
| 66 | + */ | ||
| 67 | + OrtAllocator *Allocator() const; | ||
| 68 | + | ||
| 69 | + private: | ||
| 70 | + class Impl; | ||
| 71 | + std::unique_ptr<Impl> impl_; | ||
| 72 | +}; | ||
| 73 | + | ||
| 74 | +} // namespace sherpa_onnx | ||
| 75 | + | ||
| 76 | +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 6 | 6 | ||
| 7 | +#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" | ||
| 7 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" | 8 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" |
| 8 | 9 | ||
| 9 | namespace sherpa_onnx { | 10 | namespace sherpa_onnx { |
| @@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 14 | return std::make_unique<OnlineRecognizerTransducerImpl>(config); | 15 | return std::make_unique<OnlineRecognizerTransducerImpl>(config); |
| 15 | } | 16 | } |
| 16 | 17 | ||
| 18 | + if (!config.model_config.paraformer.encoder.empty()) { | ||
| 19 | + return std::make_unique<OnlineRecognizerParaformerImpl>(config); | ||
| 20 | + } | ||
| 21 | + | ||
| 17 | SHERPA_ONNX_LOGE("Please specify a model"); | 22 | SHERPA_ONNX_LOGE("Please specify a model"); |
| 18 | exit(-1); | 23 | exit(-1); |
| 19 | } | 24 | } |
| @@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 25 | return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); | 30 | return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); |
| 26 | } | 31 | } |
| 27 | 32 | ||
| 33 | + if (!config.model_config.paraformer.encoder.empty()) { | ||
| 34 | + return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); | ||
| 35 | + } | ||
| 36 | + | ||
| 28 | SHERPA_ONNX_LOGE("Please specify a model"); | 37 | SHERPA_ONNX_LOGE("Please specify a model"); |
| 29 | exit(-1); | 38 | exit(-1); |
| 30 | } | 39 | } |
| @@ -26,8 +26,6 @@ class OnlineRecognizerImpl { | @@ -26,8 +26,6 @@ class OnlineRecognizerImpl { | ||
| 26 | 26 | ||
| 27 | virtual ~OnlineRecognizerImpl() = default; | 27 | virtual ~OnlineRecognizerImpl() = default; |
| 28 | 28 | ||
| 29 | - virtual void InitOnlineStream(OnlineStream *stream) const = 0; | ||
| 30 | - | ||
| 31 | virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; | 29 | virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; |
| 32 | 30 | ||
| 33 | virtual std::unique_ptr<OnlineStream> CreateStream( | 31 | virtual std::unique_ptr<OnlineStream> CreateStream( |
| 1 | +// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 15 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 16 | +#include "sherpa-onnx/csrc/online-lm.h" | ||
| 17 | +#include "sherpa-onnx/csrc/online-paraformer-decoder.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-paraformer-model.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-recognizer-impl.h" | ||
| 20 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 21 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 22 | + | ||
| 23 | +namespace sherpa_onnx { | ||
| 24 | + | ||
| 25 | +static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src, | ||
| 26 | + const SymbolTable &sym_table) { | ||
| 27 | + OnlineRecognizerResult r; | ||
| 28 | + r.tokens.reserve(src.tokens.size()); | ||
| 29 | + | ||
| 30 | + std::string text; | ||
| 31 | + | ||
| 32 | + // When the current token ends with "@@" we set mergeable to true | ||
| 33 | + bool mergeable = false; | ||
| 34 | + | ||
| 35 | + for (int32_t i = 0; i != src.tokens.size(); ++i) { | ||
| 36 | + auto sym = sym_table[src.tokens[i]]; | ||
| 37 | + r.tokens.push_back(sym); | ||
| 38 | + | ||
| 39 | + if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) { | ||
| 40 | + // sym does not end with "@@" | ||
| 41 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); | ||
| 42 | + if (p[0] < 0x80) { | ||
| 43 | + // an ascii | ||
| 44 | + if (mergeable) { | ||
| 45 | + mergeable = false; | ||
| 46 | + text.append(sym); | ||
| 47 | + } else { | ||
| 48 | + text.append(" "); | ||
| 49 | + text.append(sym); | ||
| 50 | + } | ||
| 51 | + } else { | ||
| 52 | + // not an ascii | ||
| 53 | + mergeable = false; | ||
| 54 | + | ||
| 55 | + if (i > 0) { | ||
| 56 | + const uint8_t *p = reinterpret_cast<const uint8_t *>( | ||
| 57 | + sym_table[src.tokens[i - 1]].c_str()); | ||
| 58 | + if (p[0] < 0x80) { | ||
| 59 | + // put a space between ascii and non-ascii | ||
| 60 | + text.append(" "); | ||
| 61 | + } | ||
| 62 | + } | ||
| 63 | + text.append(sym); | ||
| 64 | + } | ||
| 65 | + } else { | ||
| 66 | + // this sym ends with @@ | ||
| 67 | + sym = std::string(sym.data(), sym.size() - 2); | ||
| 68 | + if (mergeable) { | ||
| 69 | + text.append(sym); | ||
| 70 | + } else { | ||
| 71 | + text.append(" "); | ||
| 72 | + text.append(sym); | ||
| 73 | + mergeable = true; | ||
| 74 | + } | ||
| 75 | + } | ||
| 76 | + } | ||
| 77 | + r.text = std::move(text); | ||
| 78 | + | ||
| 79 | + return r; | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +// y[i] += x[i] * scale | ||
| 83 | +static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) { | ||
| 84 | + for (int32_t i = 0; i != n; ++i) { | ||
| 85 | + y[i] += x[i] * scale; | ||
| 86 | + } | ||
| 87 | +} | ||
| 88 | + | ||
| 89 | +// y[i] = x[i] * scale | ||
| 90 | +static void Scale(const float *x, int32_t n, float scale, float *y) { | ||
| 91 | + for (int32_t i = 0; i != n; ++i) { | ||
| 92 | + y[i] = x[i] * scale; | ||
| 93 | + } | ||
| 94 | +} | ||
| 95 | + | ||
| 96 | +class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 97 | + public: | ||
| 98 | + explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) | ||
| 99 | + : config_(config), | ||
| 100 | + model_(config.model_config), | ||
| 101 | + sym_(config.model_config.tokens), | ||
| 102 | + endpoint_(config_.endpoint_config) { | ||
| 103 | + if (config.decoding_method != "greedy_search") { | ||
| 104 | + SHERPA_ONNX_LOGE( | ||
| 105 | + "Unsupported decoding method: %s. Support only greedy_search at " | ||
| 106 | + "present", | ||
| 107 | + config.decoding_method.c_str()); | ||
| 108 | + exit(-1); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + // Paraformer models assume input samples are in the range | ||
| 112 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 113 | + config_.feat_config.normalize_samples = false; | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | +#if __ANDROID_API__ >= 9 | ||
| 117 | + explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, | ||
| 118 | + const OnlineRecognizerConfig &config) | ||
| 119 | + : config_(config), | ||
| 120 | + model_(mgr, config.model_config), | ||
| 121 | + sym_(mgr, config.model_config.tokens), | ||
| 122 | + endpoint_(config_.endpoint_config) { | ||
| 123 | + if (config.decoding_method == "greedy_search") { | ||
| 124 | + // add greedy search decoder | ||
| 125 | + // SHERPA_ONNX_LOGE("to be implemented"); | ||
| 126 | + // exit(-1); | ||
| 127 | + } else { | ||
| 128 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 129 | + config.decoding_method.c_str()); | ||
| 130 | + exit(-1); | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + // Paraformer models assume input samples are in the range | ||
| 134 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 135 | + config_.feat_config.normalize_samples = false; | ||
| 136 | + } | ||
| 137 | +#endif | ||
| 138 | + OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) = | ||
| 139 | + delete; | ||
| 140 | + | ||
| 141 | + OnlineRecognizerParaformerImpl operator=( | ||
| 142 | + const OnlineRecognizerParaformerImpl &) = delete; | ||
| 143 | + | ||
| 144 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 145 | + auto stream = std::make_unique<OnlineStream>(config_.feat_config); | ||
| 146 | + | ||
| 147 | + OnlineParaformerDecoderResult r; | ||
| 148 | + stream->SetParaformerResult(r); | ||
| 149 | + | ||
| 150 | + return stream; | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + bool IsReady(OnlineStream *s) const override { | ||
| 154 | + return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady(); | ||
| 155 | + } | ||
| 156 | + | ||
| 157 | + void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
| 158 | + // TODO(fangjun): Support batch size > 1 | ||
| 159 | + for (int32_t i = 0; i != n; ++i) { | ||
| 160 | + DecodeStream(ss[i]); | ||
| 161 | + } | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + OnlineRecognizerResult GetResult(OnlineStream *s) const override { | ||
| 165 | + auto decoder_result = s->GetParaformerResult(); | ||
| 166 | + | ||
| 167 | + return Convert(decoder_result, sym_); | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + bool IsEndpoint(OnlineStream *s) const override { | ||
| 171 | + if (!config_.enable_endpoint) { | ||
| 172 | + return false; | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + const auto &result = s->GetParaformerResult(); | ||
| 176 | + | ||
| 177 | + int32_t num_processed_frames = s->GetNumProcessedFrames(); | ||
| 178 | + | ||
| 179 | + // frame shift is 10 milliseconds | ||
| 180 | + float frame_shift_in_seconds = 0.01; | ||
| 181 | + | ||
| 182 | + int32_t trailing_silence_frames = | ||
| 183 | + num_processed_frames - result.last_non_blank_frame_index; | ||
| 184 | + | ||
| 185 | + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | ||
| 186 | + frame_shift_in_seconds); | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + void Reset(OnlineStream *s) const override { | ||
| 190 | + OnlineParaformerDecoderResult r; | ||
| 191 | + s->SetParaformerResult(r); | ||
| 192 | + | ||
| 193 | + // the internal model caches are not reset | ||
| 194 | + | ||
| 195 | + // Note: We only update counters. The underlying audio samples | ||
| 196 | + // are not discarded. | ||
| 197 | + s->Reset(); | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + private: | ||
| 201 | + void DecodeStream(OnlineStream *s) const { | ||
| 202 | + const auto num_processed_frames = s->GetNumProcessedFrames(); | ||
| 203 | + std::vector<float> frames = s->GetFrames(num_processed_frames, chunk_size_); | ||
| 204 | + s->GetNumProcessedFrames() += chunk_size_ - 1; | ||
| 205 | + | ||
| 206 | + frames = ApplyLFR(frames); | ||
| 207 | + ApplyCMVN(&frames); | ||
| 208 | + PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift()); | ||
| 209 | + | ||
| 210 | + int32_t feat_dim = model_.NegativeMean().size(); | ||
| 211 | + | ||
| 212 | + // We have scaled inv_stddev by sqrt(encoder_output_size) | ||
| 213 | + // so the following line can be commented out | ||
| 214 | + // frames *= encoder_output_size ** 0.5 | ||
| 215 | + | ||
| 216 | + // add overlap chunk | ||
| 217 | + std::vector<float> &feat_cache = s->GetParaformerFeatCache(); | ||
| 218 | + if (feat_cache.empty()) { | ||
| 219 | + int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim; | ||
| 220 | + feat_cache.resize(n, 0); | ||
| 221 | + } | ||
| 222 | + | ||
| 223 | + frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end()); | ||
| 224 | + std::copy(frames.end() - feat_cache.size(), frames.end(), | ||
| 225 | + feat_cache.begin()); | ||
| 226 | + | ||
| 227 | + int32_t num_frames = frames.size() / feat_dim; | ||
| 228 | + | ||
| 229 | + auto memory_info = | ||
| 230 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 231 | + | ||
| 232 | + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim}; | ||
| 233 | + Ort::Value x = | ||
| 234 | + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(), | ||
| 235 | + x_shape.data(), x_shape.size()); | ||
| 236 | + | ||
| 237 | + int64_t x_len_shape = 1; | ||
| 238 | + int32_t x_len_val = num_frames; | ||
| 239 | + | ||
| 240 | + Ort::Value x_length = | ||
| 241 | + Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1); | ||
| 242 | + | ||
| 243 | + auto encoder_out_vec = | ||
| 244 | + model_.ForwardEncoder(std::move(x), std::move(x_length)); | ||
| 245 | + | ||
| 246 | + // CIF search | ||
| 247 | + auto &encoder_out = encoder_out_vec[0]; | ||
| 248 | + auto &encoder_out_len = encoder_out_vec[1]; | ||
| 249 | + auto &alpha = encoder_out_vec[2]; | ||
| 250 | + | ||
| 251 | + float *p_alpha = alpha.GetTensorMutableData<float>(); | ||
| 252 | + | ||
| 253 | + std::vector<int64_t> alpha_shape = | ||
| 254 | + alpha.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 255 | + | ||
| 256 | + std::fill(p_alpha, p_alpha + left_chunk_size_, 0); | ||
| 257 | + std::fill(p_alpha + alpha_shape[1] - right_chunk_size_, | ||
| 258 | + p_alpha + alpha_shape[1], 0); | ||
| 259 | + | ||
| 260 | + const float *p_encoder_out = encoder_out.GetTensorData<float>(); | ||
| 261 | + | ||
| 262 | + std::vector<int64_t> encoder_out_shape = | ||
| 263 | + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 264 | + | ||
| 265 | + std::vector<float> &initial_hidden = s->GetParaformerEncoderOutCache(); | ||
| 266 | + if (initial_hidden.empty()) { | ||
| 267 | + initial_hidden.resize(encoder_out_shape[2]); | ||
| 268 | + } | ||
| 269 | + | ||
| 270 | + std::vector<float> &alpha_cache = s->GetParaformerAlphaCache(); | ||
| 271 | + if (alpha_cache.empty()) { | ||
| 272 | + alpha_cache.resize(1); | ||
| 273 | + } | ||
| 274 | + | ||
| 275 | + std::vector<float> acoustic_embedding; | ||
| 276 | + acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]); | ||
| 277 | + | ||
| 278 | + float threshold = 1.0; | ||
| 279 | + | ||
| 280 | + float integrate = alpha_cache[0]; | ||
| 281 | + | ||
| 282 | + for (int32_t i = 0; i != encoder_out_shape[1]; ++i) { | ||
| 283 | + float this_alpha = p_alpha[i]; | ||
| 284 | + if (integrate + this_alpha < threshold) { | ||
| 285 | + integrate += this_alpha; | ||
| 286 | + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], | ||
| 287 | + encoder_out_shape[2], this_alpha, | ||
| 288 | + initial_hidden.data()); | ||
| 289 | + continue; | ||
| 290 | + } | ||
| 291 | + | ||
| 292 | + // fire | ||
| 293 | + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], | ||
| 294 | + encoder_out_shape[2], threshold - integrate, | ||
| 295 | + initial_hidden.data()); | ||
| 296 | + acoustic_embedding.insert(acoustic_embedding.end(), | ||
| 297 | + initial_hidden.begin(), initial_hidden.end()); | ||
| 298 | + integrate += this_alpha - threshold; | ||
| 299 | + | ||
| 300 | + Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2], | ||
| 301 | + integrate, initial_hidden.data()); | ||
| 302 | + } | ||
| 303 | + | ||
| 304 | + alpha_cache[0] = integrate; | ||
| 305 | + | ||
| 306 | + if (acoustic_embedding.empty()) { | ||
| 307 | + return; | ||
| 308 | + } | ||
| 309 | + | ||
| 310 | + auto &states = s->GetStates(); | ||
| 311 | + if (states.empty()) { | ||
| 312 | + states.reserve(model_.DecoderNumBlocks()); | ||
| 313 | + | ||
| 314 | + std::array<int64_t, 3> shape{1, model_.EncoderOutputSize(), | ||
| 315 | + model_.DecoderKernelSize() - 1}; | ||
| 316 | + | ||
| 317 | + int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2]; | ||
| 318 | + | ||
| 319 | + for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) { | ||
| 320 | + Ort::Value this_state = Ort::Value::CreateTensor<float>( | ||
| 321 | + model_.Allocator(), shape.data(), shape.size()); | ||
| 322 | + | ||
| 323 | + memset(this_state.GetTensorMutableData<float>(), 0, num_bytes); | ||
| 324 | + | ||
| 325 | + states.push_back(std::move(this_state)); | ||
| 326 | + } | ||
| 327 | + } | ||
| 328 | + | ||
| 329 | + int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size(); | ||
| 330 | + std::array<int64_t, 3> acoustic_embedding_shape{ | ||
| 331 | + 1, num_tokens, static_cast<int32_t>(initial_hidden.size())}; | ||
| 332 | + | ||
| 333 | + Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor( | ||
| 334 | + memory_info, acoustic_embedding.data(), acoustic_embedding.size(), | ||
| 335 | + acoustic_embedding_shape.data(), acoustic_embedding_shape.size()); | ||
| 336 | + | ||
| 337 | + std::array<int64_t, 1> acoustic_embedding_length_shape{1}; | ||
| 338 | + Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor( | ||
| 339 | + memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(), | ||
| 340 | + acoustic_embedding_length_shape.size()); | ||
| 341 | + | ||
| 342 | + auto decoder_out_vec = model_.ForwardDecoder( | ||
| 343 | + std::move(encoder_out), std::move(encoder_out_len), | ||
| 344 | + std::move(acoustic_embedding_tensor), | ||
| 345 | + std::move(acoustic_embedding_length_tensor), std::move(states)); | ||
| 346 | + | ||
| 347 | + states.reserve(model_.DecoderNumBlocks()); | ||
| 348 | + for (int32_t i = 2; i != decoder_out_vec.size(); ++i) { | ||
| 349 | + // TODO(fangjun): When we change chunk_size_, we need to | ||
| 350 | + // slice decoder_out_vec[i] accordingly. | ||
| 351 | + states.push_back(std::move(decoder_out_vec[i])); | ||
| 352 | + } | ||
| 353 | + | ||
| 354 | + const auto &sample_ids = decoder_out_vec[1]; | ||
| 355 | + const int64_t *p_sample_ids = sample_ids.GetTensorData<int64_t>(); | ||
| 356 | + | ||
| 357 | + bool non_blank_detected = false; | ||
| 358 | + | ||
| 359 | + auto &result = s->GetParaformerResult(); | ||
| 360 | + | ||
| 361 | + for (int32_t i = 0; i != num_tokens; ++i) { | ||
| 362 | + int32_t t = p_sample_ids[i]; | ||
| 363 | + if (t == 0) { | ||
| 364 | + continue; | ||
| 365 | + } | ||
| 366 | + | ||
| 367 | + non_blank_detected = true; | ||
| 368 | + result.tokens.push_back(t); | ||
| 369 | + } | ||
| 370 | + | ||
| 371 | + if (non_blank_detected) { | ||
| 372 | + result.last_non_blank_frame_index = num_processed_frames; | ||
| 373 | + } | ||
| 374 | + } | ||
| 375 | + | ||
| 376 | + std::vector<float> ApplyLFR(const std::vector<float> &in) const { | ||
| 377 | + int32_t lfr_window_size = model_.LfrWindowSize(); | ||
| 378 | + int32_t lfr_window_shift = model_.LfrWindowShift(); | ||
| 379 | + int32_t in_feat_dim = config_.feat_config.feature_dim; | ||
| 380 | + | ||
| 381 | + int32_t in_num_frames = in.size() / in_feat_dim; | ||
| 382 | + int32_t out_num_frames = | ||
| 383 | + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; | ||
| 384 | + int32_t out_feat_dim = in_feat_dim * lfr_window_size; | ||
| 385 | + | ||
| 386 | + std::vector<float> out(out_num_frames * out_feat_dim); | ||
| 387 | + | ||
| 388 | + const float *p_in = in.data(); | ||
| 389 | + float *p_out = out.data(); | ||
| 390 | + | ||
| 391 | + for (int32_t i = 0; i != out_num_frames; ++i) { | ||
| 392 | + std::copy(p_in, p_in + out_feat_dim, p_out); | ||
| 393 | + | ||
| 394 | + p_out += out_feat_dim; | ||
| 395 | + p_in += lfr_window_shift * in_feat_dim; | ||
| 396 | + } | ||
| 397 | + | ||
| 398 | + return out; | ||
| 399 | + } | ||
| 400 | + | ||
| 401 | + void ApplyCMVN(std::vector<float> *v) const { | ||
| 402 | + const std::vector<float> &neg_mean = model_.NegativeMean(); | ||
| 403 | + const std::vector<float> &inv_stddev = model_.InverseStdDev(); | ||
| 404 | + | ||
| 405 | + int32_t dim = neg_mean.size(); | ||
| 406 | + int32_t num_frames = v->size() / dim; | ||
| 407 | + | ||
| 408 | + float *p = v->data(); | ||
| 409 | + | ||
| 410 | + for (int32_t i = 0; i != num_frames; ++i) { | ||
| 411 | + for (int32_t k = 0; k != dim; ++k) { | ||
| 412 | + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; | ||
| 413 | + } | ||
| 414 | + | ||
| 415 | + p += dim; | ||
| 416 | + } | ||
| 417 | + } | ||
| 418 | + | ||
| 419 | + void PositionalEncoding(std::vector<float> *v, int32_t t_offset) const { | ||
| 420 | + int32_t lfr_window_size = model_.LfrWindowSize(); | ||
| 421 | + int32_t in_feat_dim = config_.feat_config.feature_dim; | ||
| 422 | + | ||
| 423 | + int32_t feat_dim = in_feat_dim * lfr_window_size; | ||
| 424 | + int32_t T = v->size() / feat_dim; | ||
| 425 | + | ||
| 426 | + // log(10000)/(7*80/2-1) == 0.03301197265941284 | ||
| 427 | + // 7 is lfr_window_size | ||
| 428 | + // 80 is in_feat_dim | ||
| 429 | + // 7*80 is feat_dim | ||
| 430 | + constexpr float kScale = -0.03301197265941284; | ||
| 431 | + | ||
| 432 | + for (int32_t t = 0; t != T; ++t) { | ||
| 433 | + float *p = v->data() + t * feat_dim; | ||
| 434 | + | ||
| 435 | + int32_t offset = t + 1 + t_offset; | ||
| 436 | + | ||
| 437 | + for (int32_t d = 0; d < feat_dim / 2; ++d) { | ||
| 438 | + float inv_timescale = offset * std::exp(d * kScale); | ||
| 439 | + | ||
| 440 | + float sin_d = std::sin(inv_timescale); | ||
| 441 | + float cos_d = std::cos(inv_timescale); | ||
| 442 | + | ||
| 443 | + p[d] += sin_d; | ||
| 444 | + p[d + feat_dim / 2] += cos_d; | ||
| 445 | + } | ||
| 446 | + } | ||
| 447 | + } | ||
| 448 | + | ||
| 449 | + private: | ||
| 450 | + OnlineRecognizerConfig config_; | ||
| 451 | + OnlineParaformerModel model_; | ||
| 452 | + SymbolTable sym_; | ||
| 453 | + Endpoint endpoint_; | ||
| 454 | + | ||
| 455 | + // 0.61 seconds | ||
| 456 | + int32_t chunk_size_ = 61; | ||
| 457 | + // (61 - 7) / 6 + 1 = 10 | ||
| 458 | + | ||
| 459 | + int32_t left_chunk_size_ = 5; | ||
| 460 | + int32_t right_chunk_size_ = 5; | ||
| 461 | +}; | ||
| 462 | + | ||
| 463 | +} // namespace sherpa_onnx | ||
| 464 | + | ||
| 465 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ |
| @@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 94 | } | 94 | } |
| 95 | #endif | 95 | #endif |
| 96 | 96 | ||
| 97 | - void InitOnlineStream(OnlineStream *stream) const override { | ||
| 98 | - auto r = decoder_->GetEmptyResult(); | ||
| 99 | - | ||
| 100 | - if (config_.decoding_method == "modified_beam_search" && | ||
| 101 | - nullptr != stream->GetContextGraph()) { | ||
| 102 | - // r.hyps has only one element. | ||
| 103 | - for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { | ||
| 104 | - it->second.context_state = stream->GetContextGraph()->Root(); | ||
| 105 | - } | ||
| 106 | - } | ||
| 107 | - | ||
| 108 | - stream->SetResult(r); | ||
| 109 | - stream->SetStates(model_->GetEncoderInitStates()); | ||
| 110 | - } | ||
| 111 | - | ||
| 112 | std::unique_ptr<OnlineStream> CreateStream() const override { | 97 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| 113 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 98 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); |
| 114 | InitOnlineStream(stream.get()); | 99 | InitOnlineStream(stream.get()); |
| @@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 211 | } | 196 | } |
| 212 | 197 | ||
| 213 | bool IsEndpoint(OnlineStream *s) const override { | 198 | bool IsEndpoint(OnlineStream *s) const override { |
| 214 | - if (!config_.enable_endpoint) return false; | 199 | + if (!config_.enable_endpoint) { |
| 200 | + return false; | ||
| 201 | + } | ||
| 202 | + | ||
| 215 | int32_t num_processed_frames = s->GetNumProcessedFrames(); | 203 | int32_t num_processed_frames = s->GetNumProcessedFrames(); |
| 216 | 204 | ||
| 217 | // frame shift is 10 milliseconds | 205 | // frame shift is 10 milliseconds |
| @@ -245,6 +233,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -245,6 +233,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 245 | } | 233 | } |
| 246 | 234 | ||
| 247 | private: | 235 | private: |
| 236 | + void InitOnlineStream(OnlineStream *stream) const { | ||
| 237 | + auto r = decoder_->GetEmptyResult(); | ||
| 238 | + | ||
| 239 | + if (config_.decoding_method == "modified_beam_search" && | ||
| 240 | + nullptr != stream->GetContextGraph()) { | ||
| 241 | + // r.hyps has only one element. | ||
| 242 | + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { | ||
| 243 | + it->second.context_state = stream->GetContextGraph()->Root(); | ||
| 244 | + } | ||
| 245 | + } | ||
| 246 | + | ||
| 247 | + stream->SetResult(r); | ||
| 248 | + stream->SetStates(model_->GetEncoderInitStates()); | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + private: | ||
| 248 | OnlineRecognizerConfig config_; | 252 | OnlineRecognizerConfig config_; |
| 249 | std::unique_ptr<OnlineTransducerModel> model_; | 253 | std::unique_ptr<OnlineTransducerModel> model_; |
| 250 | std::unique_ptr<OnlineLM> lm_; | 254 | std::unique_ptr<OnlineLM> lm_; |
| @@ -47,6 +47,14 @@ class OnlineStream::Impl { | @@ -47,6 +47,14 @@ class OnlineStream::Impl { | ||
| 47 | 47 | ||
| 48 | OnlineTransducerDecoderResult &GetResult() { return result_; } | 48 | OnlineTransducerDecoderResult &GetResult() { return result_; } |
| 49 | 49 | ||
| 50 | + void SetParaformerResult(const OnlineParaformerDecoderResult &r) { | ||
| 51 | + paraformer_result_ = r; | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + OnlineParaformerDecoderResult &GetParaformerResult() { | ||
| 55 | + return paraformer_result_; | ||
| 56 | + } | ||
| 57 | + | ||
| 50 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } | 58 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } |
| 51 | 59 | ||
| 52 | void SetStates(std::vector<Ort::Value> states) { | 60 | void SetStates(std::vector<Ort::Value> states) { |
| @@ -57,6 +65,18 @@ class OnlineStream::Impl { | @@ -57,6 +65,18 @@ class OnlineStream::Impl { | ||
| 57 | 65 | ||
| 58 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } | 66 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } |
| 59 | 67 | ||
| 68 | + std::vector<float> &GetParaformerFeatCache() { | ||
| 69 | + return paraformer_feat_cache_; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + std::vector<float> &GetParaformerEncoderOutCache() { | ||
| 73 | + return paraformer_encoder_out_cache_; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + std::vector<float> &GetParaformerAlphaCache() { | ||
| 77 | + return paraformer_alpha_cache_; | ||
| 78 | + } | ||
| 79 | + | ||
| 60 | private: | 80 | private: |
| 61 | FeatureExtractor feat_extractor_; | 81 | FeatureExtractor feat_extractor_; |
| 62 | /// For contextual-biasing | 82 | /// For contextual-biasing |
| @@ -65,6 +85,10 @@ class OnlineStream::Impl { | @@ -65,6 +85,10 @@ class OnlineStream::Impl { | ||
| 65 | int32_t start_frame_index_ = 0; // never reset | 85 | int32_t start_frame_index_ = 0; // never reset |
| 66 | OnlineTransducerDecoderResult result_; | 86 | OnlineTransducerDecoderResult result_; |
| 67 | std::vector<Ort::Value> states_; | 87 | std::vector<Ort::Value> states_; |
| 88 | + std::vector<float> paraformer_feat_cache_; | ||
| 89 | + std::vector<float> paraformer_encoder_out_cache_; | ||
| 90 | + std::vector<float> paraformer_alpha_cache_; | ||
| 91 | + OnlineParaformerDecoderResult paraformer_result_; | ||
| 68 | }; | 92 | }; |
| 69 | 93 | ||
| 70 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, | 94 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, |
| @@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | @@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | ||
| 107 | return impl_->GetResult(); | 131 | return impl_->GetResult(); |
| 108 | } | 132 | } |
| 109 | 133 | ||
| 134 | +void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { | ||
| 135 | + impl_->SetParaformerResult(r); | ||
| 136 | +} | ||
| 137 | + | ||
| 138 | +OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() { | ||
| 139 | + return impl_->GetParaformerResult(); | ||
| 140 | +} | ||
| 141 | + | ||
| 110 | void OnlineStream::SetStates(std::vector<Ort::Value> states) { | 142 | void OnlineStream::SetStates(std::vector<Ort::Value> states) { |
| 111 | impl_->SetStates(std::move(states)); | 143 | impl_->SetStates(std::move(states)); |
| 112 | } | 144 | } |
| @@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { | @@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { | ||
| 119 | return impl_->GetContextGraph(); | 151 | return impl_->GetContextGraph(); |
| 120 | } | 152 | } |
| 121 | 153 | ||
| 154 | +std::vector<float> &OnlineStream::GetParaformerFeatCache() { | ||
| 155 | + return impl_->GetParaformerFeatCache(); | ||
| 156 | +} | ||
| 157 | + | ||
| 158 | +std::vector<float> &OnlineStream::GetParaformerEncoderOutCache() { | ||
| 159 | + return impl_->GetParaformerEncoderOutCache(); | ||
| 160 | +} | ||
| 161 | + | ||
| 162 | +std::vector<float> &OnlineStream::GetParaformerAlphaCache() { | ||
| 163 | + return impl_->GetParaformerAlphaCache(); | ||
| 164 | +} | ||
| 165 | + | ||
| 122 | } // namespace sherpa_onnx | 166 | } // namespace sherpa_onnx |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/context-graph.h" | 12 | #include "sherpa-onnx/csrc/context-graph.h" |
| 13 | #include "sherpa-onnx/csrc/features.h" | 13 | #include "sherpa-onnx/csrc/features.h" |
| 14 | +#include "sherpa-onnx/csrc/online-paraformer-decoder.h" | ||
| 14 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 15 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -70,6 +71,9 @@ class OnlineStream { | @@ -70,6 +71,9 @@ class OnlineStream { | ||
| 70 | void SetResult(const OnlineTransducerDecoderResult &r); | 71 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 71 | OnlineTransducerDecoderResult &GetResult(); | 72 | OnlineTransducerDecoderResult &GetResult(); |
| 72 | 73 | ||
| 74 | + void SetParaformerResult(const OnlineParaformerDecoderResult &r); | ||
| 75 | + OnlineParaformerDecoderResult &GetParaformerResult(); | ||
| 76 | + | ||
| 73 | void SetStates(std::vector<Ort::Value> states); | 77 | void SetStates(std::vector<Ort::Value> states); |
| 74 | std::vector<Ort::Value> &GetStates(); | 78 | std::vector<Ort::Value> &GetStates(); |
| 75 | 79 | ||
| @@ -80,6 +84,11 @@ class OnlineStream { | @@ -80,6 +84,11 @@ class OnlineStream { | ||
| 80 | */ | 84 | */ |
| 81 | const ContextGraphPtr &GetContextGraph() const; | 85 | const ContextGraphPtr &GetContextGraph() const; |
| 82 | 86 | ||
| 87 | + // for streaming parformer | ||
| 88 | + std::vector<float> &GetParaformerFeatCache(); | ||
| 89 | + std::vector<float> &GetParaformerEncoderOutCache(); | ||
| 90 | + std::vector<float> &GetParaformerAlphaCache(); | ||
| 91 | + | ||
| 83 | private: | 92 | private: |
| 84 | class Impl; | 93 | class Impl; |
| 85 | std::unique_ptr<Impl> impl_; | 94 | std::unique_ptr<Impl> impl_; |
| @@ -12,8 +12,8 @@ | @@ -12,8 +12,8 @@ | ||
| 12 | 12 | ||
| 13 | #include "sherpa-onnx/csrc/online-recognizer.h" | 13 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 14 | #include "sherpa-onnx/csrc/online-stream.h" | 14 | #include "sherpa-onnx/csrc/online-stream.h" |
| 15 | -#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 16 | #include "sherpa-onnx/csrc/parse-options.h" | 15 | #include "sherpa-onnx/csrc/parse-options.h" |
| 16 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 17 | #include "sherpa-onnx/csrc/wave-reader.h" | 17 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 18 | 18 | ||
| 19 | typedef struct { | 19 | typedef struct { |
| @@ -80,7 +80,7 @@ for a list of pre-trained models to download. | @@ -80,7 +80,7 @@ for a list of pre-trained models to download. | ||
| 80 | 80 | ||
| 81 | bool is_ok = false; | 81 | bool is_ok = false; |
| 82 | const std::vector<float> samples = | 82 | const std::vector<float> samples = |
| 83 | - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | 83 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); |
| 84 | 84 | ||
| 85 | if (!is_ok) { | 85 | if (!is_ok) { |
| 86 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | 86 | fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| @@ -92,14 +92,14 @@ for a list of pre-trained models to download. | @@ -92,14 +92,14 @@ for a list of pre-trained models to download. | ||
| 92 | auto s = recognizer.CreateStream(); | 92 | auto s = recognizer.CreateStream(); |
| 93 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | 93 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); |
| 94 | 94 | ||
| 95 | - std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate)); | 95 | + std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate)); |
| 96 | // Note: We can call AcceptWaveform() multiple times. | 96 | // Note: We can call AcceptWaveform() multiple times. |
| 97 | - s->AcceptWaveform( | ||
| 98 | - sampling_rate, tail_paddings.data(), tail_paddings.size()); | 97 | + s->AcceptWaveform(sampling_rate, tail_paddings.data(), |
| 98 | + tail_paddings.size()); | ||
| 99 | 99 | ||
| 100 | // Call InputFinished() to indicate that no audio samples are available | 100 | // Call InputFinished() to indicate that no audio samples are available |
| 101 | s->InputFinished(); | 101 | s->InputFinished(); |
| 102 | - ss.push_back({ std::move(s), duration, 0 }); | 102 | + ss.push_back({std::move(s), duration, 0}); |
| 103 | } | 103 | } |
| 104 | 104 | ||
| 105 | std::vector<sherpa_onnx::OnlineStream *> ready_streams; | 105 | std::vector<sherpa_onnx::OnlineStream *> ready_streams; |
| @@ -112,8 +112,9 @@ for a list of pre-trained models to download. | @@ -112,8 +112,9 @@ for a list of pre-trained models to download. | ||
| 112 | } else if (s.elapsed_seconds == 0) { | 112 | } else if (s.elapsed_seconds == 0) { |
| 113 | const auto end = std::chrono::steady_clock::now(); | 113 | const auto end = std::chrono::steady_clock::now(); |
| 114 | const float elapsed_seconds = | 114 | const float elapsed_seconds = |
| 115 | - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 116 | - .count() / 1000.; | 115 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) |
| 116 | + .count() / | ||
| 117 | + 1000.; | ||
| 117 | s.elapsed_seconds = elapsed_seconds; | 118 | s.elapsed_seconds = elapsed_seconds; |
| 118 | } | 119 | } |
| 119 | } | 120 | } |
| @@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx | @@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 15 | offline-whisper-model-config.cc | 15 | offline-whisper-model-config.cc |
| 16 | online-lm-config.cc | 16 | online-lm-config.cc |
| 17 | online-model-config.cc | 17 | online-model-config.cc |
| 18 | + online-paraformer-model-config.cc | ||
| 18 | online-recognizer.cc | 19 | online-recognizer.cc |
| 19 | online-stream.cc | 20 | online-stream.cc |
| 20 | online-transducer-model-config.cc | 21 | online-transducer-model-config.cc |
| 1 | // sherpa-onnx/python/csrc/online-model-config.cc | 1 | // sherpa-onnx/python/csrc/online-model-config.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 by manyeyes | 3 | +// Copyright (c) 2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/online-model-config.h" | 5 | #include "sherpa-onnx/python/csrc/online-model-config.h" |
| 6 | 6 | ||
| @@ -9,21 +9,26 @@ | @@ -9,21 +9,26 @@ | ||
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/online-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 12 | +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | ||
| 12 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| 13 | 14 | ||
| 14 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 15 | 16 | ||
| 16 | void PybindOnlineModelConfig(py::module *m) { | 17 | void PybindOnlineModelConfig(py::module *m) { |
| 17 | PybindOnlineTransducerModelConfig(m); | 18 | PybindOnlineTransducerModelConfig(m); |
| 19 | + PybindOnlineParaformerModelConfig(m); | ||
| 18 | 20 | ||
| 19 | using PyClass = OnlineModelConfig; | 21 | using PyClass = OnlineModelConfig; |
| 20 | py::class_<PyClass>(*m, "OnlineModelConfig") | 22 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| 21 | - .def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t, | 23 | + .def(py::init<const OnlineTransducerModelConfig &, |
| 24 | + const OnlineParaformerModelConfig &, std::string &, int32_t, | ||
| 22 | bool, const std::string &, const std::string &>(), | 25 | bool, const std::string &, const std::string &>(), |
| 23 | py::arg("transducer") = OnlineTransducerModelConfig(), | 26 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 27 | + py::arg("paraformer") = OnlineParaformerModelConfig(), | ||
| 24 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 28 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 25 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 29 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 26 | .def_readwrite("transducer", &PyClass::transducer) | 30 | .def_readwrite("transducer", &PyClass::transducer) |
| 31 | + .def_readwrite("paraformer", &PyClass::paraformer) | ||
| 27 | .def_readwrite("tokens", &PyClass::tokens) | 32 | .def_readwrite("tokens", &PyClass::tokens) |
| 28 | .def_readwrite("num_threads", &PyClass::num_threads) | 33 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 29 | .def_readwrite("debug", &PyClass::debug) | 34 | .def_readwrite("debug", &PyClass::debug) |
| 1 | // sherpa-onnx/python/csrc/online-model-config.h | 1 | // sherpa-onnx/python/csrc/online-model-config.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 by manyeyes | 3 | +// Copyright (c) 2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ | 5 | #ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ |
| 6 | #define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ | 6 | #define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/python/csrc/online-paraformer-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-paraformer-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOnlineParaformerModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OnlineParaformerModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OnlineParaformerModelConfig") | ||
| 17 | + .def(py::init<const std::string &, const std::string &>(), | ||
| 18 | + py::arg("encoder"), py::arg("decoder")) | ||
| 19 | + .def_readwrite("encoder", &PyClass::encoder) | ||
| 20 | + .def_readwrite("decoder", &PyClass::decoder) | ||
| 21 | + .def("__str__", &PyClass::ToString); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/online-paraformer-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineParaformerModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ |
| @@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 33 | py::arg("feat_config"), py::arg("model_config"), | 33 | py::arg("feat_config"), py::arg("model_config"), |
| 34 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), | 34 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), |
| 35 | py::arg("enable_endpoint"), py::arg("decoding_method"), | 35 | py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 36 | - py::arg("max_active_paths"), py::arg("context_score")) | 36 | + py::arg("max_active_paths") = 4, py::arg("context_score") = 0) |
| 37 | .def_readwrite("feat_config", &PyClass::feat_config) | 37 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 38 | .def_readwrite("model_config", &PyClass::model_config) | 38 | .def_readwrite("model_config", &PyClass::model_config) |
| 39 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 39 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| @@ -6,6 +6,7 @@ from _sherpa_onnx import ( | @@ -6,6 +6,7 @@ from _sherpa_onnx import ( | ||
| 6 | EndpointConfig, | 6 | EndpointConfig, |
| 7 | FeatureExtractorConfig, | 7 | FeatureExtractorConfig, |
| 8 | OnlineModelConfig, | 8 | OnlineModelConfig, |
| 9 | + OnlineParaformerModelConfig, | ||
| 9 | OnlineRecognizer as _Recognizer, | 10 | OnlineRecognizer as _Recognizer, |
| 10 | OnlineRecognizerConfig, | 11 | OnlineRecognizerConfig, |
| 11 | OnlineStream, | 12 | OnlineStream, |
| @@ -32,7 +33,7 @@ class OnlineRecognizer(object): | @@ -32,7 +33,7 @@ class OnlineRecognizer(object): | ||
| 32 | encoder: str, | 33 | encoder: str, |
| 33 | decoder: str, | 34 | decoder: str, |
| 34 | joiner: str, | 35 | joiner: str, |
| 35 | - num_threads: int = 4, | 36 | + num_threads: int = 2, |
| 36 | sample_rate: float = 16000, | 37 | sample_rate: float = 16000, |
| 37 | feature_dim: int = 80, | 38 | feature_dim: int = 80, |
| 38 | enable_endpoint_detection: bool = False, | 39 | enable_endpoint_detection: bool = False, |
| @@ -144,6 +145,109 @@ class OnlineRecognizer(object): | @@ -144,6 +145,109 @@ class OnlineRecognizer(object): | ||
| 144 | self.config = recognizer_config | 145 | self.config = recognizer_config |
| 145 | return self | 146 | return self |
| 146 | 147 | ||
| 148 | + @classmethod | ||
| 149 | + def from_paraformer( | ||
| 150 | + cls, | ||
| 151 | + tokens: str, | ||
| 152 | + encoder: str, | ||
| 153 | + decoder: str, | ||
| 154 | + num_threads: int = 2, | ||
| 155 | + sample_rate: float = 16000, | ||
| 156 | + feature_dim: int = 80, | ||
| 157 | + enable_endpoint_detection: bool = False, | ||
| 158 | + rule1_min_trailing_silence: float = 2.4, | ||
| 159 | + rule2_min_trailing_silence: float = 1.2, | ||
| 160 | + rule3_min_utterance_length: float = 20.0, | ||
| 161 | + decoding_method: str = "greedy_search", | ||
| 162 | + provider: str = "cpu", | ||
| 163 | + ): | ||
| 164 | + """ | ||
| 165 | + Please refer to | ||
| 166 | + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_ | ||
| 167 | + to download pre-trained models for different languages, e.g., Chinese, | ||
| 168 | + English, etc. | ||
| 169 | + | ||
| 170 | + Args: | ||
| 171 | + tokens: | ||
| 172 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 173 | + columns:: | ||
| 174 | + | ||
| 175 | + symbol integer_id | ||
| 176 | + | ||
| 177 | + encoder: | ||
| 178 | + Path to ``encoder.onnx``. | ||
| 179 | + decoder: | ||
| 180 | + Path to ``decoder.onnx``. | ||
| 181 | + num_threads: | ||
| 182 | + Number of threads for neural network computation. | ||
| 183 | + sample_rate: | ||
| 184 | + Sample rate of the training data used to train the model. | ||
| 185 | + feature_dim: | ||
| 186 | + Dimension of the feature used to train the model. | ||
| 187 | + enable_endpoint_detection: | ||
| 188 | + True to enable endpoint detection. False to disable endpoint | ||
| 189 | + detection. | ||
| 190 | + rule1_min_trailing_silence: | ||
| 191 | + Used only when enable_endpoint_detection is True. If the duration | ||
| 192 | + of trailing silence in seconds is larger than this value, we assume | ||
| 193 | + an endpoint is detected. | ||
| 194 | + rule2_min_trailing_silence: | ||
| 195 | + Used only when enable_endpoint_detection is True. If we have decoded | ||
| 196 | + something that is nonsilence and if the duration of trailing silence | ||
| 197 | + in seconds is larger than this value, we assume an endpoint is | ||
| 198 | + detected. | ||
| 199 | + rule3_min_utterance_length: | ||
| 200 | + Used only when enable_endpoint_detection is True. If the utterance | ||
| 201 | + length in seconds is larger than this value, we assume an endpoint | ||
| 202 | + is detected. | ||
| 203 | + decoding_method: | ||
| 204 | + The only valid value is greedy_search. | ||
| 205 | + provider: | ||
| 206 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 207 | + """ | ||
| 208 | + self = cls.__new__(cls) | ||
| 209 | + _assert_file_exists(tokens) | ||
| 210 | + _assert_file_exists(encoder) | ||
| 211 | + _assert_file_exists(decoder) | ||
| 212 | + | ||
| 213 | + assert num_threads > 0, num_threads | ||
| 214 | + | ||
| 215 | + paraformer_config = OnlineParaformerModelConfig( | ||
| 216 | + encoder=encoder, | ||
| 217 | + decoder=decoder, | ||
| 218 | + ) | ||
| 219 | + | ||
| 220 | + model_config = OnlineModelConfig( | ||
| 221 | + paraformer=paraformer_config, | ||
| 222 | + tokens=tokens, | ||
| 223 | + num_threads=num_threads, | ||
| 224 | + provider=provider, | ||
| 225 | + model_type="paraformer", | ||
| 226 | + ) | ||
| 227 | + | ||
| 228 | + feat_config = FeatureExtractorConfig( | ||
| 229 | + sampling_rate=sample_rate, | ||
| 230 | + feature_dim=feature_dim, | ||
| 231 | + ) | ||
| 232 | + | ||
| 233 | + endpoint_config = EndpointConfig( | ||
| 234 | + rule1_min_trailing_silence=rule1_min_trailing_silence, | ||
| 235 | + rule2_min_trailing_silence=rule2_min_trailing_silence, | ||
| 236 | + rule3_min_utterance_length=rule3_min_utterance_length, | ||
| 237 | + ) | ||
| 238 | + | ||
| 239 | + recognizer_config = OnlineRecognizerConfig( | ||
| 240 | + feat_config=feat_config, | ||
| 241 | + model_config=model_config, | ||
| 242 | + endpoint_config=endpoint_config, | ||
| 243 | + enable_endpoint=enable_endpoint_detection, | ||
| 244 | + decoding_method=decoding_method, | ||
| 245 | + ) | ||
| 246 | + | ||
| 247 | + self.recognizer = _Recognizer(recognizer_config) | ||
| 248 | + self.config = recognizer_config | ||
| 249 | + return self | ||
| 250 | + | ||
| 147 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): | 251 | def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
| 148 | if contexts_list is None: | 252 | if contexts_list is None: |
| 149 | return self.recognizer.create_stream() | 253 | return self.recognizer.create_stream() |
-
请 注册 或 登录 后发表评论