Committed by
GitHub
Add scripts to export ASR models from wenet to ONNX (#425)
See https://user-images.githubusercontent.com/5284924/282995968-f6d39118-8008-4ce7-9d7c-d1d6387ac183.png
正在显示
7 个修改的文件
包含
1163 行增加
和
0 行删除
.github/workflows/export-wenet-to-onnx.yaml
0 → 100644
| 1 | +name: export-wenet-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + paths: | ||
| 8 | + - 'scripts/wenet/**' | ||
| 9 | + - '.github/workflows/export-wenet-to-onnx.yaml' | ||
| 10 | + pull_request: | ||
| 11 | + paths: | ||
| 12 | + - 'scripts/wenet/**' | ||
| 13 | + - '.github/workflows/export-wenet-to-onnx.yaml' | ||
| 14 | + | ||
| 15 | + workflow_dispatch: | ||
| 16 | + | ||
| 17 | +concurrency: | ||
| 18 | + group: export-wenet-to-onnx-${{ github.ref }} | ||
| 19 | + cancel-in-progress: true | ||
| 20 | + | ||
| 21 | +jobs: | ||
| 22 | + export-wenet-to-onnx: | ||
| 23 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 24 | + name: export wenet | ||
| 25 | + runs-on: ${{ matrix.os }} | ||
| 26 | + strategy: | ||
| 27 | + fail-fast: false | ||
| 28 | + matrix: | ||
| 29 | + os: [ubuntu-latest] | ||
| 30 | + python-version: ["3.8"] | ||
| 31 | + | ||
| 32 | + steps: | ||
| 33 | + - uses: actions/checkout@v4 | ||
| 34 | + | ||
| 35 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 36 | + uses: actions/setup-python@v2 | ||
| 37 | + with: | ||
| 38 | + python-version: ${{ matrix.python-version }} | ||
| 39 | + | ||
| 40 | + - name: Run | ||
| 41 | + shell: bash | ||
| 42 | + run: | | ||
| 43 | + sudo apt-get install tree sox | ||
| 44 | + cd scripts/wenet | ||
| 45 | + ./run.sh | ||
| 46 | + | ||
| 47 | + - name: Publish to huggingface (aishell) | ||
| 48 | + env: | ||
| 49 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 50 | + uses: nick-fields/retry@v2 | ||
| 51 | + with: | ||
| 52 | + max_attempts: 20 | ||
| 53 | + timeout_seconds: 200 | ||
| 54 | + shell: bash | ||
| 55 | + command: | | ||
| 56 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 57 | + git config --global user.name "Fangjun Kuang" | ||
| 58 | + | ||
| 59 | + rm -rf huggingface | ||
| 60 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 61 | + | ||
| 62 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell huggingface | ||
| 63 | + cd huggingface | ||
| 64 | + git fetch | ||
| 65 | + git pull | ||
| 66 | + | ||
| 67 | + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/*.onnx . | ||
| 68 | + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/units.txt tokens.txt | ||
| 69 | + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/README.md . | ||
| 70 | + | ||
| 71 | + if [ ! -d test_wavs ]; then | ||
| 72 | + mkdir test_wavs | ||
| 73 | + cd test_wavs | ||
| 74 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav | ||
| 75 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav | ||
| 76 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav | ||
| 77 | + cd .. | ||
| 78 | + fi | ||
| 79 | + git lfs track "*.onnx" | ||
| 80 | + git add . | ||
| 81 | + | ||
| 82 | + git commit -m "add aishell models" | ||
| 83 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell main || true | ||
| 84 | + | ||
| 85 | + cd .. | ||
| 86 | + rm -rf huggingface | ||
| 87 | + | ||
| 88 | + - name: Publish to huggingface (aishell2) | ||
| 89 | + env: | ||
| 90 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 91 | + uses: nick-fields/retry@v2 | ||
| 92 | + with: | ||
| 93 | + max_attempts: 20 | ||
| 94 | + timeout_seconds: 200 | ||
| 95 | + shell: bash | ||
| 96 | + command: | | ||
| 97 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 98 | + git config --global user.name "Fangjun Kuang" | ||
| 99 | + | ||
| 100 | + rm -rf huggingface | ||
| 101 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 102 | + | ||
| 103 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell2 huggingface | ||
| 104 | + cd huggingface | ||
| 105 | + git fetch | ||
| 106 | + git pull | ||
| 107 | + | ||
| 108 | + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/*.onnx . | ||
| 109 | + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/units.txt tokens.txt | ||
| 110 | + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/README.md . | ||
| 111 | + | ||
| 112 | + if [ ! -d test_wavs ]; then | ||
| 113 | + mkdir test_wavs | ||
| 114 | + cd test_wavs | ||
| 115 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav | ||
| 116 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav | ||
| 117 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav | ||
| 118 | + cd .. | ||
| 119 | + fi | ||
| 120 | + git lfs track "*.onnx" | ||
| 121 | + git add . | ||
| 122 | + | ||
| 123 | + git commit -m "add aishell2 models" | ||
| 124 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell2 main || true | ||
| 125 | + | ||
| 126 | + cd .. | ||
| 127 | + rm -rf huggingface | ||
| 128 | + | ||
| 129 | + - name: Publish to huggingface (multi_cn) | ||
| 130 | + env: | ||
| 131 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 132 | + uses: nick-fields/retry@v2 | ||
| 133 | + with: | ||
| 134 | + max_attempts: 20 | ||
| 135 | + timeout_seconds: 200 | ||
| 136 | + shell: bash | ||
| 137 | + command: | | ||
| 138 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 139 | + git config --global user.name "Fangjun Kuang" | ||
| 140 | + | ||
| 141 | + rm -rf huggingface | ||
| 142 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 143 | + | ||
| 144 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-multi-cn huggingface | ||
| 145 | + cd huggingface | ||
| 146 | + git fetch | ||
| 147 | + git pull | ||
| 148 | + | ||
| 149 | + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/*.onnx . | ||
| 150 | + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/units.txt tokens.txt | ||
| 151 | + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/README.md . | ||
| 152 | + | ||
| 153 | + if [ ! -d test_wavs ]; then | ||
| 154 | + mkdir test_wavs | ||
| 155 | + cd test_wavs | ||
| 156 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav | ||
| 157 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav | ||
| 158 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav | ||
| 159 | + cd .. | ||
| 160 | + fi | ||
| 161 | + git lfs track "*.onnx" | ||
| 162 | + git add . | ||
| 163 | + | ||
| 164 | + git commit -m "add multi_cn models" | ||
| 165 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-multi-cn main || true | ||
| 166 | + | ||
| 167 | + cd .. | ||
| 168 | + rm -rf huggingface | ||
| 169 | + | ||
| 170 | + - name: Publish to huggingface (wenetspeech) | ||
| 171 | + env: | ||
| 172 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 173 | + uses: nick-fields/retry@v2 | ||
| 174 | + with: | ||
| 175 | + max_attempts: 20 | ||
| 176 | + timeout_seconds: 200 | ||
| 177 | + shell: bash | ||
| 178 | + command: | | ||
| 179 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 180 | + git config --global user.name "Fangjun Kuang" | ||
| 181 | + | ||
| 182 | + rm -rf huggingface | ||
| 183 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 184 | + | ||
| 185 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech huggingface | ||
| 186 | + cd huggingface | ||
| 187 | + git fetch | ||
| 188 | + git pull | ||
| 189 | + | ||
| 190 | + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/*.onnx . | ||
| 191 | + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/units.txt tokens.txt | ||
| 192 | + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/README.md . | ||
| 193 | + | ||
| 194 | + if [ ! -d test_wavs ]; then | ||
| 195 | + mkdir test_wavs | ||
| 196 | + cd test_wavs | ||
| 197 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav | ||
| 198 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav | ||
| 199 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav | ||
| 200 | + cd .. | ||
| 201 | + fi | ||
| 202 | + git lfs track "*.onnx" | ||
| 203 | + git add . | ||
| 204 | + | ||
| 205 | + git commit -m "add wenetspeech models" | ||
| 206 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech main || true | ||
| 207 | + | ||
| 208 | + cd .. | ||
| 209 | + rm -rf huggingface | ||
| 210 | + | ||
| 211 | + - name: Publish to huggingface (librispeech) | ||
| 212 | + env: | ||
| 213 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 214 | + uses: nick-fields/retry@v2 | ||
| 215 | + with: | ||
| 216 | + max_attempts: 20 | ||
| 217 | + timeout_seconds: 200 | ||
| 218 | + shell: bash | ||
| 219 | + command: | | ||
| 220 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 221 | + git config --global user.name "Fangjun Kuang" | ||
| 222 | + | ||
| 223 | + rm -rf huggingface | ||
| 224 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 225 | + | ||
| 226 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-en-wenet-librispeech huggingface | ||
| 227 | + cd huggingface | ||
| 228 | + git fetch | ||
| 229 | + git pull | ||
| 230 | + | ||
| 231 | + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/*.onnx . | ||
| 232 | + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/units.txt tokens.txt | ||
| 233 | + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/README.md . | ||
| 234 | + | ||
| 235 | + if [ ! -d test_wavs ]; then | ||
| 236 | + mkdir test_wavs | ||
| 237 | + cd test_wavs | ||
| 238 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/0.wav | ||
| 239 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/1.wav | ||
| 240 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/8k.wav | ||
| 241 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/trans.txt | ||
| 242 | + cd .. | ||
| 243 | + fi | ||
| 244 | + git lfs track "*.onnx" | ||
| 245 | + git add . | ||
| 246 | + | ||
| 247 | + git commit -m "add librispeech models" | ||
| 248 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-en-wenet-librispeech main || true | ||
| 249 | + | ||
| 250 | + cd .. | ||
| 251 | + rm -rf huggingface | ||
| 252 | + | ||
| 253 | + - name: Publish to huggingface (gigaspeech) | ||
| 254 | + env: | ||
| 255 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 256 | + uses: nick-fields/retry@v2 | ||
| 257 | + with: | ||
| 258 | + max_attempts: 20 | ||
| 259 | + timeout_seconds: 200 | ||
| 260 | + shell: bash | ||
| 261 | + command: | | ||
| 262 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 263 | + git config --global user.name "Fangjun Kuang" | ||
| 264 | + | ||
| 265 | + rm -rf huggingface | ||
| 266 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 267 | + | ||
| 268 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-en-wenet-gigaspeech huggingface | ||
| 269 | + cd huggingface | ||
| 270 | + git fetch | ||
| 271 | + git pull | ||
| 272 | + | ||
| 273 | + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/*.onnx . | ||
| 274 | + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/units.txt tokens.txt | ||
| 275 | + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/README.md . | ||
| 276 | + | ||
| 277 | + if [ ! -d test_wavs ]; then | ||
| 278 | + mkdir test_wavs | ||
| 279 | + cd test_wavs | ||
| 280 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/0.wav | ||
| 281 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/1.wav | ||
| 282 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/8k.wav | ||
| 283 | + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/trans.txt | ||
| 284 | + cd .. | ||
| 285 | + fi | ||
| 286 | + git lfs track "*.onnx" | ||
| 287 | + git add . | ||
| 288 | + | ||
| 289 | + git commit -m "add gigaspeech models" | ||
| 290 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-en-wenet-gigaspeech main || true | ||
| 291 | + | ||
| 292 | + cd .. | ||
| 293 | + rm -rf huggingface |
scripts/wenet/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains script for exporting models | ||
| 4 | +from [wenet](https://github.com/wenet-e2e/wenet) | ||
| 5 | +to onnx. You can use the exported models in sherpa-onnx. | ||
| 6 | + | ||
| 7 | +Note that both **streaming** and **non-streaming** models are supported. | ||
| 8 | + | ||
| 9 | +We only use the CTC branch. Rescore with the attention decoder | ||
| 10 | +is not supported, though decoding with H, HL, and HLG is supported. |
scripts/wenet/export-onnx-streaming.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +# pip install git+https://github.com/wenet-e2e/wenet.git | ||
| 5 | +# pip install onnxruntime onnx pyyaml | ||
| 6 | +# cp -a ~/open-source/wenet/wenet/transducer/search . | ||
| 7 | +# cp -a ~/open-source//wenet/wenet/e_branchformer . | ||
| 8 | +# cp -a ~/open-source/wenet/wenet/ctl_model . | ||
| 9 | + | ||
| 10 | +import os | ||
| 11 | +from typing import Dict | ||
| 12 | + | ||
| 13 | +import onnx | ||
| 14 | +import torch | ||
| 15 | +import yaml | ||
| 16 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 17 | + | ||
| 18 | +from wenet.utils.init_model import init_model | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 22 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 23 | + | ||
| 24 | + Args: | ||
| 25 | + filename: | ||
| 26 | + Filename of the ONNX model to be changed. | ||
| 27 | + meta_data: | ||
| 28 | + Key-value pairs. | ||
| 29 | + """ | ||
| 30 | + model = onnx.load(filename) | ||
| 31 | + for key, value in meta_data.items(): | ||
| 32 | + meta = model.metadata_props.add() | ||
| 33 | + meta.key = key | ||
| 34 | + meta.value = str(value) | ||
| 35 | + | ||
| 36 | + onnx.save(model, filename) | ||
| 37 | + | ||
| 38 | + | ||
| 39 | +class OnnxModel(torch.nn.Module): | ||
| 40 | + def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module): | ||
| 41 | + super().__init__() | ||
| 42 | + self.encoder = encoder | ||
| 43 | + self.ctc = ctc | ||
| 44 | + | ||
| 45 | + def forward( | ||
| 46 | + self, | ||
| 47 | + x: torch.Tensor, | ||
| 48 | + offset: torch.Tensor, | ||
| 49 | + required_cache_size: torch.Tensor, | ||
| 50 | + attn_cache: torch.Tensor, | ||
| 51 | + conv_cache: torch.Tensor, | ||
| 52 | + attn_mask: torch.Tensor, | ||
| 53 | + ): | ||
| 54 | + """ | ||
| 55 | + Args: | ||
| 56 | + x: | ||
| 57 | + A 3-D float32 tensor of shape (N, T, C). It supports only N == 1. | ||
| 58 | + offset: | ||
| 59 | + A scalar of dtype torch.int64. | ||
| 60 | + required_cache_size: | ||
| 61 | + A scalar of dtype torch.int64. | ||
| 62 | + attn_cache: | ||
| 63 | + A 4-D float32 tensor of shape (num_blocks, head, required_cache_size, encoder_output_size / head /2). | ||
| 64 | + conv_cache: | ||
| 65 | + A 4-D float32 tensor of shape (num_blocks, N, encoder_output_size, cnn_module_kernel - 1). | ||
| 66 | + attn_mask: | ||
| 67 | + A 3-D bool tensor of shape (N, 1, required_cache_size + chunk_size) | ||
| 68 | + Returns: | ||
| 69 | + Return a tuple of 3 tensors: | ||
| 70 | + - A 3-D float32 tensor of shape (N, T, C) containing log_probs | ||
| 71 | + - next_attn_cache | ||
| 72 | + - next_conv_cache | ||
| 73 | + """ | ||
| 74 | + encoder_out, next_att_cache, next_conv_cache = self.encoder.forward_chunk( | ||
| 75 | + xs=x, | ||
| 76 | + offset=offset, | ||
| 77 | + required_cache_size=required_cache_size, | ||
| 78 | + att_cache=attn_cache, | ||
| 79 | + cnn_cache=conv_cache, | ||
| 80 | + att_mask=attn_mask, | ||
| 81 | + ) | ||
| 82 | + log_probs = self.ctc.log_softmax(encoder_out) | ||
| 83 | + | ||
| 84 | + return log_probs, next_att_cache, next_conv_cache | ||
| 85 | + | ||
| 86 | + | ||
| 87 | +class Foo: | ||
| 88 | + pass | ||
| 89 | + | ||
| 90 | + | ||
| 91 | +@torch.no_grad() | ||
| 92 | +def main(): | ||
| 93 | + args = Foo() | ||
| 94 | + args.checkpoint = "./final.pt" | ||
| 95 | + config_file = "./train.yaml" | ||
| 96 | + | ||
| 97 | + with open(config_file, "r") as fin: | ||
| 98 | + configs = yaml.load(fin, Loader=yaml.FullLoader) | ||
| 99 | + torch_model, configs = init_model(args, configs) | ||
| 100 | + torch_model.eval() | ||
| 101 | + | ||
| 102 | + head = configs["encoder_conf"]["attention_heads"] | ||
| 103 | + num_blocks = configs["encoder_conf"]["num_blocks"] | ||
| 104 | + output_size = configs["encoder_conf"]["output_size"] | ||
| 105 | + cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) | ||
| 106 | + | ||
| 107 | + right_context = torch_model.right_context() | ||
| 108 | + subsampling_factor = torch_model.encoder.embed.subsampling_rate | ||
| 109 | + chunk_size = 16 | ||
| 110 | + left_chunks = 4 | ||
| 111 | + | ||
| 112 | + decoding_window = (chunk_size - 1) * subsampling_factor + right_context + 1 | ||
| 113 | + | ||
| 114 | + required_cache_size = chunk_size * left_chunks | ||
| 115 | + | ||
| 116 | + offset = required_cache_size | ||
| 117 | + | ||
| 118 | + attn_cache = torch.zeros( | ||
| 119 | + num_blocks, | ||
| 120 | + head, | ||
| 121 | + required_cache_size, | ||
| 122 | + output_size // head * 2, | ||
| 123 | + dtype=torch.float32, | ||
| 124 | + ) | ||
| 125 | + | ||
| 126 | + attn_mask = torch.ones(1, 1, required_cache_size + chunk_size, dtype=torch.bool) | ||
| 127 | + attn_mask[:, :, :required_cache_size] = 0 | ||
| 128 | + | ||
| 129 | + conv_cache = torch.zeros( | ||
| 130 | + num_blocks, 1, output_size, cnn_module_kernel - 1, dtype=torch.float32 | ||
| 131 | + ) | ||
| 132 | + | ||
| 133 | + sos = torch_model.sos_symbol() | ||
| 134 | + eos = torch_model.eos_symbol() | ||
| 135 | + | ||
| 136 | + onnx_model = OnnxModel( | ||
| 137 | + encoder=torch_model.encoder, | ||
| 138 | + ctc=torch_model.ctc, | ||
| 139 | + ) | ||
| 140 | + filename = "model-streaming.onnx" | ||
| 141 | + | ||
| 142 | + N = 1 | ||
| 143 | + T = decoding_window | ||
| 144 | + C = 80 | ||
| 145 | + x = torch.rand(N, T, C, dtype=torch.float32) | ||
| 146 | + offset = torch.tensor([offset], dtype=torch.int64) | ||
| 147 | + required_cache_size = torch.tensor([required_cache_size], dtype=torch.int64) | ||
| 148 | + | ||
| 149 | + opset_version = 13 | ||
| 150 | + torch.onnx.export( | ||
| 151 | + onnx_model, | ||
| 152 | + (x, offset, required_cache_size, attn_cache, conv_cache, attn_mask), | ||
| 153 | + filename, | ||
| 154 | + opset_version=opset_version, | ||
| 155 | + input_names=[ | ||
| 156 | + "x", | ||
| 157 | + "offset", | ||
| 158 | + "required_cache_size", | ||
| 159 | + "attn_cache", | ||
| 160 | + "conv_cache", | ||
| 161 | + "attn_mask", | ||
| 162 | + ], | ||
| 163 | + output_names=["log_probs", "next_att_cache", "next_conv_cache"], | ||
| 164 | + dynamic_axes={ | ||
| 165 | + "x": {0: "N", 1: "T"}, | ||
| 166 | + "attn_cache": {2: "T"}, | ||
| 167 | + "log_probs": {0: "N"}, | ||
| 168 | + "new_attn_cache": {2: "T"}, | ||
| 169 | + }, | ||
| 170 | + ) | ||
| 171 | + | ||
| 172 | + # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | ||
| 173 | + url = os.environ.get("WENET_URL", "") | ||
| 174 | + meta_data = { | ||
| 175 | + "model_type": "wenet-ctc", | ||
| 176 | + "version": "1", | ||
| 177 | + "model_author": "wenet", | ||
| 178 | + "comment": "streaming", | ||
| 179 | + "url": "https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz", | ||
| 180 | + "chunk_size": chunk_size, | ||
| 181 | + "left_chunks": left_chunks, | ||
| 182 | + "head": head, | ||
| 183 | + "num_blocks": num_blocks, | ||
| 184 | + "output_size": output_size, | ||
| 185 | + "cnn_module_kernel": cnn_module_kernel, | ||
| 186 | + "right_context": right_context, | ||
| 187 | + "subsampling_factor": subsampling_factor, | ||
| 188 | + } | ||
| 189 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 190 | + | ||
| 191 | + print("Generate int8 quantization models") | ||
| 192 | + | ||
| 193 | + filename_int8 = f"model-streaming.int8.onnx" | ||
| 194 | + quantize_dynamic( | ||
| 195 | + model_input=filename, | ||
| 196 | + model_output=filename_int8, | ||
| 197 | + op_types_to_quantize=["MatMul"], | ||
| 198 | + weight_type=QuantType.QInt8, | ||
| 199 | + ) | ||
| 200 | + | ||
| 201 | + | ||
| 202 | +if __name__ == "__main__": | ||
| 203 | + main() |
scripts/wenet/export-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +# pip install git+https://github.com/wenet-e2e/wenet.git | ||
| 5 | +# pip install onnxruntime onnx pyyaml | ||
| 6 | +# cp -a ~/open-source/wenet/wenet/transducer/search . | ||
| 7 | +# cp -a ~/open-source//wenet/wenet/e_branchformer . | ||
| 8 | +# cp -a ~/open-source/wenet/wenet/ctl_model . | ||
| 9 | + | ||
| 10 | +import os | ||
| 11 | +from typing import Dict | ||
| 12 | + | ||
| 13 | +import onnx | ||
| 14 | +import torch | ||
| 15 | +import yaml | ||
| 16 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 17 | + | ||
| 18 | +from wenet.utils.init_model import init_model | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +class Foo: | ||
| 22 | + pass | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 26 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 27 | + | ||
| 28 | + Args: | ||
| 29 | + filename: | ||
| 30 | + Filename of the ONNX model to be changed. | ||
| 31 | + meta_data: | ||
| 32 | + Key-value pairs. | ||
| 33 | + """ | ||
| 34 | + model = onnx.load(filename) | ||
| 35 | + for key, value in meta_data.items(): | ||
| 36 | + meta = model.metadata_props.add() | ||
| 37 | + meta.key = key | ||
| 38 | + meta.value = str(value) | ||
| 39 | + | ||
| 40 | + onnx.save(model, filename) | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +class OnnxModel(torch.nn.Module): | ||
| 44 | + def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module): | ||
| 45 | + super().__init__() | ||
| 46 | + self.encoder = encoder | ||
| 47 | + self.ctc = ctc | ||
| 48 | + | ||
| 49 | + def forward(self, x, x_lens): | ||
| 50 | + """ | ||
| 51 | + Args: | ||
| 52 | + x: | ||
| 53 | + A 3-D tensor of shape (N, T, C) | ||
| 54 | + x_lens: | ||
| 55 | + A 1-D tensor of shape (N,) containing valid lengths in x before | ||
| 56 | + padding. Its type is torch.int64 | ||
| 57 | + """ | ||
| 58 | + encoder_out, encoder_out_mask = self.encoder( | ||
| 59 | + x, | ||
| 60 | + x_lens, | ||
| 61 | + decoding_chunk_size=-1, | ||
| 62 | + num_decoding_left_chunks=-1, | ||
| 63 | + ) | ||
| 64 | + log_probs = self.ctc.log_softmax(encoder_out) | ||
| 65 | + log_probs_lens = encoder_out_mask.int().squeeze(1).sum(1) | ||
| 66 | + | ||
| 67 | + return log_probs, log_probs_lens | ||
| 68 | + | ||
| 69 | + | ||
| 70 | +@torch.no_grad() | ||
| 71 | +def main(): | ||
| 72 | + args = Foo() | ||
| 73 | + args.checkpoint = "./final.pt" | ||
| 74 | + config_file = "./train.yaml" | ||
| 75 | + | ||
| 76 | + with open(config_file, "r") as fin: | ||
| 77 | + configs = yaml.load(fin, Loader=yaml.FullLoader) | ||
| 78 | + torch_model, configs = init_model(args, configs) | ||
| 79 | + torch_model.eval() | ||
| 80 | + | ||
| 81 | + onnx_model = OnnxModel(encoder=torch_model.encoder, ctc=torch_model.ctc) | ||
| 82 | + filename = "model.onnx" | ||
| 83 | + | ||
| 84 | + N = 1 | ||
| 85 | + T = 1000 | ||
| 86 | + C = 80 | ||
| 87 | + x = torch.rand(N, T, C, dtype=torch.float) | ||
| 88 | + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64) | ||
| 89 | + | ||
| 90 | + opset_version = 13 | ||
| 91 | + onnx_model = torch.jit.script(onnx_model) | ||
| 92 | + torch.onnx.export( | ||
| 93 | + onnx_model, | ||
| 94 | + (x, x_lens), | ||
| 95 | + filename, | ||
| 96 | + opset_version=opset_version, | ||
| 97 | + input_names=["x", "x_lens"], | ||
| 98 | + output_names=["log_probs", "log_probs_lens"], | ||
| 99 | + dynamic_axes={ | ||
| 100 | + "x": {0: "N", 1: "T"}, | ||
| 101 | + "x_lens": {0: "N"}, | ||
| 102 | + "log_probs": {0: "N", 1: "T"}, | ||
| 103 | + "log_probs_lens": {0: "N"}, | ||
| 104 | + }, | ||
| 105 | + ) | ||
| 106 | + | ||
| 107 | + # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | ||
| 108 | + url = os.environ.get("WENET_URL", "") | ||
| 109 | + meta_data = { | ||
| 110 | + "model_type": "wenet-ctc", | ||
| 111 | + "version": "1", | ||
| 112 | + "model_author": "wenet", | ||
| 113 | + "comment": "non-streaming", | ||
| 114 | + "url": url, | ||
| 115 | + } | ||
| 116 | + add_meta_data(filename=filename, meta_data=meta_data) | ||
| 117 | + | ||
| 118 | + print("Generate int8 quantization models") | ||
| 119 | + | ||
| 120 | + filename_int8 = f"model.int8.onnx" | ||
| 121 | + quantize_dynamic( | ||
| 122 | + model_input=filename, | ||
| 123 | + model_output=filename_int8, | ||
| 124 | + op_types_to_quantize=["MatMul"], | ||
| 125 | + weight_type=QuantType.QInt8, | ||
| 126 | + ) | ||
| 127 | + | ||
| 128 | + | ||
| 129 | +if __name__ == "__main__": | ||
| 130 | + main() |
scripts/wenet/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# | ||
| 3 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 4 | +# | ||
| 5 | +# Please refer to | ||
| 6 | +# https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.en.md | ||
| 7 | +# for a table of pre-trained models. | ||
| 8 | +# Please select the column "Checkpoint Model" for downloading. | ||
| 9 | + | ||
| 10 | +set -ex | ||
| 11 | + | ||
| 12 | +function install_dependencies() { | ||
| 13 | + pip install soundfile | ||
| 14 | + pip install torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html | ||
| 15 | + pip install k2==1.24.4.dev20231022+cpu.torch2.1.0 -f https://k2-fsa.github.io/k2/cpu.html | ||
| 16 | + | ||
| 17 | + pip install onnxruntime onnx kaldi-native-fbank pyyaml | ||
| 18 | + | ||
| 19 | + pip install git+https://github.com/wenet-e2e/wenet.git | ||
| 20 | + wenet_dir=$(dirname $(python3 -c "import wenet; print(wenet.__file__)")) | ||
| 21 | + git clone https://github.com/wenet-e2e/wenet | ||
| 22 | + if [ ! -d $wenet_dir/transducer/search ]; then | ||
| 23 | + cp -av ./wenet/wenet/transducer/search $wenet_dir/transducer | ||
| 24 | + fi | ||
| 25 | + | ||
| 26 | + if [ ! -d $wenet_dir/e_branchformer ]; then | ||
| 27 | + cp -a .//wenet/wenet/e_branchformer $wenet_dir | ||
| 28 | + fi | ||
| 29 | + | ||
| 30 | + if [ ! -d $wenet_dir/ctl_model ]; then | ||
| 31 | + cp -a ./wenet/wenet/ctl_model $wenet_dir | ||
| 32 | + fi | ||
| 33 | + | ||
| 34 | + rm -rf wenet | ||
| 35 | +} | ||
| 36 | + | ||
| 37 | +function aishell() { | ||
| 38 | + echo "aishell" | ||
| 39 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/aishell_u2pp_conformer_exp.tar.gz | ||
| 40 | + tar xvf aishell_u2pp_conformer_exp.tar.gz | ||
| 41 | + rm -v aishell_u2pp_conformer_exp.tar.gz | ||
| 42 | + | ||
| 43 | + pushd aishell_u2pp_conformer_exp | ||
| 44 | + mkdir -p exp/20210601_u2++_conformer_exp | ||
| 45 | + cp global_cmvn ./exp/20210601_u2++_conformer_exp | ||
| 46 | + cp ../*.py . | ||
| 47 | + | ||
| 48 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | ||
| 49 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav | ||
| 50 | + soxi 0.wav | ||
| 51 | + | ||
| 52 | + echo "Test streaming" | ||
| 53 | + ./export-onnx-streaming.py | ||
| 54 | + ls -lh | ||
| 55 | + ./test-onnx-streaming.py | ||
| 56 | + | ||
| 57 | + echo "Test non-streaming" | ||
| 58 | + ./export-onnx.py | ||
| 59 | + ls -lh | ||
| 60 | + ./test-onnx.py | ||
| 61 | + | ||
| 62 | + cat > README.md <<EOF | ||
| 63 | +# Introduction | ||
| 64 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz | ||
| 65 | +EOF | ||
| 66 | + | ||
| 67 | + popd | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +function aishell2() { | ||
| 71 | + echo "aishell2" | ||
| 72 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/aishell2_u2pp_conformer_exp.tar.gz | ||
| 73 | + tar xvf aishell2_u2pp_conformer_exp.tar.gz | ||
| 74 | + rm -v aishell2_u2pp_conformer_exp.tar.gz | ||
| 75 | + | ||
| 76 | + pushd aishell2_u2pp_conformer_exp | ||
| 77 | + mkdir -p exp/u2++_conformer | ||
| 78 | + cp global_cmvn ./exp/u2++_conformer | ||
| 79 | + cp ../*.py . | ||
| 80 | + | ||
| 81 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=aishell2_u2pp_conformer_exp.tar.gz | ||
| 82 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav | ||
| 83 | + soxi 0.wav | ||
| 84 | + | ||
| 85 | + echo "Test streaming" | ||
| 86 | + ./export-onnx-streaming.py | ||
| 87 | + ls -lh | ||
| 88 | + ./test-onnx-streaming.py | ||
| 89 | + | ||
| 90 | + echo "Test non-streaming" | ||
| 91 | + ./export-onnx.py | ||
| 92 | + ls -lh | ||
| 93 | + ./test-onnx.py | ||
| 94 | + | ||
| 95 | + cat > README.md <<EOF | ||
| 96 | +# Introduction | ||
| 97 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=aishell2_u2pp_conformer_exp.tar.gz | ||
| 98 | +EOF | ||
| 99 | + | ||
| 100 | + popd | ||
| 101 | +} | ||
| 102 | + | ||
| 103 | +function multi_cn() { | ||
| 104 | + echo "multi_cn" | ||
| 105 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/multi_cn_unified_conformer_exp.tar.gz | ||
| 106 | + tar xvf multi_cn_unified_conformer_exp.tar.gz | ||
| 107 | + rm -v multi_cn_unified_conformer_exp.tar.gz | ||
| 108 | + | ||
| 109 | + pushd multi_cn_unified_conformer_exp | ||
| 110 | + mkdir -p exp/20210815_unified_conformer_exp | ||
| 111 | + cp global_cmvn ./exp/20210815_unified_conformer_exp | ||
| 112 | + cp ../*.py . | ||
| 113 | + | ||
| 114 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=multi_cn_unified_conformer_exp.tar.gz | ||
| 115 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav | ||
| 116 | + soxi 0.wav | ||
| 117 | + | ||
| 118 | + echo "Test streaming" | ||
| 119 | + ./export-onnx-streaming.py | ||
| 120 | + ls -lh | ||
| 121 | + ./test-onnx-streaming.py | ||
| 122 | + | ||
| 123 | + echo "Test non-streaming" | ||
| 124 | + ./export-onnx.py | ||
| 125 | + ls -lh | ||
| 126 | + ./test-onnx.py | ||
| 127 | + | ||
| 128 | + cat > README.md <<EOF | ||
| 129 | +# Introduction | ||
| 130 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=multi_cn_unified_conformer_exp.tar.gz | ||
| 131 | +EOF | ||
| 132 | + | ||
| 133 | + popd | ||
| 134 | +} | ||
| 135 | + | ||
| 136 | +function wenetspeech() { | ||
| 137 | + echo "wenetspeech" | ||
| 138 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/wenetspeech_u2pp_conformer_exp.tar.gz | ||
| 139 | + tar xvf wenetspeech_u2pp_conformer_exp.tar.gz | ||
| 140 | + rm -v wenetspeech_u2pp_conformer_exp.tar.gz | ||
| 141 | + | ||
| 142 | + pushd 20220506_u2pp_conformer_exp | ||
| 143 | + mkdir -p exp/20220506_u2pp_conformer_exp | ||
| 144 | + cp global_cmvn ./exp/20220506_u2pp_conformer_exp | ||
| 145 | + cp ../*.py . | ||
| 146 | + | ||
| 147 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=wenetspeech_u2pp_conformer_exp.tar.gz | ||
| 148 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav | ||
| 149 | + soxi 0.wav | ||
| 150 | + | ||
| 151 | + echo "Test streaming" | ||
| 152 | + ./export-onnx-streaming.py | ||
| 153 | + ls -lh | ||
| 154 | + ./test-onnx-streaming.py | ||
| 155 | + | ||
| 156 | + echo "Test non-streaming" | ||
| 157 | + ./export-onnx.py | ||
| 158 | + ls -lh | ||
| 159 | + ./test-onnx.py | ||
| 160 | + | ||
| 161 | + cat > README.md <<EOF | ||
| 162 | +# Introduction | ||
| 163 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=wenetspeech_u2pp_conformer_exp.tar.gz | ||
| 164 | +EOF | ||
| 165 | + | ||
| 166 | + popd | ||
| 167 | +} | ||
| 168 | + | ||
| 169 | +function librispeech() { | ||
| 170 | + echo "librispeech" | ||
| 171 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/librispeech_u2pp_conformer_exp.tar.gz | ||
| 172 | + tar xvf librispeech_u2pp_conformer_exp.tar.gz | ||
| 173 | + rm -v librispeech_u2pp_conformer_exp.tar.gz | ||
| 174 | + | ||
| 175 | + pushd librispeech_u2pp_conformer_exp | ||
| 176 | + mkdir -p data/train_960 | ||
| 177 | + cp global_cmvn ./data/train_960 | ||
| 178 | + cp ../*.py . | ||
| 179 | + | ||
| 180 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=librispeech_u2pp_conformer_exp.tar.gz | ||
| 181 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/en.wav | ||
| 182 | + soxi 0.wav | ||
| 183 | + | ||
| 184 | + echo "Test streaming" | ||
| 185 | + ./export-onnx-streaming.py | ||
| 186 | + ls -lh | ||
| 187 | + ./test-onnx-streaming.py | ||
| 188 | + | ||
| 189 | + echo "Test non-streaming" | ||
| 190 | + ./export-onnx.py | ||
| 191 | + ls -lh | ||
| 192 | + ./test-onnx.py | ||
| 193 | + | ||
| 194 | + cat > README.md <<EOF | ||
| 195 | +# Introduction | ||
| 196 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=librispeech_u2pp_conformer_exp.tar.gz | ||
| 197 | +EOF | ||
| 198 | + | ||
| 199 | + popd | ||
| 200 | +} | ||
| 201 | + | ||
| 202 | +function gigaspeech() { | ||
| 203 | + echo "gigaspeech" | ||
| 204 | + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/gigaspeech_u2pp_conformer_exp.tar.gz | ||
| 205 | + tar xvf gigaspeech_u2pp_conformer_exp.tar.gz | ||
| 206 | + rm -v gigaspeech_u2pp_conformer_exp.tar.gz | ||
| 207 | + | ||
| 208 | + pushd 20210728_u2pp_conformer_exp | ||
| 209 | + mkdir -p data/gigaspeech_train_xl | ||
| 210 | + cp global_cmvn ./data/gigaspeech_train_xl | ||
| 211 | + cp ../*.py . | ||
| 212 | + | ||
| 213 | + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=gigaspeech_u2pp_conformer_exp.tar.gz | ||
| 214 | + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/en.wav | ||
| 215 | + soxi 0.wav | ||
| 216 | + | ||
| 217 | + echo "Test streaming" | ||
| 218 | + ./export-onnx-streaming.py | ||
| 219 | + ls -lh | ||
| 220 | + ./test-onnx-streaming.py | ||
| 221 | + | ||
| 222 | + echo "Test non-streaming" | ||
| 223 | + ./export-onnx.py | ||
| 224 | + ls -lh | ||
| 225 | + ./test-onnx.py | ||
| 226 | + | ||
| 227 | + cat > README.md <<EOF | ||
| 228 | +# Introduction | ||
| 229 | +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=gigaspeech_u2pp_conformer_exp.tar.gz | ||
| 230 | +EOF | ||
| 231 | + | ||
| 232 | + popd | ||
| 233 | +} | ||
| 234 | + | ||
| 235 | +install_dependencies | ||
| 236 | + | ||
| 237 | +aishell | ||
| 238 | + | ||
| 239 | +aishell2 | ||
| 240 | + | ||
| 241 | +multi_cn | ||
| 242 | + | ||
| 243 | +wenetspeech | ||
| 244 | + | ||
| 245 | +librispeech | ||
| 246 | + | ||
| 247 | +gigaspeech | ||
| 248 | + | ||
| 249 | +tree . |
scripts/wenet/test-onnx-streaming.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import kaldi_native_fbank as knf | ||
| 5 | +import onnxruntime as ort | ||
| 6 | +import torch | ||
| 7 | +import torchaudio | ||
| 8 | +from torch.nn.utils.rnn import pad_sequence | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +class OnnxModel: | ||
| 12 | + def __init__( | ||
| 13 | + self, | ||
| 14 | + filename: str, | ||
| 15 | + ): | ||
| 16 | + session_opts = ort.SessionOptions() | ||
| 17 | + session_opts.inter_op_num_threads = 1 | ||
| 18 | + session_opts.intra_op_num_threads = 4 | ||
| 19 | + | ||
| 20 | + self.session_opts = session_opts | ||
| 21 | + | ||
| 22 | + self.model = ort.InferenceSession( | ||
| 23 | + filename, | ||
| 24 | + sess_options=self.session_opts, | ||
| 25 | + providers=["CPUExecutionProvider"], | ||
| 26 | + ) | ||
| 27 | + | ||
| 28 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 29 | + self.left_chunks = int(meta["left_chunks"]) | ||
| 30 | + self.num_blocks = int(meta["num_blocks"]) | ||
| 31 | + self.chunk_size = int(meta["chunk_size"]) | ||
| 32 | + self.head = int(meta["head"]) | ||
| 33 | + self.output_size = int(meta["output_size"]) | ||
| 34 | + self.cnn_module_kernel = int(meta["cnn_module_kernel"]) | ||
| 35 | + self.right_context = int(meta["right_context"]) | ||
| 36 | + self.subsampling_factor = int(meta["subsampling_factor"]) | ||
| 37 | + | ||
| 38 | + self._init_cache() | ||
| 39 | + | ||
| 40 | + def _init_cache(self): | ||
| 41 | + required_cache_size = self.chunk_size * self.left_chunks | ||
| 42 | + | ||
| 43 | + self.attn_cache = torch.zeros( | ||
| 44 | + self.num_blocks, | ||
| 45 | + self.head, | ||
| 46 | + required_cache_size, | ||
| 47 | + self.output_size // self.head * 2, | ||
| 48 | + dtype=torch.float32, | ||
| 49 | + ).numpy() | ||
| 50 | + | ||
| 51 | + self.conv_cache = torch.zeros( | ||
| 52 | + self.num_blocks, | ||
| 53 | + 1, | ||
| 54 | + self.output_size, | ||
| 55 | + self.cnn_module_kernel - 1, | ||
| 56 | + dtype=torch.float32, | ||
| 57 | + ).numpy() | ||
| 58 | + | ||
| 59 | + self.offset = torch.tensor([required_cache_size], dtype=torch.int64).numpy() | ||
| 60 | + | ||
| 61 | + self.required_cache_size = torch.tensor( | ||
| 62 | + [self.chunk_size * self.left_chunks], dtype=torch.int64 | ||
| 63 | + ).numpy() | ||
| 64 | + | ||
| 65 | + def __call__(self, x: torch.Tensor) -> torch.Tensor: | ||
| 66 | + """ | ||
| 67 | + Args: | ||
| 68 | + x: | ||
| 69 | + A 2-D tensor of shape (T, C) | ||
| 70 | + Returns: | ||
| 71 | + Return a 2-D tensor of shape (T, C) containing log_probs. | ||
| 72 | + """ | ||
| 73 | + attn_mask = torch.ones( | ||
| 74 | + 1, 1, int(self.required_cache_size + self.chunk_size), dtype=torch.bool | ||
| 75 | + ) | ||
| 76 | + chunk_idx = self.offset // self.chunk_size - self.left_chunks | ||
| 77 | + if chunk_idx < self.left_chunks: | ||
| 78 | + attn_mask[ | ||
| 79 | + :, :, : int(self.required_cache_size - chunk_idx * self.chunk_size) | ||
| 80 | + ] = False | ||
| 81 | + | ||
| 82 | + log_probs, new_attn_cache, new_conv_cache = self.model.run( | ||
| 83 | + [ | ||
| 84 | + self.model.get_outputs()[0].name, | ||
| 85 | + self.model.get_outputs()[1].name, | ||
| 86 | + self.model.get_outputs()[2].name, | ||
| 87 | + ], | ||
| 88 | + { | ||
| 89 | + self.model.get_inputs()[0].name: x.unsqueeze(0).numpy(), | ||
| 90 | + self.model.get_inputs()[1].name: self.offset, | ||
| 91 | + self.model.get_inputs()[2].name: self.required_cache_size, | ||
| 92 | + self.model.get_inputs()[3].name: self.attn_cache, | ||
| 93 | + self.model.get_inputs()[4].name: self.conv_cache, | ||
| 94 | + self.model.get_inputs()[5].name: attn_mask.numpy(), | ||
| 95 | + }, | ||
| 96 | + ) | ||
| 97 | + | ||
| 98 | + self.attn_cache = new_attn_cache | ||
| 99 | + self.conv_cache = new_conv_cache | ||
| 100 | + | ||
| 101 | + log_probs = torch.from_numpy(log_probs) | ||
| 102 | + | ||
| 103 | + self.offset += log_probs.shape[1] | ||
| 104 | + | ||
| 105 | + return log_probs.squeeze(0) | ||
| 106 | + | ||
| 107 | + | ||
| 108 | +def get_features(test_wav_filename): | ||
| 109 | + wave, sample_rate = torchaudio.load(test_wav_filename) | ||
| 110 | + audio = wave[0].contiguous() # only use the first channel | ||
| 111 | + if sample_rate != 16000: | ||
| 112 | + audio = torchaudio.functional.resample( | ||
| 113 | + audio, orig_freq=sample_rate, new_freq=16000 | ||
| 114 | + ) | ||
| 115 | + audio *= 372768 | ||
| 116 | + | ||
| 117 | + opts = knf.FbankOptions() | ||
| 118 | + opts.frame_opts.dither = 0 | ||
| 119 | + opts.mel_opts.num_bins = 80 | ||
| 120 | + opts.frame_opts.snip_edges = False | ||
| 121 | + opts.mel_opts.debug_mel = False | ||
| 122 | + | ||
| 123 | + fbank = knf.OnlineFbank(opts) | ||
| 124 | + fbank.accept_waveform(16000, audio.numpy()) | ||
| 125 | + frames = [] | ||
| 126 | + for i in range(fbank.num_frames_ready): | ||
| 127 | + frames.append(torch.from_numpy(fbank.get_frame(i))) | ||
| 128 | + frames = torch.stack(frames) | ||
| 129 | + return frames | ||
| 130 | + | ||
| 131 | + | ||
| 132 | +def main(): | ||
| 133 | + model_filename = "./model-streaming.onnx" | ||
| 134 | + model = OnnxModel(model_filename) | ||
| 135 | + | ||
| 136 | + filename = "./0.wav" | ||
| 137 | + x = get_features(filename) | ||
| 138 | + | ||
| 139 | + padding = torch.zeros(int(16000 * 0.5), 80) | ||
| 140 | + x = torch.cat([x, padding], dim=0) | ||
| 141 | + | ||
| 142 | + chunk_length = ( | ||
| 143 | + (model.chunk_size - 1) * model.subsampling_factor + model.right_context + 1 | ||
| 144 | + ) | ||
| 145 | + chunk_length = int(chunk_length) | ||
| 146 | + chunk_shift = int(model.required_cache_size) | ||
| 147 | + print(chunk_length, chunk_shift) | ||
| 148 | + | ||
| 149 | + num_frames = x.shape[0] | ||
| 150 | + n = (num_frames - chunk_length) // chunk_shift + 1 | ||
| 151 | + tokens = [] | ||
| 152 | + for i in range(n): | ||
| 153 | + start = i * chunk_shift | ||
| 154 | + end = start + chunk_length | ||
| 155 | + frames = x[start:end, :] | ||
| 156 | + log_probs = model(frames) | ||
| 157 | + | ||
| 158 | + indexes = log_probs.argmax(dim=1) | ||
| 159 | + indexes = torch.unique_consecutive(indexes) | ||
| 160 | + indexes = indexes[indexes != 0].tolist() | ||
| 161 | + if indexes: | ||
| 162 | + tokens.extend(indexes) | ||
| 163 | + | ||
| 164 | + id2word = dict() | ||
| 165 | + with open("./units.txt", encoding="utf-8") as f: | ||
| 166 | + for line in f: | ||
| 167 | + word, idx = line.strip().split() | ||
| 168 | + id2word[int(idx)] = word | ||
| 169 | + text = "".join([id2word[i] for i in tokens]) | ||
| 170 | + print(text) | ||
| 171 | + | ||
| 172 | + | ||
| 173 | +if __name__ == "__main__": | ||
| 174 | + main() |
scripts/wenet/test-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import kaldi_native_fbank as knf | ||
| 5 | +import onnxruntime as ort | ||
| 6 | +import torch | ||
| 7 | +import torchaudio | ||
| 8 | +from torch.nn.utils.rnn import pad_sequence | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +class OnnxModel: | ||
| 12 | + def __init__( | ||
| 13 | + self, | ||
| 14 | + filename: str, | ||
| 15 | + ): | ||
| 16 | + session_opts = ort.SessionOptions() | ||
| 17 | + session_opts.inter_op_num_threads = 1 | ||
| 18 | + session_opts.intra_op_num_threads = 4 | ||
| 19 | + | ||
| 20 | + self.session_opts = session_opts | ||
| 21 | + | ||
| 22 | + self.model = ort.InferenceSession( | ||
| 23 | + filename, | ||
| 24 | + sess_options=self.session_opts, | ||
| 25 | + providers=["CPUExecutionProvider"], | ||
| 26 | + ) | ||
| 27 | + | ||
| 28 | + def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: | ||
| 29 | + """ | ||
| 30 | + Args: | ||
| 31 | + x: | ||
| 32 | + A 3-D tensor of shape (N, T, C) | ||
| 33 | + x_lens: | ||
| 34 | + A 1-D tensor of shape (N,). Its dtype is torch.int64 | ||
| 35 | + Returns: | ||
| 36 | + Return a 3-D tensor of shape (N, T, C) containing log_probs. | ||
| 37 | + """ | ||
| 38 | + log_probs, log_probs_lens = self.model.run( | ||
| 39 | + [self.model.get_outputs()[0].name, self.model.get_outputs()[1].name], | ||
| 40 | + { | ||
| 41 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 42 | + self.model.get_inputs()[1].name: x_lens.numpy(), | ||
| 43 | + }, | ||
| 44 | + ) | ||
| 45 | + return torch.from_numpy(log_probs), torch.from_numpy(log_probs_lens) | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +def get_features(test_wav_filename): | ||
| 49 | + wave, sample_rate = torchaudio.load(test_wav_filename) | ||
| 50 | + audio = wave[0].contiguous() # only use the first channel | ||
| 51 | + if sample_rate != 16000: | ||
| 52 | + audio = torchaudio.functional.resample( | ||
| 53 | + audio, orig_freq=sample_rate, new_freq=16000 | ||
| 54 | + ) | ||
| 55 | + audio *= 372768 | ||
| 56 | + | ||
| 57 | + opts = knf.FbankOptions() | ||
| 58 | + opts.frame_opts.dither = 0 | ||
| 59 | + opts.mel_opts.num_bins = 80 | ||
| 60 | + opts.frame_opts.snip_edges = False | ||
| 61 | + opts.mel_opts.debug_mel = False | ||
| 62 | + | ||
| 63 | + fbank = knf.OnlineFbank(opts) | ||
| 64 | + fbank.accept_waveform(16000, audio.numpy()) | ||
| 65 | + frames = [] | ||
| 66 | + for i in range(fbank.num_frames_ready): | ||
| 67 | + frames.append(torch.from_numpy(fbank.get_frame(i))) | ||
| 68 | + frames = torch.stack(frames) | ||
| 69 | + return frames | ||
| 70 | + | ||
| 71 | + | ||
| 72 | +def main(): | ||
| 73 | + model_filename = "./model.onnx" | ||
| 74 | + model = OnnxModel(model_filename) | ||
| 75 | + | ||
| 76 | + filename = "./0.wav" | ||
| 77 | + x = get_features(filename) | ||
| 78 | + x = x.unsqueeze(0) | ||
| 79 | + | ||
| 80 | + # Note: It supports only batch size == 1 | ||
| 81 | + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) | ||
| 82 | + | ||
| 83 | + print(x.shape, x_lens) | ||
| 84 | + | ||
| 85 | + log_probs, log_probs_lens = model(x, x_lens) | ||
| 86 | + log_probs = log_probs[0] | ||
| 87 | + print(log_probs.shape) | ||
| 88 | + | ||
| 89 | + indexes = log_probs.argmax(dim=1) | ||
| 90 | + print(indexes) | ||
| 91 | + indexes = torch.unique_consecutive(indexes) | ||
| 92 | + indexes = indexes[indexes != 0].tolist() | ||
| 93 | + | ||
| 94 | + id2word = dict() | ||
| 95 | + with open("./units.txt", encoding="utf-8") as f: | ||
| 96 | + for line in f: | ||
| 97 | + word, idx = line.strip().split() | ||
| 98 | + id2word[int(idx)] = word | ||
| 99 | + text = "".join([id2word[i] for i in indexes]) | ||
| 100 | + print(text) | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +if __name__ == "__main__": | ||
| 104 | + main() |
-
请 注册 或 登录 后发表评论