Fangjun Kuang
Committed by GitHub

Support streaming paraformer (#263)

正在显示 38 个修改的文件 包含 1488 行增加112 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run streaming Paraformer"
  18 +log "------------------------------------------------------------"
  19 +
  20 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
  21 +log "Start testing ${repo_url}"
  22 +repo=$(basename $repo_url)
  23 +log "Download pretrained model and test-data from $repo_url"
  24 +
  25 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  26 +pushd $repo
  27 +git lfs pull --include "*.onnx"
  28 +ls -lh *.onnx
  29 +popd
  30 +
  31 +time $EXE \
  32 + --tokens=$repo/tokens.txt \
  33 + --paraformer-encoder=$repo/encoder.onnx \
  34 + --paraformer-decoder=$repo/decoder.onnx \
  35 + --num-threads=2 \
  36 + $repo/test_wavs/0.wav \
  37 + $repo/test_wavs/1.wav \
  38 + $repo/test_wavs/2.wav \
  39 + $repo/test_wavs/3.wav \
  40 + $repo/test_wavs/8k.wav
  41 +
  42 +time $EXE \
  43 + --tokens=$repo/tokens.txt \
  44 + --paraformer-encoder=$repo/encoder.int8.onnx \
  45 + --paraformer-decoder=$repo/decoder.int8.onnx \
  46 + --num-threads=2 \
  47 + $repo/test_wavs/0.wav \
  48 + $repo/test_wavs/1.wav \
  49 + $repo/test_wavs/2.wav \
  50 + $repo/test_wavs/3.wav \
  51 + $repo/test_wavs/8k.wav
  52 +
  53 +rm -rf $repo
@@ -9,6 +9,7 @@ on: @@ -9,6 +9,7 @@ on:
9 paths: 9 paths:
10 - '.github/workflows/linux-gpu.yaml' 10 - '.github/workflows/linux-gpu.yaml'
11 - '.github/scripts/test-online-transducer.sh' 11 - '.github/scripts/test-online-transducer.sh'
  12 + - '.github/scripts/test-online-paraformer.sh'
12 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
13 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
14 - 'CMakeLists.txt' 15 - 'CMakeLists.txt'
@@ -22,6 +23,7 @@ on: @@ -22,6 +23,7 @@ on:
22 paths: 23 paths:
23 - '.github/workflows/linux-gpu.yaml' 24 - '.github/workflows/linux-gpu.yaml'
24 - '.github/scripts/test-online-transducer.sh' 25 - '.github/scripts/test-online-transducer.sh'
  26 + - '.github/scripts/test-online-paraformer.sh'
25 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
26 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
27 - 'CMakeLists.txt' 29 - 'CMakeLists.txt'
@@ -85,6 +87,14 @@ jobs: @@ -85,6 +87,14 @@ jobs:
85 file build/bin/sherpa-onnx 87 file build/bin/sherpa-onnx
86 readelf -d build/bin/sherpa-onnx 88 readelf -d build/bin/sherpa-onnx
87 89
  90 + - name: Test online paraformer
  91 + shell: bash
  92 + run: |
  93 + export PATH=$PWD/build/bin:$PATH
  94 + export EXE=sherpa-onnx
  95 +
  96 + .github/scripts/test-online-paraformer.sh
  97 +
88 - name: Test offline Whisper 98 - name: Test offline Whisper
89 shell: bash 99 shell: bash
90 run: | 100 run: |
@@ -9,6 +9,7 @@ on: @@ -9,6 +9,7 @@ on:
9 paths: 9 paths:
10 - '.github/workflows/linux.yaml' 10 - '.github/workflows/linux.yaml'
11 - '.github/scripts/test-online-transducer.sh' 11 - '.github/scripts/test-online-transducer.sh'
  12 + - '.github/scripts/test-online-paraformer.sh'
12 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
13 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
14 - 'CMakeLists.txt' 15 - 'CMakeLists.txt'
@@ -22,6 +23,7 @@ on: @@ -22,6 +23,7 @@ on:
22 paths: 23 paths:
23 - '.github/workflows/linux.yaml' 24 - '.github/workflows/linux.yaml'
24 - '.github/scripts/test-online-transducer.sh' 25 - '.github/scripts/test-online-transducer.sh'
  26 + - '.github/scripts/test-online-paraformer.sh'
25 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
26 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
27 - 'CMakeLists.txt' 29 - 'CMakeLists.txt'
@@ -84,6 +86,14 @@ jobs: @@ -84,6 +86,14 @@ jobs:
84 file build/bin/sherpa-onnx 86 file build/bin/sherpa-onnx
85 readelf -d build/bin/sherpa-onnx 87 readelf -d build/bin/sherpa-onnx
86 88
  89 + - name: Test online paraformer
  90 + shell: bash
  91 + run: |
  92 + export PATH=$PWD/build/bin:$PATH
  93 + export EXE=sherpa-onnx
  94 +
  95 + .github/scripts/test-online-paraformer.sh
  96 +
87 - name: Test offline Whisper 97 - name: Test offline Whisper
88 shell: bash 98 shell: bash
89 run: | 99 run: |
@@ -7,6 +7,7 @@ on: @@ -7,6 +7,7 @@ on:
7 paths: 7 paths:
8 - '.github/workflows/macos.yaml' 8 - '.github/workflows/macos.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
  10 + - '.github/scripts/test-online-paraformer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 11 - '.github/scripts/test-offline-transducer.sh'
11 - '.github/scripts/test-offline-ctc.sh' 12 - '.github/scripts/test-offline-ctc.sh'
12 - 'CMakeLists.txt' 13 - 'CMakeLists.txt'
@@ -18,6 +19,7 @@ on: @@ -18,6 +19,7 @@ on:
18 paths: 19 paths:
19 - '.github/workflows/macos.yaml' 20 - '.github/workflows/macos.yaml'
20 - '.github/scripts/test-online-transducer.sh' 21 - '.github/scripts/test-online-transducer.sh'
  22 + - '.github/scripts/test-online-paraformer.sh'
21 - '.github/scripts/test-offline-transducer.sh' 23 - '.github/scripts/test-offline-transducer.sh'
22 - '.github/scripts/test-offline-ctc.sh' 24 - '.github/scripts/test-offline-ctc.sh'
23 - 'CMakeLists.txt' 25 - 'CMakeLists.txt'
@@ -82,6 +84,14 @@ jobs: @@ -82,6 +84,14 @@ jobs:
82 otool -L build/bin/sherpa-onnx 84 otool -L build/bin/sherpa-onnx
83 otool -l build/bin/sherpa-onnx 85 otool -l build/bin/sherpa-onnx
84 86
  87 + - name: Test online paraformer
  88 + shell: bash
  89 + run: |
  90 + export PATH=$PWD/build/bin:$PATH
  91 + export EXE=sherpa-onnx
  92 +
  93 + .github/scripts/test-online-paraformer.sh
  94 +
85 - name: Test offline Whisper 95 - name: Test offline Whisper
86 shell: bash 96 shell: bash
87 run: | 97 run: |
@@ -58,7 +58,6 @@ jobs: @@ -58,7 +58,6 @@ jobs:
58 sherpa-onnx-microphone-offline --help 58 sherpa-onnx-microphone-offline --help
59 59
60 sherpa-onnx-offline-websocket-server --help 60 sherpa-onnx-offline-websocket-server --help
61 - sherpa-onnx-offline-websocket-client --help  
62 61
63 sherpa-onnx-online-websocket-server --help 62 sherpa-onnx-online-websocket-server --help
64 sherpa-onnx-online-websocket-client --help 63 sherpa-onnx-online-websocket-client --help
@@ -84,14 +84,14 @@ jobs: @@ -84,14 +84,14 @@ jobs:
84 if: matrix.model_type == 'paraformer' 84 if: matrix.model_type == 'paraformer'
85 shell: bash 85 shell: bash
86 run: | 86 run: |
87 - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28  
88 - cd sherpa-onnx-paraformer-zh-2023-03-28 87 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
  88 + cd sherpa-onnx-paraformer-bilingual-zh-en
89 git lfs pull --include "*.onnx" 89 git lfs pull --include "*.onnx"
90 cd .. 90 cd ..
91 91
92 python3 ./python-api-examples/non_streaming_server.py \ 92 python3 ./python-api-examples/non_streaming_server.py \
93 - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \  
94 - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt & 93 + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
  94 + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt &
95 95
96 echo "sleep 10 seconds to wait the server start" 96 echo "sleep 10 seconds to wait the server start"
97 sleep 10 97 sleep 10
@@ -101,16 +101,16 @@ jobs: @@ -101,16 +101,16 @@ jobs:
101 shell: bash 101 shell: bash
102 run: | 102 run: |
103 python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \ 103 python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
104 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \  
105 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \  
106 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \  
107 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav 104 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
  105 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
  106 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
  107 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
108 108
109 python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \ 109 python3 ./python-api-examples/offline-websocket-client-decode-files-sequential.py \
110 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav \  
111 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav \  
112 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav \  
113 - ./sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav 110 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/0.wav \
  111 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/1.wav \
  112 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/2.wav \
  113 + ./sherpa-onnx-paraformer-bilingual-zh-en/test_wavs/8k.wav
114 114
115 - name: Start server for nemo_ctc models 115 - name: Start server for nemo_ctc models
116 if: matrix.model_type == 'nemo_ctc' 116 if: matrix.model_type == 'nemo_ctc'
@@ -24,7 +24,7 @@ jobs: @@ -24,7 +24,7 @@ jobs:
24 matrix: 24 matrix:
25 os: [ubuntu-latest, windows-latest, macos-latest] 25 os: [ubuntu-latest, windows-latest, macos-latest]
26 python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] 26 python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
27 - model_type: ["transducer"] 27 + model_type: ["transducer", "paraformer"]
28 28
29 steps: 29 steps:
30 - uses: actions/checkout@v2 30 - uses: actions/checkout@v2
@@ -71,3 +71,36 @@ jobs: @@ -71,3 +71,36 @@ jobs:
71 run: | 71 run: |
72 python3 ./python-api-examples/online-websocket-client-decode-file.py \ 72 python3 ./python-api-examples/online-websocket-client-decode-file.py \
73 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav 73 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav
  74 +
  75 + - name: Start server for paraformer models
  76 + if: matrix.model_type == 'paraformer'
  77 + shell: bash
  78 + run: |
  79 + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
  80 + cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
  81 + git lfs pull --include "*.onnx"
  82 + cd ..
  83 +
  84 + python3 ./python-api-examples/streaming_server.py \
  85 + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
  86 + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
  87 + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx &
  88 +
  89 + echo "sleep 10 seconds to wait the server start"
  90 + sleep 10
  91 +
  92 + - name: Start client for paraformer models
  93 + if: matrix.model_type == 'paraformer'
  94 + shell: bash
  95 + run: |
  96 + python3 ./python-api-examples/online-websocket-client-decode-file.py \
  97 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav
  98 +
  99 + python3 ./python-api-examples/online-websocket-client-decode-file.py \
  100 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav
  101 +
  102 + python3 ./python-api-examples/online-websocket-client-decode-file.py \
  103 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav
  104 +
  105 + python3 ./python-api-examples/online-websocket-client-decode-file.py \
  106 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav
@@ -9,6 +9,7 @@ on: @@ -9,6 +9,7 @@ on:
9 paths: 9 paths:
10 - '.github/workflows/windows-x64-cuda.yaml' 10 - '.github/workflows/windows-x64-cuda.yaml'
11 - '.github/scripts/test-online-transducer.sh' 11 - '.github/scripts/test-online-transducer.sh'
  12 + - '.github/scripts/test-online-paraformer.sh'
12 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
13 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
14 - 'CMakeLists.txt' 15 - 'CMakeLists.txt'
@@ -20,6 +21,7 @@ on: @@ -20,6 +21,7 @@ on:
20 paths: 21 paths:
21 - '.github/workflows/windows-x64-cuda.yaml' 22 - '.github/workflows/windows-x64-cuda.yaml'
22 - '.github/scripts/test-online-transducer.sh' 23 - '.github/scripts/test-online-transducer.sh'
  24 + - '.github/scripts/test-online-paraformer.sh'
23 - '.github/scripts/test-offline-transducer.sh' 25 - '.github/scripts/test-offline-transducer.sh'
24 - '.github/scripts/test-offline-ctc.sh' 26 - '.github/scripts/test-offline-ctc.sh'
25 - 'CMakeLists.txt' 27 - 'CMakeLists.txt'
@@ -74,6 +76,14 @@ jobs: @@ -74,6 +76,14 @@ jobs:
74 76
75 ls -lh ./bin/Release/sherpa-onnx.exe 77 ls -lh ./bin/Release/sherpa-onnx.exe
76 78
  79 + - name: Test online paraformer for windows x64
  80 + shell: bash
  81 + run: |
  82 + export PATH=$PWD/build/bin/Release:$PATH
  83 + export EXE=sherpa-onnx.exe
  84 +
  85 + .github/scripts/test-online-paraformer.sh
  86 +
77 - name: Test offline Whisper for windows x64 87 - name: Test offline Whisper for windows x64
78 shell: bash 88 shell: bash
79 run: | 89 run: |
@@ -9,6 +9,7 @@ on: @@ -9,6 +9,7 @@ on:
9 paths: 9 paths:
10 - '.github/workflows/windows-x64.yaml' 10 - '.github/workflows/windows-x64.yaml'
11 - '.github/scripts/test-online-transducer.sh' 11 - '.github/scripts/test-online-transducer.sh'
  12 + - '.github/scripts/test-online-paraformer.sh'
12 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
13 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
14 - 'CMakeLists.txt' 15 - 'CMakeLists.txt'
@@ -20,6 +21,7 @@ on: @@ -20,6 +21,7 @@ on:
20 paths: 21 paths:
21 - '.github/workflows/windows-x64.yaml' 22 - '.github/workflows/windows-x64.yaml'
22 - '.github/scripts/test-online-transducer.sh' 23 - '.github/scripts/test-online-transducer.sh'
  24 + - '.github/scripts/test-online-paraformer.sh'
23 - '.github/scripts/test-offline-transducer.sh' 25 - '.github/scripts/test-offline-transducer.sh'
24 - '.github/scripts/test-offline-ctc.sh' 26 - '.github/scripts/test-offline-ctc.sh'
25 - 'CMakeLists.txt' 27 - 'CMakeLists.txt'
@@ -75,6 +77,14 @@ jobs: @@ -75,6 +77,14 @@ jobs:
75 77
76 ls -lh ./bin/Release/sherpa-onnx.exe 78 ls -lh ./bin/Release/sherpa-onnx.exe
77 79
  80 + - name: Test online paraformer for windows x64
  81 + shell: bash
  82 + run: |
  83 + export PATH=$PWD/build/bin/Release:$PATH
  84 + export EXE=sherpa-onnx.exe
  85 +
  86 + .github/scripts/test-online-paraformer.sh
  87 +
78 - name: Test offline Whisper for windows x64 88 - name: Test offline Whisper for windows x64
79 shell: bash 89 shell: bash
80 run: | 90 run: |
@@ -7,6 +7,7 @@ on: @@ -7,6 +7,7 @@ on:
7 paths: 7 paths:
8 - '.github/workflows/windows-x86.yaml' 8 - '.github/workflows/windows-x86.yaml'
9 - '.github/scripts/test-online-transducer.sh' 9 - '.github/scripts/test-online-transducer.sh'
  10 + - '.github/scripts/test-online-paraformer.sh'
10 - '.github/scripts/test-offline-transducer.sh' 11 - '.github/scripts/test-offline-transducer.sh'
11 - '.github/scripts/test-offline-ctc.sh' 12 - '.github/scripts/test-offline-ctc.sh'
12 - 'CMakeLists.txt' 13 - 'CMakeLists.txt'
@@ -18,6 +19,7 @@ on: @@ -18,6 +19,7 @@ on:
18 paths: 19 paths:
19 - '.github/workflows/windows-x86.yaml' 20 - '.github/workflows/windows-x86.yaml'
20 - '.github/scripts/test-online-transducer.sh' 21 - '.github/scripts/test-online-transducer.sh'
  22 + - '.github/scripts/test-online-paraformer.sh'
21 - '.github/scripts/test-offline-transducer.sh' 23 - '.github/scripts/test-offline-transducer.sh'
22 - '.github/scripts/test-offline-ctc.sh' 24 - '.github/scripts/test-offline-ctc.sh'
23 - 'CMakeLists.txt' 25 - 'CMakeLists.txt'
@@ -73,6 +75,14 @@ jobs: @@ -73,6 +75,14 @@ jobs:
73 75
74 ls -lh ./bin/Release/sherpa-onnx.exe 76 ls -lh ./bin/Release/sherpa-onnx.exe
75 77
  78 + - name: Test online paraformer for windows x86
  79 + shell: bash
  80 + run: |
  81 + export PATH=$PWD/build/bin/Release:$PATH
  82 + export EXE=sherpa-onnx.exe
  83 +
  84 + .github/scripts/test-online-paraformer.sh
  85 +
76 - name: Test offline Whisper for windows x86 86 - name: Test offline Whisper for windows x86
77 shell: bash 87 shell: bash
78 run: | 88 run: |
1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR) 1 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
2 project(sherpa-onnx) 2 project(sherpa-onnx)
3 3
4 -set(SHERPA_ONNX_VERSION "1.7.3") 4 +set(SHERPA_ONNX_VERSION "1.7.4")
5 5
6 # Disable warning about 6 # Disable warning about
7 # 7 #
@@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \ @@ -37,14 +37,14 @@ python3 ./python-api-examples/non_streaming_server.py \
37 (2) Use a non-streaming paraformer 37 (2) Use a non-streaming paraformer
38 38
39 cd /path/to/sherpa-onnx 39 cd /path/to/sherpa-onnx
40 -GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28  
41 -cd sherpa-onnx-paraformer-zh-2023-03-28 40 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
  41 +cd sherpa-onnx-paraformer-bilingual-zh-en/
42 git lfs pull --include "*.onnx" 42 git lfs pull --include "*.onnx"
43 cd .. 43 cd ..
44 44
45 python3 ./python-api-examples/non_streaming_server.py \ 45 python3 ./python-api-examples/non_streaming_server.py \
46 - --paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \  
47 - --tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt 46 + --paraformer ./sherpa-onnx-paraformer-bilingual-zh-en/model.int8.onnx \
  47 + --tokens ./sherpa-onnx-paraformer-bilingual-zh-en/tokens.txt
48 48
49 (3) Use a non-streaming CTC model from NeMo 49 (3) Use a non-streaming CTC model from NeMo
50 50
@@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe @@ -5,16 +5,41 @@ This file demonstrates how to use sherpa-onnx Python API to transcribe
5 file(s) with a streaming model. 5 file(s) with a streaming model.
6 6
7 Usage: 7 Usage:
8 - ./online-decode-files.py \  
9 - /path/to/foo.wav \  
10 - /path/to/bar.wav \  
11 - /path/to/16kHz.wav \  
12 - /path/to/8kHz.wav 8 +
  9 +(1) Streaming transducer
  10 +
  11 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
  12 +cd sherpa-onnx-streaming-zipformer-en-2023-06-26
  13 +git lfs pull --include "*.onnx"
  14 +
  15 +./python-api-examples/online-decode-files.py \
  16 + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
  17 + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
  18 + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
  19 + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
  20 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
  21 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
  22 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
  23 +
  24 +(2) Streaming paraformer
  25 +
  26 +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
  27 +cd sherpa-onnx-streaming-paraformer-bilingual-zh-en
  28 +git lfs pull --include "*.onnx"
  29 +
  30 +./python-api-examples/online-decode-files.py \
  31 + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
  32 + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \
  33 + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \
  34 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \
  35 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \
  36 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \
  37 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \
  38 + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav
13 39
14 Please refer to 40 Please refer to
15 https://k2-fsa.github.io/sherpa/onnx/index.html 41 https://k2-fsa.github.io/sherpa/onnx/index.html
16 -to install sherpa-onnx and to download the pre-trained models  
17 -used in this file. 42 +to install sherpa-onnx and to download streaming pre-trained models.
18 """ 43 """
19 import argparse 44 import argparse
20 import time 45 import time
@@ -41,19 +66,31 @@ def get_args(): @@ -41,19 +66,31 @@ def get_args():
41 parser.add_argument( 66 parser.add_argument(
42 "--encoder", 67 "--encoder",
43 type=str, 68 type=str,
44 - help="Path to the encoder model", 69 + help="Path to the transducer encoder model",
45 ) 70 )
46 71
47 parser.add_argument( 72 parser.add_argument(
48 "--decoder", 73 "--decoder",
49 type=str, 74 type=str,
50 - help="Path to the decoder model", 75 + help="Path to the transducer decoder model",
51 ) 76 )
52 77
53 parser.add_argument( 78 parser.add_argument(
54 "--joiner", 79 "--joiner",
55 type=str, 80 type=str,
56 - help="Path to the joiner model", 81 + help="Path to the transducer joiner model",
  82 + )
  83 +
  84 + parser.add_argument(
  85 + "--paraformer-encoder",
  86 + type=str,
  87 + help="Path to the paraformer encoder model",
  88 + )
  89 +
  90 + parser.add_argument(
  91 + "--paraformer-decoder",
  92 + type=str,
  93 + help="Path to the paraformer decoder model",
57 ) 94 )
58 95
59 parser.add_argument( 96 parser.add_argument(
@@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]: @@ -200,24 +237,42 @@ def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
200 237
201 def main(): 238 def main():
202 args = get_args() 239 args = get_args()
203 - assert_file_exists(args.encoder)  
204 - assert_file_exists(args.decoder)  
205 - assert_file_exists(args.joiner)  
206 assert_file_exists(args.tokens) 240 assert_file_exists(args.tokens)
207 241
208 - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(  
209 - tokens=args.tokens,  
210 - encoder=args.encoder,  
211 - decoder=args.decoder,  
212 - joiner=args.joiner,  
213 - num_threads=args.num_threads,  
214 - provider=args.provider,  
215 - sample_rate=16000,  
216 - feature_dim=80,  
217 - decoding_method=args.decoding_method,  
218 - max_active_paths=args.max_active_paths,  
219 - context_score=args.context_score,  
220 - ) 242 + if args.encoder:
  243 + assert_file_exists(args.encoder)
  244 + assert_file_exists(args.decoder)
  245 + assert_file_exists(args.joiner)
  246 +
  247 + assert not args.paraformer_encoder, args.paraformer_encoder
  248 + assert not args.paraformer_decoder, args.paraformer_decoder
  249 +
  250 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
  251 + tokens=args.tokens,
  252 + encoder=args.encoder,
  253 + decoder=args.decoder,
  254 + joiner=args.joiner,
  255 + num_threads=args.num_threads,
  256 + provider=args.provider,
  257 + sample_rate=16000,
  258 + feature_dim=80,
  259 + decoding_method=args.decoding_method,
  260 + max_active_paths=args.max_active_paths,
  261 + context_score=args.context_score,
  262 + )
  263 + elif args.paraformer_encoder:
  264 + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
  265 + tokens=args.tokens,
  266 + encoder=args.paraformer_encoder,
  267 + decoder=args.paraformer_decoder,
  268 + num_threads=args.num_threads,
  269 + provider=args.provider,
  270 + sample_rate=16000,
  271 + feature_dim=80,
  272 + decoding_method="greedy_search",
  273 + )
  274 + else:
  275 + raise ValueError("Please provide a model")
221 276
222 print("Started!") 277 print("Started!")
223 start_time = time.time() 278 start_time = time.time()
@@ -243,7 +298,7 @@ def main(): @@ -243,7 +298,7 @@ def main():
243 298
244 s.accept_waveform(sample_rate, samples) 299 s.accept_waveform(sample_rate, samples)
245 300
246 - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) 301 + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
247 s.accept_waveform(sample_rate, tail_paddings) 302 s.accept_waveform(sample_rate, tail_paddings)
248 303
249 s.input_finished() 304 s.input_finished()
@@ -16,9 +16,9 @@ Example: @@ -16,9 +16,9 @@ Example:
16 (1) Without a certificate 16 (1) Without a certificate
17 17
18 python3 ./python-api-examples/streaming_server.py \ 18 python3 ./python-api-examples/streaming_server.py \
19 - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \  
20 - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \  
21 - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ 19 + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
  20 + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
  21 + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
22 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt 22 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
23 23
24 (2) With a certificate 24 (2) With a certificate
@@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \ @@ -32,9 +32,9 @@ python3 ./python-api-examples/streaming_server.py \
32 (b) Start the server 32 (b) Start the server
33 33
34 python3 ./python-api-examples/streaming_server.py \ 34 python3 ./python-api-examples/streaming_server.py \
35 - --encoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \  
36 - --decoder-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \  
37 - --joiner-model ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ 35 + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \
  36 + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \
  37 + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \
38 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ 38 --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \
39 --certificate ./python-api-examples/web/cert.pem 39 --certificate ./python-api-examples/web/cert.pem
40 40
@@ -113,24 +113,33 @@ def setup_logger( @@ -113,24 +113,33 @@ def setup_logger(
113 113
114 def add_model_args(parser: argparse.ArgumentParser): 114 def add_model_args(parser: argparse.ArgumentParser):
115 parser.add_argument( 115 parser.add_argument(
116 - "--encoder-model", 116 + "--encoder",
117 type=str, 117 type=str,
118 - required=True,  
119 - help="Path to the encoder model", 118 + help="Path to the transducer encoder model",
120 ) 119 )
121 120
122 parser.add_argument( 121 parser.add_argument(
123 - "--decoder-model", 122 + "--decoder",
124 type=str, 123 type=str,
125 - required=True,  
126 - help="Path to the decoder model.", 124 + help="Path to the transducer decoder model.",
127 ) 125 )
128 126
129 parser.add_argument( 127 parser.add_argument(
130 - "--joiner-model", 128 + "--joiner",
131 type=str, 129 type=str,
132 - required=True,  
133 - help="Path to the joiner model.", 130 + help="Path to the transducer joiner model.",
  131 + )
  132 +
  133 + parser.add_argument(
  134 + "--paraformer-encoder",
  135 + type=str,
  136 + help="Path to the paraformer encoder model",
  137 + )
  138 +
  139 + parser.add_argument(
  140 + "--paraformer-decoder",
  141 + type=str,
  142 + help="Path to the transducer decoder model.",
134 ) 143 )
135 144
136 parser.add_argument( 145 parser.add_argument(
@@ -323,22 +332,40 @@ def get_args(): @@ -323,22 +332,40 @@ def get_args():
323 332
324 333
325 def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: 334 def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
326 - recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(  
327 - tokens=args.tokens,  
328 - encoder=args.encoder_model,  
329 - decoder=args.decoder_model,  
330 - joiner=args.joiner_model,  
331 - num_threads=args.num_threads,  
332 - sample_rate=args.sample_rate,  
333 - feature_dim=args.feat_dim,  
334 - decoding_method=args.decoding_method,  
335 - max_active_paths=args.num_active_paths,  
336 - enable_endpoint_detection=args.use_endpoint != 0,  
337 - rule1_min_trailing_silence=args.rule1_min_trailing_silence,  
338 - rule2_min_trailing_silence=args.rule2_min_trailing_silence,  
339 - rule3_min_utterance_length=args.rule3_min_utterance_length,  
340 - provider=args.provider,  
341 - ) 335 + if args.encoder:
  336 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
  337 + tokens=args.tokens,
  338 + encoder=args.encoder,
  339 + decoder=args.decoder,
  340 + joiner=args.joiner,
  341 + num_threads=args.num_threads,
  342 + sample_rate=args.sample_rate,
  343 + feature_dim=args.feat_dim,
  344 + decoding_method=args.decoding_method,
  345 + max_active_paths=args.num_active_paths,
  346 + enable_endpoint_detection=args.use_endpoint != 0,
  347 + rule1_min_trailing_silence=args.rule1_min_trailing_silence,
  348 + rule2_min_trailing_silence=args.rule2_min_trailing_silence,
  349 + rule3_min_utterance_length=args.rule3_min_utterance_length,
  350 + provider=args.provider,
  351 + )
  352 + elif args.paraformer_encoder:
  353 + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
  354 + tokens=args.tokens,
  355 + encoder=args.paraformer_encoder,
  356 + decoder=args.paraformer_decoder,
  357 + num_threads=args.num_threads,
  358 + sample_rate=args.sample_rate,
  359 + feature_dim=args.feat_dim,
  360 + decoding_method=args.decoding_method,
  361 + enable_endpoint_detection=args.use_endpoint != 0,
  362 + rule1_min_trailing_silence=args.rule1_min_trailing_silence,
  363 + rule2_min_trailing_silence=args.rule2_min_trailing_silence,
  364 + rule3_min_utterance_length=args.rule3_min_utterance_length,
  365 + provider=args.provider,
  366 + )
  367 + else:
  368 + raise ValueError("Please provide a model")
342 369
343 return recognizer 370 return recognizer
344 371
@@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a> @@ -654,11 +681,25 @@ Go back to <a href="/streaming_record.html">/streaming_record.html</a>
654 681
655 682
656 def check_args(args): 683 def check_args(args):
657 - assert Path(args.encoder_model).is_file(), f"{args.encoder_model} does not exist" 684 + if args.encoder:
  685 + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist"
  686 +
  687 + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist"
  688 +
  689 + assert Path(args.joiner).is_file(), f"{args.joiner} does not exist"
658 690
659 - assert Path(args.decoder_model).is_file(), f"{args.decoder_model} does not exist" 691 + assert args.paraformer_encoder is None, args.paraformer_encoder
  692 + assert args.paraformer_decoder is None, args.paraformer_decoder
  693 + elif args.paraformer_encoder:
  694 + assert Path(
  695 + args.paraformer_encoder
  696 + ).is_file(), f"{args.paraformer_encoder} does not exist"
660 697
661 - assert Path(args.joiner_model).is_file(), f"{args.joiner_model} does not exist" 698 + assert Path(
  699 + args.paraformer_decoder
  700 + ).is_file(), f"{args.paraformer_decoder} does not exist"
  701 + else:
  702 + raise ValueError("Please provide a model")
662 703
663 if not Path(args.tokens).is_file(): 704 if not Path(args.tokens).is_file():
664 raise ValueError(f"{args.tokens} does not exist") 705 raise ValueError(f"{args.tokens} does not exist")
@@ -46,6 +46,8 @@ set(sources @@ -46,6 +46,8 @@ set(sources
46 online-lm.cc 46 online-lm.cc
47 online-lstm-transducer-model.cc 47 online-lstm-transducer-model.cc
48 online-model-config.cc 48 online-model-config.cc
  49 + online-paraformer-model-config.cc
  50 + online-paraformer-model.cc
49 online-recognizer-impl.cc 51 online-recognizer-impl.cc
50 online-recognizer.cc 52 online-recognizer.cc
51 online-rnn-lm.cc 53 online-rnn-lm.cc
@@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const { @@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const {
39 39
40 class FeatureExtractor::Impl { 40 class FeatureExtractor::Impl {
41 public: 41 public:
42 - explicit Impl(const FeatureExtractorConfig &config) { 42 + explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
43 opts_.frame_opts.dither = 0; 43 opts_.frame_opts.dither = 0;
44 opts_.frame_opts.snip_edges = false; 44 opts_.frame_opts.snip_edges = false;
45 opts_.frame_opts.samp_freq = config.sampling_rate; 45 opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -50,6 +50,19 @@ class FeatureExtractor::Impl { @@ -50,6 +50,19 @@ class FeatureExtractor::Impl {
50 } 50 }
51 51
52 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 52 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
  53 + if (config_.normalize_samples) {
  54 + AcceptWaveformImpl(sampling_rate, waveform, n);
  55 + } else {
  56 + std::vector<float> buf(n);
  57 + for (int32_t i = 0; i != n; ++i) {
  58 + buf[i] = waveform[i] * 32768;
  59 + }
  60 + AcceptWaveformImpl(sampling_rate, buf.data(), n);
  61 + }
  62 + }
  63 +
  64 + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
  65 + int32_t n) {
53 std::lock_guard<std::mutex> lock(mutex_); 66 std::lock_guard<std::mutex> lock(mutex_);
54 67
55 if (resampler_) { 68 if (resampler_) {
@@ -146,6 +159,7 @@ class FeatureExtractor::Impl { @@ -146,6 +159,7 @@ class FeatureExtractor::Impl {
146 private: 159 private:
147 std::unique_ptr<knf::OnlineFbank> fbank_; 160 std::unique_ptr<knf::OnlineFbank> fbank_;
148 knf::FbankOptions opts_; 161 knf::FbankOptions opts_;
  162 + FeatureExtractorConfig config_;
149 mutable std::mutex mutex_; 163 mutable std::mutex mutex_;
150 std::unique_ptr<LinearResample> resampler_; 164 std::unique_ptr<LinearResample> resampler_;
151 int32_t last_frame_index_ = 0; 165 int32_t last_frame_index_ = 0;
@@ -21,6 +21,13 @@ struct FeatureExtractorConfig { @@ -21,6 +21,13 @@ struct FeatureExtractorConfig {
21 // Feature dimension 21 // Feature dimension
22 int32_t feature_dim = 80; 22 int32_t feature_dim = 80;
23 23
  24 + // Set internally by some models, e.g., paraformer sets it to false.
  25 + // This parameter is not exposed to users from the commandline
  26 + // If true, the feature extractor expects inputs to be normalized to
  27 + // the range [-1, 1].
  28 + // If false, we will multiply the inputs by 32768
  29 + bool normalize_samples = true;
  30 +
24 std::string ToString() const; 31 std::string ToString() const;
25 32
26 void Register(ParseOptions *po); 33 void Register(ParseOptions *po);
@@ -12,6 +12,7 @@ namespace sherpa_onnx { @@ -12,6 +12,7 @@ namespace sherpa_onnx {
12 12
13 void OnlineModelConfig::Register(ParseOptions *po) { 13 void OnlineModelConfig::Register(ParseOptions *po) {
14 transducer.Register(po); 14 transducer.Register(po);
  15 + paraformer.Register(po);
15 16
16 po->Register("tokens", &tokens, "Path to tokens.txt"); 17 po->Register("tokens", &tokens, "Path to tokens.txt");
17 18
@@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const { @@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const {
41 return false; 42 return false;
42 } 43 }
43 44
  45 + if (!paraformer.encoder.empty()) {
  46 + return paraformer.Validate();
  47 + }
  48 +
44 return transducer.Validate(); 49 return transducer.Validate();
45 } 50 }
46 51
@@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const { @@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const {
49 54
50 os << "OnlineModelConfig("; 55 os << "OnlineModelConfig(";
51 os << "transducer=" << transducer.ToString() << ", "; 56 os << "transducer=" << transducer.ToString() << ", ";
  57 + os << "paraformer=" << paraformer.ToString() << ", ";
52 os << "tokens=\"" << tokens << "\", "; 58 os << "tokens=\"" << tokens << "\", ";
53 os << "num_threads=" << num_threads << ", "; 59 os << "num_threads=" << num_threads << ", ";
54 os << "debug=" << (debug ? "True" : "False") << ", "; 60 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -6,12 +6,14 @@ @@ -6,12 +6,14 @@
6 6
7 #include <string> 7 #include <string>
8 8
  9 +#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
9 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 10 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
10 11
11 namespace sherpa_onnx { 12 namespace sherpa_onnx {
12 13
13 struct OnlineModelConfig { 14 struct OnlineModelConfig {
14 OnlineTransducerModelConfig transducer; 15 OnlineTransducerModelConfig transducer;
  16 + OnlineParaformerModelConfig paraformer;
15 std::string tokens; 17 std::string tokens;
16 int32_t num_threads = 1; 18 int32_t num_threads = 1;
17 bool debug = false; 19 bool debug = false;
@@ -28,9 +30,11 @@ struct OnlineModelConfig { @@ -28,9 +30,11 @@ struct OnlineModelConfig {
28 30
29 OnlineModelConfig() = default; 31 OnlineModelConfig() = default;
30 OnlineModelConfig(const OnlineTransducerModelConfig &transducer, 32 OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
  33 + const OnlineParaformerModelConfig &paraformer,
31 const std::string &tokens, int32_t num_threads, bool debug, 34 const std::string &tokens, int32_t num_threads, bool debug,
32 const std::string &provider, const std::string &model_type) 35 const std::string &provider, const std::string &model_type)
33 : transducer(transducer), 36 : transducer(transducer),
  37 + paraformer(paraformer),
34 tokens(tokens), 38 tokens(tokens),
35 num_threads(num_threads), 39 num_threads(num_threads),
36 debug(debug), 40 debug(debug),
  1 +// sherpa-onnx/csrc/online-paraformer-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OnlineParaformerDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int32_t> tokens;
  17 +
  18 + int32_t last_non_blank_frame_index = 0;
  19 +};
  20 +
  21 +} // namespace sherpa_onnx
  22 +
  23 +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
  1 +// sherpa-onnx/csrc/online-paraformer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OnlineParaformerModelConfig::Register(ParseOptions *po) {
  13 + po->Register("paraformer-encoder", &encoder,
  14 + "Path to encoder.onnx of paraformer.");
  15 + po->Register("paraformer-decoder", &decoder,
  16 + "Path to decoder.onnx of paraformer.");
  17 +}
  18 +
  19 +bool OnlineParaformerModelConfig::Validate() const {
  20 + if (!FileExists(encoder)) {
  21 + SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
  22 + return false;
  23 + }
  24 +
  25 + if (!FileExists(decoder)) {
  26 + SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
  27 + return false;
  28 + }
  29 +
  30 + return true;
  31 +}
  32 +
  33 +std::string OnlineParaformerModelConfig::ToString() const {
  34 + std::ostringstream os;
  35 +
  36 + os << "OnlineParaformerModelConfig(";
  37 + os << "encoder=\"" << encoder << "\", ";
  38 + os << "decoder=\"" << decoder << "\")";
  39 +
  40 + return os.str();
  41 +}
  42 +
  43 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-paraformer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OnlineParaformerModelConfig {
  14 + std::string encoder;
  15 + std::string decoder;
  16 +
  17 + OnlineParaformerModelConfig() = default;
  18 +
  19 + OnlineParaformerModelConfig(const std::string &encoder,
  20 + const std::string &decoder)
  21 + : encoder(encoder), decoder(decoder) {}
  22 +
  23 + void Register(ParseOptions *po);
  24 + bool Validate() const;
  25 +
  26 + std::string ToString() const;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-paraformer-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-paraformer-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <string>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "sherpa-onnx/csrc/macros.h"
  17 +#include "sherpa-onnx/csrc/onnx-utils.h"
  18 +#include "sherpa-onnx/csrc/session.h"
  19 +#include "sherpa-onnx/csrc/text-utils.h"
  20 +
  21 +namespace sherpa_onnx {
  22 +
  23 +class OnlineParaformerModel::Impl {
  24 + public:
  25 + explicit Impl(const OnlineModelConfig &config)
  26 + : config_(config),
  27 + env_(ORT_LOGGING_LEVEL_ERROR),
  28 + sess_opts_(GetSessionOptions(config)),
  29 + allocator_{} {
  30 + {
  31 + auto buf = ReadFile(config.paraformer.encoder);
  32 + InitEncoder(buf.data(), buf.size());
  33 + }
  34 +
  35 + {
  36 + auto buf = ReadFile(config.paraformer.decoder);
  37 + InitDecoder(buf.data(), buf.size());
  38 + }
  39 + }
  40 +
  41 +#if __ANDROID_API__ >= 9
  42 + Impl(AAssetManager *mgr, const OnlineModelConfig &config)
  43 + : config_(config),
  44 + env_(ORT_LOGGING_LEVEL_WARNING),
  45 + sess_opts_(GetSessionOptions(config)),
  46 + allocator_{} {
  47 + {
  48 + auto buf = ReadFile(mgr, config.paraformer.encoder);
  49 + InitEncoder(buf.data(), buf.size());
  50 + }
  51 +
  52 + {
  53 + auto buf = ReadFile(mgr, config.paraformer.decoder);
  54 + InitDecoder(buf.data(), buf.size());
  55 + }
  56 + }
  57 +#endif
  58 +
  59 + std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
  60 + Ort::Value features_length) {
  61 + std::array<Ort::Value, 2> inputs = {std::move(features),
  62 + std::move(features_length)};
  63 +
  64 + return encoder_sess_->Run(
  65 + {}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
  66 + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
  67 + }
  68 +
  69 + std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
  70 + Ort::Value encoder_out_length,
  71 + Ort::Value acoustic_embedding,
  72 + Ort::Value acoustic_embedding_length,
  73 + std::vector<Ort::Value> states) {
  74 + std::vector<Ort::Value> decoder_inputs;
  75 + decoder_inputs.reserve(4 + states.size());
  76 +
  77 + decoder_inputs.push_back(std::move(encoder_out));
  78 + decoder_inputs.push_back(std::move(encoder_out_length));
  79 + decoder_inputs.push_back(std::move(acoustic_embedding));
  80 + decoder_inputs.push_back(std::move(acoustic_embedding_length));
  81 +
  82 + for (auto &v : states) {
  83 + decoder_inputs.push_back(std::move(v));
  84 + }
  85 +
  86 + return decoder_sess_->Run({}, decoder_input_names_ptr_.data(),
  87 + decoder_inputs.data(), decoder_inputs.size(),
  88 + decoder_output_names_ptr_.data(),
  89 + decoder_output_names_ptr_.size());
  90 + }
  91 +
  92 + int32_t VocabSize() const { return vocab_size_; }
  93 +
  94 + int32_t LfrWindowSize() const { return lfr_window_size_; }
  95 +
  96 + int32_t LfrWindowShift() const { return lfr_window_shift_; }
  97 +
  98 + int32_t EncoderOutputSize() const { return encoder_output_size_; }
  99 +
  100 + int32_t DecoderKernelSize() const { return decoder_kernel_size_; }
  101 +
  102 + int32_t DecoderNumBlocks() const { return decoder_num_blocks_; }
  103 +
  104 + const std::vector<float> &NegativeMean() const { return neg_mean_; }
  105 +
  106 + const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
  107 +
  108 + OrtAllocator *Allocator() const { return allocator_; }
  109 +
  110 + private:
  111 + void InitEncoder(void *model_data, size_t model_data_length) {
  112 + encoder_sess_ = std::make_unique<Ort::Session>(
  113 + env_, model_data, model_data_length, sess_opts_);
  114 +
  115 + GetInputNames(encoder_sess_.get(), &encoder_input_names_,
  116 + &encoder_input_names_ptr_);
  117 +
  118 + GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
  119 + &encoder_output_names_ptr_);
  120 +
  121 + // get meta data
  122 + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
  123 + if (config_.debug) {
  124 + std::ostringstream os;
  125 + PrintModelMetadata(os, meta_data);
  126 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  127 + }
  128 +
  129 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  130 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  131 + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
  132 + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
  133 + SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size");
  134 + SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks");
  135 + SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size");
  136 +
  137 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
  138 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
  139 +
  140 + float scale = std::sqrt(encoder_output_size_);
  141 + for (auto &f : inv_stddev_) {
  142 + f *= scale;
  143 + }
  144 + }
  145 +
  146 + void InitDecoder(void *model_data, size_t model_data_length) {
  147 + decoder_sess_ = std::make_unique<Ort::Session>(
  148 + env_, model_data, model_data_length, sess_opts_);
  149 +
  150 + GetInputNames(decoder_sess_.get(), &decoder_input_names_,
  151 + &decoder_input_names_ptr_);
  152 +
  153 + GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
  154 + &decoder_output_names_ptr_);
  155 + }
  156 +
  157 + private:
  158 + OnlineModelConfig config_;
  159 + Ort::Env env_;
  160 + Ort::SessionOptions sess_opts_;
  161 + Ort::AllocatorWithDefaultOptions allocator_;
  162 +
  163 + std::unique_ptr<Ort::Session> encoder_sess_;
  164 +
  165 + std::vector<std::string> encoder_input_names_;
  166 + std::vector<const char *> encoder_input_names_ptr_;
  167 +
  168 + std::vector<std::string> encoder_output_names_;
  169 + std::vector<const char *> encoder_output_names_ptr_;
  170 +
  171 + std::unique_ptr<Ort::Session> decoder_sess_;
  172 +
  173 + std::vector<std::string> decoder_input_names_;
  174 + std::vector<const char *> decoder_input_names_ptr_;
  175 +
  176 + std::vector<std::string> decoder_output_names_;
  177 + std::vector<const char *> decoder_output_names_ptr_;
  178 +
  179 + std::vector<float> neg_mean_;
  180 + std::vector<float> inv_stddev_;
  181 +
  182 + int32_t vocab_size_ = 0; // initialized in Init
  183 + int32_t lfr_window_size_ = 0;
  184 + int32_t lfr_window_shift_ = 0;
  185 +
  186 + int32_t encoder_output_size_ = 0;
  187 + int32_t decoder_num_blocks_ = 0;
  188 + int32_t decoder_kernel_size_ = 0;
  189 +};
  190 +
  191 +OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config)
  192 + : impl_(std::make_unique<Impl>(config)) {}
  193 +
  194 +#if __ANDROID_API__ >= 9
  195 +OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr,
  196 + const OnlineModelConfig &config)
  197 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  198 +#endif
  199 +
  200 +OnlineParaformerModel::~OnlineParaformerModel() = default;
  201 +
  202 +std::vector<Ort::Value> OnlineParaformerModel::ForwardEncoder(
  203 + Ort::Value features, Ort::Value features_length) const {
  204 + return impl_->ForwardEncoder(std::move(features), std::move(features_length));
  205 +}
  206 +
  207 +std::vector<Ort::Value> OnlineParaformerModel::ForwardDecoder(
  208 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  209 + Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length,
  210 + std::vector<Ort::Value> states) const {
  211 + return impl_->ForwardDecoder(
  212 + std::move(encoder_out), std::move(encoder_out_length),
  213 + std::move(acoustic_embedding), std::move(acoustic_embedding_length),
  214 + std::move(states));
  215 +}
  216 +
  217 +int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
  218 +
  219 +int32_t OnlineParaformerModel::LfrWindowSize() const {
  220 + return impl_->LfrWindowSize();
  221 +}
  222 +int32_t OnlineParaformerModel::LfrWindowShift() const {
  223 + return impl_->LfrWindowShift();
  224 +}
  225 +
  226 +int32_t OnlineParaformerModel::EncoderOutputSize() const {
  227 + return impl_->EncoderOutputSize();
  228 +}
  229 +
  230 +int32_t OnlineParaformerModel::DecoderKernelSize() const {
  231 + return impl_->DecoderKernelSize();
  232 +}
  233 +
  234 +int32_t OnlineParaformerModel::DecoderNumBlocks() const {
  235 + return impl_->DecoderNumBlocks();
  236 +}
  237 +
  238 +const std::vector<float> &OnlineParaformerModel::NegativeMean() const {
  239 + return impl_->NegativeMean();
  240 +}
  241 +const std::vector<float> &OnlineParaformerModel::InverseStdDev() const {
  242 + return impl_->InverseStdDev();
  243 +}
  244 +
  245 +OrtAllocator *OnlineParaformerModel::Allocator() const {
  246 + return impl_->Allocator();
  247 +}
  248 +
  249 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-paraformer-model.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/online-model-config.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +class OnlineParaformerModel {
  22 + public:
  23 + explicit OnlineParaformerModel(const OnlineModelConfig &config);
  24 +
  25 +#if __ANDROID_API__ >= 9
  26 + OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config);
  27 +#endif
  28 +
  29 + ~OnlineParaformerModel();
  30 +
  31 + std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
  32 + Ort::Value features_length) const;
  33 +
  34 + std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
  35 + Ort::Value encoder_out_length,
  36 + Ort::Value acoustic_embedding,
  37 + Ort::Value acoustic_embedding_length,
  38 + std::vector<Ort::Value> states) const;
  39 +
  40 + /** Return the vocabulary size of the model
  41 + */
  42 + int32_t VocabSize() const;
  43 +
  44 + /** It is lfr_m in config.yaml
  45 + */
  46 + int32_t LfrWindowSize() const;
  47 +
  48 + /** It is lfr_n in config.yaml
  49 + */
  50 + int32_t LfrWindowShift() const;
  51 +
  52 + int32_t EncoderOutputSize() const;
  53 +
  54 + int32_t DecoderKernelSize() const;
  55 + int32_t DecoderNumBlocks() const;
  56 +
  57 + /** Return negative mean for CMVN
  58 + */
  59 + const std::vector<float> &NegativeMean() const;
  60 +
  61 + /** Return inverse stddev for CMVN
  62 + */
  63 + const std::vector<float> &InverseStdDev() const;
  64 +
  65 + /** Return an allocator for allocating memory
  66 + */
  67 + OrtAllocator *Allocator() const;
  68 +
  69 + private:
  70 + class Impl;
  71 + std::unique_ptr<Impl> impl_;
  72 +};
  73 +
  74 +} // namespace sherpa_onnx
  75 +
  76 +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 5 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
6 6
  7 +#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
7 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" 8 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
8 9
9 namespace sherpa_onnx { 10 namespace sherpa_onnx {
@@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
14 return std::make_unique<OnlineRecognizerTransducerImpl>(config); 15 return std::make_unique<OnlineRecognizerTransducerImpl>(config);
15 } 16 }
16 17
  18 + if (!config.model_config.paraformer.encoder.empty()) {
  19 + return std::make_unique<OnlineRecognizerParaformerImpl>(config);
  20 + }
  21 +
17 SHERPA_ONNX_LOGE("Please specify a model"); 22 SHERPA_ONNX_LOGE("Please specify a model");
18 exit(-1); 23 exit(-1);
19 } 24 }
@@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
25 return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config); 30 return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
26 } 31 }
27 32
  33 + if (!config.model_config.paraformer.encoder.empty()) {
  34 + return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
  35 + }
  36 +
28 SHERPA_ONNX_LOGE("Please specify a model"); 37 SHERPA_ONNX_LOGE("Please specify a model");
29 exit(-1); 38 exit(-1);
30 } 39 }
@@ -26,8 +26,6 @@ class OnlineRecognizerImpl { @@ -26,8 +26,6 @@ class OnlineRecognizerImpl {
26 26
27 virtual ~OnlineRecognizerImpl() = default; 27 virtual ~OnlineRecognizerImpl() = default;
28 28
29 - virtual void InitOnlineStream(OnlineStream *stream) const = 0;  
30 -  
31 virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; 29 virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
32 30
33 virtual std::unique_ptr<OnlineStream> CreateStream( 31 virtual std::unique_ptr<OnlineStream> CreateStream(
  1 +// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "sherpa-onnx/csrc/file-utils.h"
  15 +#include "sherpa-onnx/csrc/macros.h"
  16 +#include "sherpa-onnx/csrc/online-lm.h"
  17 +#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
  18 +#include "sherpa-onnx/csrc/online-paraformer-model.h"
  19 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  20 +#include "sherpa-onnx/csrc/online-recognizer.h"
  21 +#include "sherpa-onnx/csrc/symbol-table.h"
  22 +
  23 +namespace sherpa_onnx {
  24 +
  25 +static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src,
  26 + const SymbolTable &sym_table) {
  27 + OnlineRecognizerResult r;
  28 + r.tokens.reserve(src.tokens.size());
  29 +
  30 + std::string text;
  31 +
  32 + // When the current token ends with "@@" we set mergeable to true
  33 + bool mergeable = false;
  34 +
  35 + for (int32_t i = 0; i != src.tokens.size(); ++i) {
  36 + auto sym = sym_table[src.tokens[i]];
  37 + r.tokens.push_back(sym);
  38 +
  39 + if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) {
  40 + // sym does not end with "@@"
  41 + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
  42 + if (p[0] < 0x80) {
  43 + // an ascii
  44 + if (mergeable) {
  45 + mergeable = false;
  46 + text.append(sym);
  47 + } else {
  48 + text.append(" ");
  49 + text.append(sym);
  50 + }
  51 + } else {
  52 + // not an ascii
  53 + mergeable = false;
  54 +
  55 + if (i > 0) {
  56 + const uint8_t *p = reinterpret_cast<const uint8_t *>(
  57 + sym_table[src.tokens[i - 1]].c_str());
  58 + if (p[0] < 0x80) {
  59 + // put a space between ascii and non-ascii
  60 + text.append(" ");
  61 + }
  62 + }
  63 + text.append(sym);
  64 + }
  65 + } else {
  66 + // this sym ends with @@
  67 + sym = std::string(sym.data(), sym.size() - 2);
  68 + if (mergeable) {
  69 + text.append(sym);
  70 + } else {
  71 + text.append(" ");
  72 + text.append(sym);
  73 + mergeable = true;
  74 + }
  75 + }
  76 + }
  77 + r.text = std::move(text);
  78 +
  79 + return r;
  80 +}
  81 +
  82 +// y[i] += x[i] * scale
  83 +static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) {
  84 + for (int32_t i = 0; i != n; ++i) {
  85 + y[i] += x[i] * scale;
  86 + }
  87 +}
  88 +
  89 +// y[i] = x[i] * scale
  90 +static void Scale(const float *x, int32_t n, float scale, float *y) {
  91 + for (int32_t i = 0; i != n; ++i) {
  92 + y[i] = x[i] * scale;
  93 + }
  94 +}
  95 +
  96 +class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
  97 + public:
  98 + explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
  99 + : config_(config),
  100 + model_(config.model_config),
  101 + sym_(config.model_config.tokens),
  102 + endpoint_(config_.endpoint_config) {
  103 + if (config.decoding_method != "greedy_search") {
  104 + SHERPA_ONNX_LOGE(
  105 + "Unsupported decoding method: %s. Support only greedy_search at "
  106 + "present",
  107 + config.decoding_method.c_str());
  108 + exit(-1);
  109 + }
  110 +
  111 + // Paraformer models assume input samples are in the range
  112 + // [-32768, 32767], so we set normalize_samples to false
  113 + config_.feat_config.normalize_samples = false;
  114 + }
  115 +
  116 +#if __ANDROID_API__ >= 9
  117 + explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
  118 + const OnlineRecognizerConfig &config)
  119 + : config_(config),
  120 + model_(mgr, config.model_config),
  121 + sym_(mgr, config.model_config.tokens),
  122 + endpoint_(config_.endpoint_config) {
  123 + if (config.decoding_method == "greedy_search") {
  124 + // add greedy search decoder
  125 + // SHERPA_ONNX_LOGE("to be implemented");
  126 + // exit(-1);
  127 + } else {
  128 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  129 + config.decoding_method.c_str());
  130 + exit(-1);
  131 + }
  132 +
  133 + // Paraformer models assume input samples are in the range
  134 + // [-32768, 32767], so we set normalize_samples to false
  135 + config_.feat_config.normalize_samples = false;
  136 + }
  137 +#endif
  138 + OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) =
  139 + delete;
  140 +
  141 + OnlineRecognizerParaformerImpl operator=(
  142 + const OnlineRecognizerParaformerImpl &) = delete;
  143 +
  144 + std::unique_ptr<OnlineStream> CreateStream() const override {
  145 + auto stream = std::make_unique<OnlineStream>(config_.feat_config);
  146 +
  147 + OnlineParaformerDecoderResult r;
  148 + stream->SetParaformerResult(r);
  149 +
  150 + return stream;
  151 + }
  152 +
  153 + bool IsReady(OnlineStream *s) const override {
  154 + return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady();
  155 + }
  156 +
  157 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  158 + // TODO(fangjun): Support batch size > 1
  159 + for (int32_t i = 0; i != n; ++i) {
  160 + DecodeStream(ss[i]);
  161 + }
  162 + }
  163 +
  164 + OnlineRecognizerResult GetResult(OnlineStream *s) const override {
  165 + auto decoder_result = s->GetParaformerResult();
  166 +
  167 + return Convert(decoder_result, sym_);
  168 + }
  169 +
  170 + bool IsEndpoint(OnlineStream *s) const override {
  171 + if (!config_.enable_endpoint) {
  172 + return false;
  173 + }
  174 +
  175 + const auto &result = s->GetParaformerResult();
  176 +
  177 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  178 +
  179 + // frame shift is 10 milliseconds
  180 + float frame_shift_in_seconds = 0.01;
  181 +
  182 + int32_t trailing_silence_frames =
  183 + num_processed_frames - result.last_non_blank_frame_index;
  184 +
  185 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  186 + frame_shift_in_seconds);
  187 + }
  188 +
  189 + void Reset(OnlineStream *s) const override {
  190 + OnlineParaformerDecoderResult r;
  191 + s->SetParaformerResult(r);
  192 +
  193 + // the internal model caches are not reset
  194 +
  195 + // Note: We only update counters. The underlying audio samples
  196 + // are not discarded.
  197 + s->Reset();
  198 + }
  199 +
  200 + private:
  201 + void DecodeStream(OnlineStream *s) const {
  202 + const auto num_processed_frames = s->GetNumProcessedFrames();
  203 + std::vector<float> frames = s->GetFrames(num_processed_frames, chunk_size_);
  204 + s->GetNumProcessedFrames() += chunk_size_ - 1;
  205 +
  206 + frames = ApplyLFR(frames);
  207 + ApplyCMVN(&frames);
  208 + PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift());
  209 +
  210 + int32_t feat_dim = model_.NegativeMean().size();
  211 +
  212 + // We have scaled inv_stddev by sqrt(encoder_output_size)
  213 + // so the following line can be commented out
  214 + // frames *= encoder_output_size ** 0.5
  215 +
  216 + // add overlap chunk
  217 + std::vector<float> &feat_cache = s->GetParaformerFeatCache();
  218 + if (feat_cache.empty()) {
  219 + int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim;
  220 + feat_cache.resize(n, 0);
  221 + }
  222 +
  223 + frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end());
  224 + std::copy(frames.end() - feat_cache.size(), frames.end(),
  225 + feat_cache.begin());
  226 +
  227 + int32_t num_frames = frames.size() / feat_dim;
  228 +
  229 + auto memory_info =
  230 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  231 +
  232 + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
  233 + Ort::Value x =
  234 + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
  235 + x_shape.data(), x_shape.size());
  236 +
  237 + int64_t x_len_shape = 1;
  238 + int32_t x_len_val = num_frames;
  239 +
  240 + Ort::Value x_length =
  241 + Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1);
  242 +
  243 + auto encoder_out_vec =
  244 + model_.ForwardEncoder(std::move(x), std::move(x_length));
  245 +
  246 + // CIF search
  247 + auto &encoder_out = encoder_out_vec[0];
  248 + auto &encoder_out_len = encoder_out_vec[1];
  249 + auto &alpha = encoder_out_vec[2];
  250 +
  251 + float *p_alpha = alpha.GetTensorMutableData<float>();
  252 +
  253 + std::vector<int64_t> alpha_shape =
  254 + alpha.GetTensorTypeAndShapeInfo().GetShape();
  255 +
  256 + std::fill(p_alpha, p_alpha + left_chunk_size_, 0);
  257 + std::fill(p_alpha + alpha_shape[1] - right_chunk_size_,
  258 + p_alpha + alpha_shape[1], 0);
  259 +
  260 + const float *p_encoder_out = encoder_out.GetTensorData<float>();
  261 +
  262 + std::vector<int64_t> encoder_out_shape =
  263 + encoder_out.GetTensorTypeAndShapeInfo().GetShape();
  264 +
  265 + std::vector<float> &initial_hidden = s->GetParaformerEncoderOutCache();
  266 + if (initial_hidden.empty()) {
  267 + initial_hidden.resize(encoder_out_shape[2]);
  268 + }
  269 +
  270 + std::vector<float> &alpha_cache = s->GetParaformerAlphaCache();
  271 + if (alpha_cache.empty()) {
  272 + alpha_cache.resize(1);
  273 + }
  274 +
  275 + std::vector<float> acoustic_embedding;
  276 + acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]);
  277 +
  278 + float threshold = 1.0;
  279 +
  280 + float integrate = alpha_cache[0];
  281 +
  282 + for (int32_t i = 0; i != encoder_out_shape[1]; ++i) {
  283 + float this_alpha = p_alpha[i];
  284 + if (integrate + this_alpha < threshold) {
  285 + integrate += this_alpha;
  286 + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
  287 + encoder_out_shape[2], this_alpha,
  288 + initial_hidden.data());
  289 + continue;
  290 + }
  291 +
  292 + // fire
  293 + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
  294 + encoder_out_shape[2], threshold - integrate,
  295 + initial_hidden.data());
  296 + acoustic_embedding.insert(acoustic_embedding.end(),
  297 + initial_hidden.begin(), initial_hidden.end());
  298 + integrate += this_alpha - threshold;
  299 +
  300 + Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2],
  301 + integrate, initial_hidden.data());
  302 + }
  303 +
  304 + alpha_cache[0] = integrate;
  305 +
  306 + if (acoustic_embedding.empty()) {
  307 + return;
  308 + }
  309 +
  310 + auto &states = s->GetStates();
  311 + if (states.empty()) {
  312 + states.reserve(model_.DecoderNumBlocks());
  313 +
  314 + std::array<int64_t, 3> shape{1, model_.EncoderOutputSize(),
  315 + model_.DecoderKernelSize() - 1};
  316 +
  317 + int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2];
  318 +
  319 + for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) {
  320 + Ort::Value this_state = Ort::Value::CreateTensor<float>(
  321 + model_.Allocator(), shape.data(), shape.size());
  322 +
  323 + memset(this_state.GetTensorMutableData<float>(), 0, num_bytes);
  324 +
  325 + states.push_back(std::move(this_state));
  326 + }
  327 + }
  328 +
  329 + int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size();
  330 + std::array<int64_t, 3> acoustic_embedding_shape{
  331 + 1, num_tokens, static_cast<int32_t>(initial_hidden.size())};
  332 +
  333 + Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor(
  334 + memory_info, acoustic_embedding.data(), acoustic_embedding.size(),
  335 + acoustic_embedding_shape.data(), acoustic_embedding_shape.size());
  336 +
  337 + std::array<int64_t, 1> acoustic_embedding_length_shape{1};
  338 + Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor(
  339 + memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(),
  340 + acoustic_embedding_length_shape.size());
  341 +
  342 + auto decoder_out_vec = model_.ForwardDecoder(
  343 + std::move(encoder_out), std::move(encoder_out_len),
  344 + std::move(acoustic_embedding_tensor),
  345 + std::move(acoustic_embedding_length_tensor), std::move(states));
  346 +
  347 + states.reserve(model_.DecoderNumBlocks());
  348 + for (int32_t i = 2; i != decoder_out_vec.size(); ++i) {
  349 + // TODO(fangjun): When we change chunk_size_, we need to
  350 + // slice decoder_out_vec[i] accordingly.
  351 + states.push_back(std::move(decoder_out_vec[i]));
  352 + }
  353 +
  354 + const auto &sample_ids = decoder_out_vec[1];
  355 + const int64_t *p_sample_ids = sample_ids.GetTensorData<int64_t>();
  356 +
  357 + bool non_blank_detected = false;
  358 +
  359 + auto &result = s->GetParaformerResult();
  360 +
  361 + for (int32_t i = 0; i != num_tokens; ++i) {
  362 + int32_t t = p_sample_ids[i];
  363 + if (t == 0) {
  364 + continue;
  365 + }
  366 +
  367 + non_blank_detected = true;
  368 + result.tokens.push_back(t);
  369 + }
  370 +
  371 + if (non_blank_detected) {
  372 + result.last_non_blank_frame_index = num_processed_frames;
  373 + }
  374 + }
  375 +
  376 + std::vector<float> ApplyLFR(const std::vector<float> &in) const {
  377 + int32_t lfr_window_size = model_.LfrWindowSize();
  378 + int32_t lfr_window_shift = model_.LfrWindowShift();
  379 + int32_t in_feat_dim = config_.feat_config.feature_dim;
  380 +
  381 + int32_t in_num_frames = in.size() / in_feat_dim;
  382 + int32_t out_num_frames =
  383 + (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
  384 + int32_t out_feat_dim = in_feat_dim * lfr_window_size;
  385 +
  386 + std::vector<float> out(out_num_frames * out_feat_dim);
  387 +
  388 + const float *p_in = in.data();
  389 + float *p_out = out.data();
  390 +
  391 + for (int32_t i = 0; i != out_num_frames; ++i) {
  392 + std::copy(p_in, p_in + out_feat_dim, p_out);
  393 +
  394 + p_out += out_feat_dim;
  395 + p_in += lfr_window_shift * in_feat_dim;
  396 + }
  397 +
  398 + return out;
  399 + }
  400 +
  401 + void ApplyCMVN(std::vector<float> *v) const {
  402 + const std::vector<float> &neg_mean = model_.NegativeMean();
  403 + const std::vector<float> &inv_stddev = model_.InverseStdDev();
  404 +
  405 + int32_t dim = neg_mean.size();
  406 + int32_t num_frames = v->size() / dim;
  407 +
  408 + float *p = v->data();
  409 +
  410 + for (int32_t i = 0; i != num_frames; ++i) {
  411 + for (int32_t k = 0; k != dim; ++k) {
  412 + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
  413 + }
  414 +
  415 + p += dim;
  416 + }
  417 + }
  418 +
  419 + void PositionalEncoding(std::vector<float> *v, int32_t t_offset) const {
  420 + int32_t lfr_window_size = model_.LfrWindowSize();
  421 + int32_t in_feat_dim = config_.feat_config.feature_dim;
  422 +
  423 + int32_t feat_dim = in_feat_dim * lfr_window_size;
  424 + int32_t T = v->size() / feat_dim;
  425 +
  426 + // log(10000)/(7*80/2-1) == 0.03301197265941284
  427 + // 7 is lfr_window_size
  428 + // 80 is in_feat_dim
  429 + // 7*80 is feat_dim
  430 + constexpr float kScale = -0.03301197265941284;
  431 +
  432 + for (int32_t t = 0; t != T; ++t) {
  433 + float *p = v->data() + t * feat_dim;
  434 +
  435 + int32_t offset = t + 1 + t_offset;
  436 +
  437 + for (int32_t d = 0; d < feat_dim / 2; ++d) {
  438 + float inv_timescale = offset * std::exp(d * kScale);
  439 +
  440 + float sin_d = std::sin(inv_timescale);
  441 + float cos_d = std::cos(inv_timescale);
  442 +
  443 + p[d] += sin_d;
  444 + p[d + feat_dim / 2] += cos_d;
  445 + }
  446 + }
  447 + }
  448 +
  449 + private:
  450 + OnlineRecognizerConfig config_;
  451 + OnlineParaformerModel model_;
  452 + SymbolTable sym_;
  453 + Endpoint endpoint_;
  454 +
  455 + // 0.61 seconds
  456 + int32_t chunk_size_ = 61;
  457 + // (61 - 7) / 6 + 1 = 10
  458 +
  459 + int32_t left_chunk_size_ = 5;
  460 + int32_t right_chunk_size_ = 5;
  461 +};
  462 +
  463 +} // namespace sherpa_onnx
  464 +
  465 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
@@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
94 } 94 }
95 #endif 95 #endif
96 96
97 - void InitOnlineStream(OnlineStream *stream) const override {  
98 - auto r = decoder_->GetEmptyResult();  
99 -  
100 - if (config_.decoding_method == "modified_beam_search" &&  
101 - nullptr != stream->GetContextGraph()) {  
102 - // r.hyps has only one element.  
103 - for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {  
104 - it->second.context_state = stream->GetContextGraph()->Root();  
105 - }  
106 - }  
107 -  
108 - stream->SetResult(r);  
109 - stream->SetStates(model_->GetEncoderInitStates());  
110 - }  
111 -  
112 std::unique_ptr<OnlineStream> CreateStream() const override { 97 std::unique_ptr<OnlineStream> CreateStream() const override {
113 auto stream = std::make_unique<OnlineStream>(config_.feat_config); 98 auto stream = std::make_unique<OnlineStream>(config_.feat_config);
114 InitOnlineStream(stream.get()); 99 InitOnlineStream(stream.get());
@@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
211 } 196 }
212 197
213 bool IsEndpoint(OnlineStream *s) const override { 198 bool IsEndpoint(OnlineStream *s) const override {
214 - if (!config_.enable_endpoint) return false; 199 + if (!config_.enable_endpoint) {
  200 + return false;
  201 + }
  202 +
215 int32_t num_processed_frames = s->GetNumProcessedFrames(); 203 int32_t num_processed_frames = s->GetNumProcessedFrames();
216 204
217 // frame shift is 10 milliseconds 205 // frame shift is 10 milliseconds
@@ -245,6 +233,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -245,6 +233,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
245 } 233 }
246 234
247 private: 235 private:
  236 + void InitOnlineStream(OnlineStream *stream) const {
  237 + auto r = decoder_->GetEmptyResult();
  238 +
  239 + if (config_.decoding_method == "modified_beam_search" &&
  240 + nullptr != stream->GetContextGraph()) {
  241 + // r.hyps has only one element.
  242 + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
  243 + it->second.context_state = stream->GetContextGraph()->Root();
  244 + }
  245 + }
  246 +
  247 + stream->SetResult(r);
  248 + stream->SetStates(model_->GetEncoderInitStates());
  249 + }
  250 +
  251 + private:
248 OnlineRecognizerConfig config_; 252 OnlineRecognizerConfig config_;
249 std::unique_ptr<OnlineTransducerModel> model_; 253 std::unique_ptr<OnlineTransducerModel> model_;
250 std::unique_ptr<OnlineLM> lm_; 254 std::unique_ptr<OnlineLM> lm_;
@@ -47,6 +47,14 @@ class OnlineStream::Impl { @@ -47,6 +47,14 @@ class OnlineStream::Impl {
47 47
48 OnlineTransducerDecoderResult &GetResult() { return result_; } 48 OnlineTransducerDecoderResult &GetResult() { return result_; }
49 49
  50 + void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
  51 + paraformer_result_ = r;
  52 + }
  53 +
  54 + OnlineParaformerDecoderResult &GetParaformerResult() {
  55 + return paraformer_result_;
  56 + }
  57 +
50 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } 58 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
51 59
52 void SetStates(std::vector<Ort::Value> states) { 60 void SetStates(std::vector<Ort::Value> states) {
@@ -57,6 +65,18 @@ class OnlineStream::Impl { @@ -57,6 +65,18 @@ class OnlineStream::Impl {
57 65
58 const ContextGraphPtr &GetContextGraph() const { return context_graph_; } 66 const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
59 67
  68 + std::vector<float> &GetParaformerFeatCache() {
  69 + return paraformer_feat_cache_;
  70 + }
  71 +
  72 + std::vector<float> &GetParaformerEncoderOutCache() {
  73 + return paraformer_encoder_out_cache_;
  74 + }
  75 +
  76 + std::vector<float> &GetParaformerAlphaCache() {
  77 + return paraformer_alpha_cache_;
  78 + }
  79 +
60 private: 80 private:
61 FeatureExtractor feat_extractor_; 81 FeatureExtractor feat_extractor_;
62 /// For contextual-biasing 82 /// For contextual-biasing
@@ -65,6 +85,10 @@ class OnlineStream::Impl { @@ -65,6 +85,10 @@ class OnlineStream::Impl {
65 int32_t start_frame_index_ = 0; // never reset 85 int32_t start_frame_index_ = 0; // never reset
66 OnlineTransducerDecoderResult result_; 86 OnlineTransducerDecoderResult result_;
67 std::vector<Ort::Value> states_; 87 std::vector<Ort::Value> states_;
  88 + std::vector<float> paraformer_feat_cache_;
  89 + std::vector<float> paraformer_encoder_out_cache_;
  90 + std::vector<float> paraformer_alpha_cache_;
  91 + OnlineParaformerDecoderResult paraformer_result_;
68 }; 92 };
69 93
70 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, 94 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { @@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
107 return impl_->GetResult(); 131 return impl_->GetResult();
108 } 132 }
109 133
  134 +void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
  135 + impl_->SetParaformerResult(r);
  136 +}
  137 +
  138 +OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() {
  139 + return impl_->GetParaformerResult();
  140 +}
  141 +
110 void OnlineStream::SetStates(std::vector<Ort::Value> states) { 142 void OnlineStream::SetStates(std::vector<Ort::Value> states) {
111 impl_->SetStates(std::move(states)); 143 impl_->SetStates(std::move(states));
112 } 144 }
@@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { @@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
119 return impl_->GetContextGraph(); 151 return impl_->GetContextGraph();
120 } 152 }
121 153
  154 +std::vector<float> &OnlineStream::GetParaformerFeatCache() {
  155 + return impl_->GetParaformerFeatCache();
  156 +}
  157 +
  158 +std::vector<float> &OnlineStream::GetParaformerEncoderOutCache() {
  159 + return impl_->GetParaformerEncoderOutCache();
  160 +}
  161 +
  162 +std::vector<float> &OnlineStream::GetParaformerAlphaCache() {
  163 + return impl_->GetParaformerAlphaCache();
  164 +}
  165 +
122 } // namespace sherpa_onnx 166 } // namespace sherpa_onnx
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/context-graph.h" 12 #include "sherpa-onnx/csrc/context-graph.h"
13 #include "sherpa-onnx/csrc/features.h" 13 #include "sherpa-onnx/csrc/features.h"
  14 +#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
14 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 15 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -70,6 +71,9 @@ class OnlineStream { @@ -70,6 +71,9 @@ class OnlineStream {
70 void SetResult(const OnlineTransducerDecoderResult &r); 71 void SetResult(const OnlineTransducerDecoderResult &r);
71 OnlineTransducerDecoderResult &GetResult(); 72 OnlineTransducerDecoderResult &GetResult();
72 73
  74 + void SetParaformerResult(const OnlineParaformerDecoderResult &r);
  75 + OnlineParaformerDecoderResult &GetParaformerResult();
  76 +
73 void SetStates(std::vector<Ort::Value> states); 77 void SetStates(std::vector<Ort::Value> states);
74 std::vector<Ort::Value> &GetStates(); 78 std::vector<Ort::Value> &GetStates();
75 79
@@ -80,6 +84,11 @@ class OnlineStream { @@ -80,6 +84,11 @@ class OnlineStream {
80 */ 84 */
81 const ContextGraphPtr &GetContextGraph() const; 85 const ContextGraphPtr &GetContextGraph() const;
82 86
  87 + // for streaming parformer
  88 + std::vector<float> &GetParaformerFeatCache();
  89 + std::vector<float> &GetParaformerEncoderOutCache();
  90 + std::vector<float> &GetParaformerAlphaCache();
  91 +
83 private: 92 private:
84 class Impl; 93 class Impl;
85 std::unique_ptr<Impl> impl_; 94 std::unique_ptr<Impl> impl_;
@@ -12,8 +12,8 @@ @@ -12,8 +12,8 @@
12 12
13 #include "sherpa-onnx/csrc/online-recognizer.h" 13 #include "sherpa-onnx/csrc/online-recognizer.h"
14 #include "sherpa-onnx/csrc/online-stream.h" 14 #include "sherpa-onnx/csrc/online-stream.h"
15 -#include "sherpa-onnx/csrc/symbol-table.h"  
16 #include "sherpa-onnx/csrc/parse-options.h" 15 #include "sherpa-onnx/csrc/parse-options.h"
  16 +#include "sherpa-onnx/csrc/symbol-table.h"
17 #include "sherpa-onnx/csrc/wave-reader.h" 17 #include "sherpa-onnx/csrc/wave-reader.h"
18 18
19 typedef struct { 19 typedef struct {
@@ -80,7 +80,7 @@ for a list of pre-trained models to download. @@ -80,7 +80,7 @@ for a list of pre-trained models to download.
80 80
81 bool is_ok = false; 81 bool is_ok = false;
82 const std::vector<float> samples = 82 const std::vector<float> samples =
83 - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); 83 + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
84 84
85 if (!is_ok) { 85 if (!is_ok) {
86 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); 86 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
@@ -92,14 +92,14 @@ for a list of pre-trained models to download. @@ -92,14 +92,14 @@ for a list of pre-trained models to download.
92 auto s = recognizer.CreateStream(); 92 auto s = recognizer.CreateStream();
93 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 93 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
94 94
95 - std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate)); 95 + std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
96 // Note: We can call AcceptWaveform() multiple times. 96 // Note: We can call AcceptWaveform() multiple times.
97 - s->AcceptWaveform(  
98 - sampling_rate, tail_paddings.data(), tail_paddings.size()); 97 + s->AcceptWaveform(sampling_rate, tail_paddings.data(),
  98 + tail_paddings.size());
99 99
100 // Call InputFinished() to indicate that no audio samples are available 100 // Call InputFinished() to indicate that no audio samples are available
101 s->InputFinished(); 101 s->InputFinished();
102 - ss.push_back({ std::move(s), duration, 0 }); 102 + ss.push_back({std::move(s), duration, 0});
103 } 103 }
104 104
105 std::vector<sherpa_onnx::OnlineStream *> ready_streams; 105 std::vector<sherpa_onnx::OnlineStream *> ready_streams;
@@ -112,8 +112,9 @@ for a list of pre-trained models to download. @@ -112,8 +112,9 @@ for a list of pre-trained models to download.
112 } else if (s.elapsed_seconds == 0) { 112 } else if (s.elapsed_seconds == 0) {
113 const auto end = std::chrono::steady_clock::now(); 113 const auto end = std::chrono::steady_clock::now();
114 const float elapsed_seconds = 114 const float elapsed_seconds =
115 - std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)  
116 - .count() / 1000.; 115 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  116 + .count() /
  117 + 1000.;
117 s.elapsed_seconds = elapsed_seconds; 118 s.elapsed_seconds = elapsed_seconds;
118 } 119 }
119 } 120 }
@@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx @@ -15,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
15 offline-whisper-model-config.cc 15 offline-whisper-model-config.cc
16 online-lm-config.cc 16 online-lm-config.cc
17 online-model-config.cc 17 online-model-config.cc
  18 + online-paraformer-model-config.cc
18 online-recognizer.cc 19 online-recognizer.cc
19 online-stream.cc 20 online-stream.cc
20 online-transducer-model-config.cc 21 online-transducer-model-config.cc
1 // sherpa-onnx/python/csrc/online-model-config.cc 1 // sherpa-onnx/python/csrc/online-model-config.cc
2 // 2 //
3 -// Copyright (c) 2023 by manyeyes 3 +// Copyright (c) 2023 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/python/csrc/online-model-config.h" 5 #include "sherpa-onnx/python/csrc/online-model-config.h"
6 6
@@ -9,21 +9,26 @@ @@ -9,21 +9,26 @@
9 9
10 #include "sherpa-onnx/csrc/online-model-config.h" 10 #include "sherpa-onnx/csrc/online-model-config.h"
11 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
  12 +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
12 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 13 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
13 14
14 namespace sherpa_onnx { 15 namespace sherpa_onnx {
15 16
16 void PybindOnlineModelConfig(py::module *m) { 17 void PybindOnlineModelConfig(py::module *m) {
17 PybindOnlineTransducerModelConfig(m); 18 PybindOnlineTransducerModelConfig(m);
  19 + PybindOnlineParaformerModelConfig(m);
18 20
19 using PyClass = OnlineModelConfig; 21 using PyClass = OnlineModelConfig;
20 py::class_<PyClass>(*m, "OnlineModelConfig") 22 py::class_<PyClass>(*m, "OnlineModelConfig")
21 - .def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t, 23 + .def(py::init<const OnlineTransducerModelConfig &,
  24 + const OnlineParaformerModelConfig &, std::string &, int32_t,
22 bool, const std::string &, const std::string &>(), 25 bool, const std::string &, const std::string &>(),
23 py::arg("transducer") = OnlineTransducerModelConfig(), 26 py::arg("transducer") = OnlineTransducerModelConfig(),
  27 + py::arg("paraformer") = OnlineParaformerModelConfig(),
24 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, 28 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
25 py::arg("provider") = "cpu", py::arg("model_type") = "") 29 py::arg("provider") = "cpu", py::arg("model_type") = "")
26 .def_readwrite("transducer", &PyClass::transducer) 30 .def_readwrite("transducer", &PyClass::transducer)
  31 + .def_readwrite("paraformer", &PyClass::paraformer)
27 .def_readwrite("tokens", &PyClass::tokens) 32 .def_readwrite("tokens", &PyClass::tokens)
28 .def_readwrite("num_threads", &PyClass::num_threads) 33 .def_readwrite("num_threads", &PyClass::num_threads)
29 .def_readwrite("debug", &PyClass::debug) 34 .def_readwrite("debug", &PyClass::debug)
1 // sherpa-onnx/python/csrc/online-model-config.h 1 // sherpa-onnx/python/csrc/online-model-config.h
2 // 2 //
3 -// Copyright (c) 2023 by manyeyes 3 +// Copyright (c) 2023 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ 5 #ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
6 #define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ 6 #define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/online-paraformer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOnlineParaformerModelConfig(py::module *m) {
  15 + using PyClass = OnlineParaformerModelConfig;
  16 + py::class_<PyClass>(*m, "OnlineParaformerModelConfig")
  17 + .def(py::init<const std::string &, const std::string &>(),
  18 + py::arg("encoder"), py::arg("decoder"))
  19 + .def_readwrite("encoder", &PyClass::encoder)
  20 + .def_readwrite("decoder", &PyClass::decoder)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-paraformer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineParaformerModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
@@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -33,7 +33,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
33 py::arg("feat_config"), py::arg("model_config"), 33 py::arg("feat_config"), py::arg("model_config"),
34 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), 34 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
35 py::arg("enable_endpoint"), py::arg("decoding_method"), 35 py::arg("enable_endpoint"), py::arg("decoding_method"),
36 - py::arg("max_active_paths"), py::arg("context_score")) 36 + py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
37 .def_readwrite("feat_config", &PyClass::feat_config) 37 .def_readwrite("feat_config", &PyClass::feat_config)
38 .def_readwrite("model_config", &PyClass::model_config) 38 .def_readwrite("model_config", &PyClass::model_config)
39 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 39 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
@@ -6,6 +6,7 @@ from _sherpa_onnx import ( @@ -6,6 +6,7 @@ from _sherpa_onnx import (
6 EndpointConfig, 6 EndpointConfig,
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
8 OnlineModelConfig, 8 OnlineModelConfig,
  9 + OnlineParaformerModelConfig,
9 OnlineRecognizer as _Recognizer, 10 OnlineRecognizer as _Recognizer,
10 OnlineRecognizerConfig, 11 OnlineRecognizerConfig,
11 OnlineStream, 12 OnlineStream,
@@ -32,7 +33,7 @@ class OnlineRecognizer(object): @@ -32,7 +33,7 @@ class OnlineRecognizer(object):
32 encoder: str, 33 encoder: str,
33 decoder: str, 34 decoder: str,
34 joiner: str, 35 joiner: str,
35 - num_threads: int = 4, 36 + num_threads: int = 2,
36 sample_rate: float = 16000, 37 sample_rate: float = 16000,
37 feature_dim: int = 80, 38 feature_dim: int = 80,
38 enable_endpoint_detection: bool = False, 39 enable_endpoint_detection: bool = False,
@@ -144,6 +145,109 @@ class OnlineRecognizer(object): @@ -144,6 +145,109 @@ class OnlineRecognizer(object):
144 self.config = recognizer_config 145 self.config = recognizer_config
145 return self 146 return self
146 147
  148 + @classmethod
  149 + def from_paraformer(
  150 + cls,
  151 + tokens: str,
  152 + encoder: str,
  153 + decoder: str,
  154 + num_threads: int = 2,
  155 + sample_rate: float = 16000,
  156 + feature_dim: int = 80,
  157 + enable_endpoint_detection: bool = False,
  158 + rule1_min_trailing_silence: float = 2.4,
  159 + rule2_min_trailing_silence: float = 1.2,
  160 + rule3_min_utterance_length: float = 20.0,
  161 + decoding_method: str = "greedy_search",
  162 + provider: str = "cpu",
  163 + ):
  164 + """
  165 + Please refer to
  166 + `<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
  167 + to download pre-trained models for different languages, e.g., Chinese,
  168 + English, etc.
  169 +
  170 + Args:
  171 + tokens:
  172 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  173 + columns::
  174 +
  175 + symbol integer_id
  176 +
  177 + encoder:
  178 + Path to ``encoder.onnx``.
  179 + decoder:
  180 + Path to ``decoder.onnx``.
  181 + num_threads:
  182 + Number of threads for neural network computation.
  183 + sample_rate:
  184 + Sample rate of the training data used to train the model.
  185 + feature_dim:
  186 + Dimension of the feature used to train the model.
  187 + enable_endpoint_detection:
  188 + True to enable endpoint detection. False to disable endpoint
  189 + detection.
  190 + rule1_min_trailing_silence:
  191 + Used only when enable_endpoint_detection is True. If the duration
  192 + of trailing silence in seconds is larger than this value, we assume
  193 + an endpoint is detected.
  194 + rule2_min_trailing_silence:
  195 + Used only when enable_endpoint_detection is True. If we have decoded
  196 + something that is nonsilence and if the duration of trailing silence
  197 + in seconds is larger than this value, we assume an endpoint is
  198 + detected.
  199 + rule3_min_utterance_length:
  200 + Used only when enable_endpoint_detection is True. If the utterance
  201 + length in seconds is larger than this value, we assume an endpoint
  202 + is detected.
  203 + decoding_method:
  204 + The only valid value is greedy_search.
  205 + provider:
  206 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  207 + """
  208 + self = cls.__new__(cls)
  209 + _assert_file_exists(tokens)
  210 + _assert_file_exists(encoder)
  211 + _assert_file_exists(decoder)
  212 +
  213 + assert num_threads > 0, num_threads
  214 +
  215 + paraformer_config = OnlineParaformerModelConfig(
  216 + encoder=encoder,
  217 + decoder=decoder,
  218 + )
  219 +
  220 + model_config = OnlineModelConfig(
  221 + paraformer=paraformer_config,
  222 + tokens=tokens,
  223 + num_threads=num_threads,
  224 + provider=provider,
  225 + model_type="paraformer",
  226 + )
  227 +
  228 + feat_config = FeatureExtractorConfig(
  229 + sampling_rate=sample_rate,
  230 + feature_dim=feature_dim,
  231 + )
  232 +
  233 + endpoint_config = EndpointConfig(
  234 + rule1_min_trailing_silence=rule1_min_trailing_silence,
  235 + rule2_min_trailing_silence=rule2_min_trailing_silence,
  236 + rule3_min_utterance_length=rule3_min_utterance_length,
  237 + )
  238 +
  239 + recognizer_config = OnlineRecognizerConfig(
  240 + feat_config=feat_config,
  241 + model_config=model_config,
  242 + endpoint_config=endpoint_config,
  243 + enable_endpoint=enable_endpoint_detection,
  244 + decoding_method=decoding_method,
  245 + )
  246 +
  247 + self.recognizer = _Recognizer(recognizer_config)
  248 + self.config = recognizer_config
  249 + return self
  250 +
147 def create_stream(self, contexts_list: Optional[List[List[int]]] = None): 251 def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
148 if contexts_list is None: 252 if contexts_list is None:
149 return self.recognizer.create_stream() 253 return self.recognizer.create_stream()