继续操作前请注册或者登录。
Fangjun Kuang
Committed by GitHub
  1 +name: export-wenet-to-onnx
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + paths:
  8 + - 'scripts/wenet/**'
  9 + - '.github/workflows/export-wenet-to-onnx.yaml'
  10 + pull_request:
  11 + paths:
  12 + - 'scripts/wenet/**'
  13 + - '.github/workflows/export-wenet-to-onnx.yaml'
  14 +
  15 + workflow_dispatch:
  16 +
  17 +concurrency:
  18 + group: export-wenet-to-onnx-${{ github.ref }}
  19 + cancel-in-progress: true
  20 +
  21 +jobs:
  22 + export-wenet-to-onnx:
  23 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  24 + name: export wenet
  25 + runs-on: ${{ matrix.os }}
  26 + strategy:
  27 + fail-fast: false
  28 + matrix:
  29 + os: [ubuntu-latest]
  30 + python-version: ["3.8"]
  31 +
  32 + steps:
  33 + - uses: actions/checkout@v4
  34 +
  35 + - name: Setup Python ${{ matrix.python-version }}
  36 + uses: actions/setup-python@v2
  37 + with:
  38 + python-version: ${{ matrix.python-version }}
  39 +
  40 + - name: Run
  41 + shell: bash
  42 + run: |
  43 + sudo apt-get install tree sox
  44 + cd scripts/wenet
  45 + ./run.sh
  46 +
  47 + - name: Publish to huggingface (aishell)
  48 + env:
  49 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  50 + uses: nick-fields/retry@v2
  51 + with:
  52 + max_attempts: 20
  53 + timeout_seconds: 200
  54 + shell: bash
  55 + command: |
  56 + git config --global user.email "csukuangfj@gmail.com"
  57 + git config --global user.name "Fangjun Kuang"
  58 +
  59 + rm -rf huggingface
  60 + export GIT_LFS_SKIP_SMUDGE=1
  61 +
  62 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell huggingface
  63 + cd huggingface
  64 + git fetch
  65 + git pull
  66 +
  67 + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/*.onnx .
  68 + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/units.txt tokens.txt
  69 + cp -v ../scripts/wenet/aishell_u2pp_conformer_exp/README.md .
  70 +
  71 + if [ ! -d test_wavs ]; then
  72 + mkdir test_wavs
  73 + cd test_wavs
  74 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav
  75 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav
  76 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav
  77 + cd ..
  78 + fi
  79 + git lfs track "*.onnx"
  80 + git add .
  81 +
  82 + git commit -m "add aishell models"
  83 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell main || true
  84 +
  85 + cd ..
  86 + rm -rf huggingface
  87 +
  88 + - name: Publish to huggingface (aishell2)
  89 + env:
  90 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  91 + uses: nick-fields/retry@v2
  92 + with:
  93 + max_attempts: 20
  94 + timeout_seconds: 200
  95 + shell: bash
  96 + command: |
  97 + git config --global user.email "csukuangfj@gmail.com"
  98 + git config --global user.name "Fangjun Kuang"
  99 +
  100 + rm -rf huggingface
  101 + export GIT_LFS_SKIP_SMUDGE=1
  102 +
  103 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell2 huggingface
  104 + cd huggingface
  105 + git fetch
  106 + git pull
  107 +
  108 + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/*.onnx .
  109 + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/units.txt tokens.txt
  110 + cp -v ../scripts/wenet/aishell2_u2pp_conformer_exp/README.md .
  111 +
  112 + if [ ! -d test_wavs ]; then
  113 + mkdir test_wavs
  114 + cd test_wavs
  115 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav
  116 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav
  117 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav
  118 + cd ..
  119 + fi
  120 + git lfs track "*.onnx"
  121 + git add .
  122 +
  123 + git commit -m "add aishell2 models"
  124 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-aishell2 main || true
  125 +
  126 + cd ..
  127 + rm -rf huggingface
  128 +
  129 + - name: Publish to huggingface (multi_cn)
  130 + env:
  131 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  132 + uses: nick-fields/retry@v2
  133 + with:
  134 + max_attempts: 20
  135 + timeout_seconds: 200
  136 + shell: bash
  137 + command: |
  138 + git config --global user.email "csukuangfj@gmail.com"
  139 + git config --global user.name "Fangjun Kuang"
  140 +
  141 + rm -rf huggingface
  142 + export GIT_LFS_SKIP_SMUDGE=1
  143 +
  144 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-multi-cn huggingface
  145 + cd huggingface
  146 + git fetch
  147 + git pull
  148 +
  149 + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/*.onnx .
  150 + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/units.txt tokens.txt
  151 + cp -v ../scripts/wenet/multi_cn_unified_conformer_exp/README.md .
  152 +
  153 + if [ ! -d test_wavs ]; then
  154 + mkdir test_wavs
  155 + cd test_wavs
  156 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav
  157 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav
  158 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav
  159 + cd ..
  160 + fi
  161 + git lfs track "*.onnx"
  162 + git add .
  163 +
  164 + git commit -m "add multi_cn models"
  165 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-multi-cn main || true
  166 +
  167 + cd ..
  168 + rm -rf huggingface
  169 +
  170 + - name: Publish to huggingface (wenetspeech)
  171 + env:
  172 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  173 + uses: nick-fields/retry@v2
  174 + with:
  175 + max_attempts: 20
  176 + timeout_seconds: 200
  177 + shell: bash
  178 + command: |
  179 + git config --global user.email "csukuangfj@gmail.com"
  180 + git config --global user.name "Fangjun Kuang"
  181 +
  182 + rm -rf huggingface
  183 + export GIT_LFS_SKIP_SMUDGE=1
  184 +
  185 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech huggingface
  186 + cd huggingface
  187 + git fetch
  188 + git pull
  189 +
  190 + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/*.onnx .
  191 + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/units.txt tokens.txt
  192 + cp -v ../scripts/wenet/20220506_u2pp_conformer_exp/README.md .
  193 +
  194 + if [ ! -d test_wavs ]; then
  195 + mkdir test_wavs
  196 + cd test_wavs
  197 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/0.wav
  198 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/1.wav
  199 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/resolve/main/test_wavs/8k.wav
  200 + cd ..
  201 + fi
  202 + git lfs track "*.onnx"
  203 + git add .
  204 +
  205 + git commit -m "add wenetspeech models"
  206 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech main || true
  207 +
  208 + cd ..
  209 + rm -rf huggingface
  210 +
  211 + - name: Publish to huggingface (librispeech)
  212 + env:
  213 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  214 + uses: nick-fields/retry@v2
  215 + with:
  216 + max_attempts: 20
  217 + timeout_seconds: 200
  218 + shell: bash
  219 + command: |
  220 + git config --global user.email "csukuangfj@gmail.com"
  221 + git config --global user.name "Fangjun Kuang"
  222 +
  223 + rm -rf huggingface
  224 + export GIT_LFS_SKIP_SMUDGE=1
  225 +
  226 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-en-wenet-librispeech huggingface
  227 + cd huggingface
  228 + git fetch
  229 + git pull
  230 +
  231 + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/*.onnx .
  232 + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/units.txt tokens.txt
  233 + cp -v ../scripts/wenet/librispeech_u2pp_conformer_exp/README.md .
  234 +
  235 + if [ ! -d test_wavs ]; then
  236 + mkdir test_wavs
  237 + cd test_wavs
  238 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/0.wav
  239 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/1.wav
  240 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/8k.wav
  241 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/trans.txt
  242 + cd ..
  243 + fi
  244 + git lfs track "*.onnx"
  245 + git add .
  246 +
  247 + git commit -m "add librispeech models"
  248 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-en-wenet-librispeech main || true
  249 +
  250 + cd ..
  251 + rm -rf huggingface
  252 +
  253 + - name: Publish to huggingface (gigaspeech)
  254 + env:
  255 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  256 + uses: nick-fields/retry@v2
  257 + with:
  258 + max_attempts: 20
  259 + timeout_seconds: 200
  260 + shell: bash
  261 + command: |
  262 + git config --global user.email "csukuangfj@gmail.com"
  263 + git config --global user.name "Fangjun Kuang"
  264 +
  265 + rm -rf huggingface
  266 + export GIT_LFS_SKIP_SMUDGE=1
  267 +
  268 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-en-wenet-gigaspeech huggingface
  269 + cd huggingface
  270 + git fetch
  271 + git pull
  272 +
  273 + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/*.onnx .
  274 + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/units.txt tokens.txt
  275 + cp -v ../scripts/wenet/20210728_u2pp_conformer_exp/README.md .
  276 +
  277 + if [ ! -d test_wavs ]; then
  278 + mkdir test_wavs
  279 + cd test_wavs
  280 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/0.wav
  281 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/1.wav
  282 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/8k.wav
  283 + wget -q https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21/resolve/main/test_wavs/trans.txt
  284 + cd ..
  285 + fi
  286 + git lfs track "*.onnx"
  287 + git add .
  288 +
  289 + git commit -m "add gigaspeech models"
  290 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-en-wenet-gigaspeech main || true
  291 +
  292 + cd ..
  293 + rm -rf huggingface
  1 +# Introduction
  2 +
  3 +This folder contains script for exporting models
  4 +from [wenet](https://github.com/wenet-e2e/wenet)
  5 +to onnx. You can use the exported models in sherpa-onnx.
  6 +
  7 +Note that both **streaming** and **non-streaming** models are supported.
  8 +
  9 +We only use the CTC branch. Rescore with the attention decoder
  10 +is not supported, though decoding with H, HL, and HLG is supported.
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# pip install git+https://github.com/wenet-e2e/wenet.git
  5 +# pip install onnxruntime onnx pyyaml
  6 +# cp -a ~/open-source/wenet/wenet/transducer/search .
  7 +# cp -a ~/open-source//wenet/wenet/e_branchformer .
  8 +# cp -a ~/open-source/wenet/wenet/ctl_model .
  9 +
  10 +import os
  11 +from typing import Dict
  12 +
  13 +import onnx
  14 +import torch
  15 +import yaml
  16 +from onnxruntime.quantization import QuantType, quantize_dynamic
  17 +
  18 +from wenet.utils.init_model import init_model
  19 +
  20 +
  21 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  22 + """Add meta data to an ONNX model. It is changed in-place.
  23 +
  24 + Args:
  25 + filename:
  26 + Filename of the ONNX model to be changed.
  27 + meta_data:
  28 + Key-value pairs.
  29 + """
  30 + model = onnx.load(filename)
  31 + for key, value in meta_data.items():
  32 + meta = model.metadata_props.add()
  33 + meta.key = key
  34 + meta.value = str(value)
  35 +
  36 + onnx.save(model, filename)
  37 +
  38 +
  39 +class OnnxModel(torch.nn.Module):
  40 + def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module):
  41 + super().__init__()
  42 + self.encoder = encoder
  43 + self.ctc = ctc
  44 +
  45 + def forward(
  46 + self,
  47 + x: torch.Tensor,
  48 + offset: torch.Tensor,
  49 + required_cache_size: torch.Tensor,
  50 + attn_cache: torch.Tensor,
  51 + conv_cache: torch.Tensor,
  52 + attn_mask: torch.Tensor,
  53 + ):
  54 + """
  55 + Args:
  56 + x:
  57 + A 3-D float32 tensor of shape (N, T, C). It supports only N == 1.
  58 + offset:
  59 + A scalar of dtype torch.int64.
  60 + required_cache_size:
  61 + A scalar of dtype torch.int64.
  62 + attn_cache:
  63 + A 4-D float32 tensor of shape (num_blocks, head, required_cache_size, encoder_output_size / head /2).
  64 + conv_cache:
  65 + A 4-D float32 tensor of shape (num_blocks, N, encoder_output_size, cnn_module_kernel - 1).
  66 + attn_mask:
  67 + A 3-D bool tensor of shape (N, 1, required_cache_size + chunk_size)
  68 + Returns:
  69 + Return a tuple of 3 tensors:
  70 + - A 3-D float32 tensor of shape (N, T, C) containing log_probs
  71 + - next_attn_cache
  72 + - next_conv_cache
  73 + """
  74 + encoder_out, next_att_cache, next_conv_cache = self.encoder.forward_chunk(
  75 + xs=x,
  76 + offset=offset,
  77 + required_cache_size=required_cache_size,
  78 + att_cache=attn_cache,
  79 + cnn_cache=conv_cache,
  80 + att_mask=attn_mask,
  81 + )
  82 + log_probs = self.ctc.log_softmax(encoder_out)
  83 +
  84 + return log_probs, next_att_cache, next_conv_cache
  85 +
  86 +
  87 +class Foo:
  88 + pass
  89 +
  90 +
  91 +@torch.no_grad()
  92 +def main():
  93 + args = Foo()
  94 + args.checkpoint = "./final.pt"
  95 + config_file = "./train.yaml"
  96 +
  97 + with open(config_file, "r") as fin:
  98 + configs = yaml.load(fin, Loader=yaml.FullLoader)
  99 + torch_model, configs = init_model(args, configs)
  100 + torch_model.eval()
  101 +
  102 + head = configs["encoder_conf"]["attention_heads"]
  103 + num_blocks = configs["encoder_conf"]["num_blocks"]
  104 + output_size = configs["encoder_conf"]["output_size"]
  105 + cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1)
  106 +
  107 + right_context = torch_model.right_context()
  108 + subsampling_factor = torch_model.encoder.embed.subsampling_rate
  109 + chunk_size = 16
  110 + left_chunks = 4
  111 +
  112 + decoding_window = (chunk_size - 1) * subsampling_factor + right_context + 1
  113 +
  114 + required_cache_size = chunk_size * left_chunks
  115 +
  116 + offset = required_cache_size
  117 +
  118 + attn_cache = torch.zeros(
  119 + num_blocks,
  120 + head,
  121 + required_cache_size,
  122 + output_size // head * 2,
  123 + dtype=torch.float32,
  124 + )
  125 +
  126 + attn_mask = torch.ones(1, 1, required_cache_size + chunk_size, dtype=torch.bool)
  127 + attn_mask[:, :, :required_cache_size] = 0
  128 +
  129 + conv_cache = torch.zeros(
  130 + num_blocks, 1, output_size, cnn_module_kernel - 1, dtype=torch.float32
  131 + )
  132 +
  133 + sos = torch_model.sos_symbol()
  134 + eos = torch_model.eos_symbol()
  135 +
  136 + onnx_model = OnnxModel(
  137 + encoder=torch_model.encoder,
  138 + ctc=torch_model.ctc,
  139 + )
  140 + filename = "model-streaming.onnx"
  141 +
  142 + N = 1
  143 + T = decoding_window
  144 + C = 80
  145 + x = torch.rand(N, T, C, dtype=torch.float32)
  146 + offset = torch.tensor([offset], dtype=torch.int64)
  147 + required_cache_size = torch.tensor([required_cache_size], dtype=torch.int64)
  148 +
  149 + opset_version = 13
  150 + torch.onnx.export(
  151 + onnx_model,
  152 + (x, offset, required_cache_size, attn_cache, conv_cache, attn_mask),
  153 + filename,
  154 + opset_version=opset_version,
  155 + input_names=[
  156 + "x",
  157 + "offset",
  158 + "required_cache_size",
  159 + "attn_cache",
  160 + "conv_cache",
  161 + "attn_mask",
  162 + ],
  163 + output_names=["log_probs", "next_att_cache", "next_conv_cache"],
  164 + dynamic_axes={
  165 + "x": {0: "N", 1: "T"},
  166 + "attn_cache": {2: "T"},
  167 + "log_probs": {0: "N"},
  168 + "new_attn_cache": {2: "T"},
  169 + },
  170 + )
  171 +
  172 + # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
  173 + url = os.environ.get("WENET_URL", "")
  174 + meta_data = {
  175 + "model_type": "wenet-ctc",
  176 + "version": "1",
  177 + "model_author": "wenet",
  178 + "comment": "streaming",
  179 + "url": "https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz",
  180 + "chunk_size": chunk_size,
  181 + "left_chunks": left_chunks,
  182 + "head": head,
  183 + "num_blocks": num_blocks,
  184 + "output_size": output_size,
  185 + "cnn_module_kernel": cnn_module_kernel,
  186 + "right_context": right_context,
  187 + "subsampling_factor": subsampling_factor,
  188 + }
  189 + add_meta_data(filename=filename, meta_data=meta_data)
  190 +
  191 + print("Generate int8 quantization models")
  192 +
  193 + filename_int8 = f"model-streaming.int8.onnx"
  194 + quantize_dynamic(
  195 + model_input=filename,
  196 + model_output=filename_int8,
  197 + op_types_to_quantize=["MatMul"],
  198 + weight_type=QuantType.QInt8,
  199 + )
  200 +
  201 +
  202 +if __name__ == "__main__":
  203 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# pip install git+https://github.com/wenet-e2e/wenet.git
  5 +# pip install onnxruntime onnx pyyaml
  6 +# cp -a ~/open-source/wenet/wenet/transducer/search .
  7 +# cp -a ~/open-source//wenet/wenet/e_branchformer .
  8 +# cp -a ~/open-source/wenet/wenet/ctl_model .
  9 +
  10 +import os
  11 +from typing import Dict
  12 +
  13 +import onnx
  14 +import torch
  15 +import yaml
  16 +from onnxruntime.quantization import QuantType, quantize_dynamic
  17 +
  18 +from wenet.utils.init_model import init_model
  19 +
  20 +
  21 +class Foo:
  22 + pass
  23 +
  24 +
  25 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  26 + """Add meta data to an ONNX model. It is changed in-place.
  27 +
  28 + Args:
  29 + filename:
  30 + Filename of the ONNX model to be changed.
  31 + meta_data:
  32 + Key-value pairs.
  33 + """
  34 + model = onnx.load(filename)
  35 + for key, value in meta_data.items():
  36 + meta = model.metadata_props.add()
  37 + meta.key = key
  38 + meta.value = str(value)
  39 +
  40 + onnx.save(model, filename)
  41 +
  42 +
  43 +class OnnxModel(torch.nn.Module):
  44 + def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module):
  45 + super().__init__()
  46 + self.encoder = encoder
  47 + self.ctc = ctc
  48 +
  49 + def forward(self, x, x_lens):
  50 + """
  51 + Args:
  52 + x:
  53 + A 3-D tensor of shape (N, T, C)
  54 + x_lens:
  55 + A 1-D tensor of shape (N,) containing valid lengths in x before
  56 + padding. Its type is torch.int64
  57 + """
  58 + encoder_out, encoder_out_mask = self.encoder(
  59 + x,
  60 + x_lens,
  61 + decoding_chunk_size=-1,
  62 + num_decoding_left_chunks=-1,
  63 + )
  64 + log_probs = self.ctc.log_softmax(encoder_out)
  65 + log_probs_lens = encoder_out_mask.int().squeeze(1).sum(1)
  66 +
  67 + return log_probs, log_probs_lens
  68 +
  69 +
  70 +@torch.no_grad()
  71 +def main():
  72 + args = Foo()
  73 + args.checkpoint = "./final.pt"
  74 + config_file = "./train.yaml"
  75 +
  76 + with open(config_file, "r") as fin:
  77 + configs = yaml.load(fin, Loader=yaml.FullLoader)
  78 + torch_model, configs = init_model(args, configs)
  79 + torch_model.eval()
  80 +
  81 + onnx_model = OnnxModel(encoder=torch_model.encoder, ctc=torch_model.ctc)
  82 + filename = "model.onnx"
  83 +
  84 + N = 1
  85 + T = 1000
  86 + C = 80
  87 + x = torch.rand(N, T, C, dtype=torch.float)
  88 + x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
  89 +
  90 + opset_version = 13
  91 + onnx_model = torch.jit.script(onnx_model)
  92 + torch.onnx.export(
  93 + onnx_model,
  94 + (x, x_lens),
  95 + filename,
  96 + opset_version=opset_version,
  97 + input_names=["x", "x_lens"],
  98 + output_names=["log_probs", "log_probs_lens"],
  99 + dynamic_axes={
  100 + "x": {0: "N", 1: "T"},
  101 + "x_lens": {0: "N"},
  102 + "log_probs": {0: "N", 1: "T"},
  103 + "log_probs_lens": {0: "N"},
  104 + },
  105 + )
  106 +
  107 + # https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
  108 + url = os.environ.get("WENET_URL", "")
  109 + meta_data = {
  110 + "model_type": "wenet-ctc",
  111 + "version": "1",
  112 + "model_author": "wenet",
  113 + "comment": "non-streaming",
  114 + "url": url,
  115 + }
  116 + add_meta_data(filename=filename, meta_data=meta_data)
  117 +
  118 + print("Generate int8 quantization models")
  119 +
  120 + filename_int8 = f"model.int8.onnx"
  121 + quantize_dynamic(
  122 + model_input=filename,
  123 + model_output=filename_int8,
  124 + op_types_to_quantize=["MatMul"],
  125 + weight_type=QuantType.QInt8,
  126 + )
  127 +
  128 +
  129 +if __name__ == "__main__":
  130 + main()
  1 +#!/usr/bin/env bash
  2 +#
  3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  4 +#
  5 +# Please refer to
  6 +# https://github.com/wenet-e2e/wenet/blob/main/docs/pretrained_models.en.md
  7 +# for a table of pre-trained models.
  8 +# Please select the column "Checkpoint Model" for downloading.
  9 +
  10 +set -ex
  11 +
  12 +function install_dependencies() {
  13 + pip install soundfile
  14 + pip install torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
  15 + pip install k2==1.24.4.dev20231022+cpu.torch2.1.0 -f https://k2-fsa.github.io/k2/cpu.html
  16 +
  17 + pip install onnxruntime onnx kaldi-native-fbank pyyaml
  18 +
  19 + pip install git+https://github.com/wenet-e2e/wenet.git
  20 + wenet_dir=$(dirname $(python3 -c "import wenet; print(wenet.__file__)"))
  21 + git clone https://github.com/wenet-e2e/wenet
  22 + if [ ! -d $wenet_dir/transducer/search ]; then
  23 + cp -av ./wenet/wenet/transducer/search $wenet_dir/transducer
  24 + fi
  25 +
  26 + if [ ! -d $wenet_dir/e_branchformer ]; then
  27 + cp -a .//wenet/wenet/e_branchformer $wenet_dir
  28 + fi
  29 +
  30 + if [ ! -d $wenet_dir/ctl_model ]; then
  31 + cp -a ./wenet/wenet/ctl_model $wenet_dir
  32 + fi
  33 +
  34 + rm -rf wenet
  35 +}
  36 +
  37 +function aishell() {
  38 + echo "aishell"
  39 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/aishell_u2pp_conformer_exp.tar.gz
  40 + tar xvf aishell_u2pp_conformer_exp.tar.gz
  41 + rm -v aishell_u2pp_conformer_exp.tar.gz
  42 +
  43 + pushd aishell_u2pp_conformer_exp
  44 + mkdir -p exp/20210601_u2++_conformer_exp
  45 + cp global_cmvn ./exp/20210601_u2++_conformer_exp
  46 + cp ../*.py .
  47 +
  48 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
  49 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav
  50 + soxi 0.wav
  51 +
  52 + echo "Test streaming"
  53 + ./export-onnx-streaming.py
  54 + ls -lh
  55 + ./test-onnx-streaming.py
  56 +
  57 + echo "Test non-streaming"
  58 + ./export-onnx.py
  59 + ls -lh
  60 + ./test-onnx.py
  61 +
  62 + cat > README.md <<EOF
  63 +# Introduction
  64 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
  65 +EOF
  66 +
  67 + popd
  68 +}
  69 +
  70 +function aishell2() {
  71 + echo "aishell2"
  72 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/aishell2_u2pp_conformer_exp.tar.gz
  73 + tar xvf aishell2_u2pp_conformer_exp.tar.gz
  74 + rm -v aishell2_u2pp_conformer_exp.tar.gz
  75 +
  76 + pushd aishell2_u2pp_conformer_exp
  77 + mkdir -p exp/u2++_conformer
  78 + cp global_cmvn ./exp/u2++_conformer
  79 + cp ../*.py .
  80 +
  81 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=aishell2_u2pp_conformer_exp.tar.gz
  82 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav
  83 + soxi 0.wav
  84 +
  85 + echo "Test streaming"
  86 + ./export-onnx-streaming.py
  87 + ls -lh
  88 + ./test-onnx-streaming.py
  89 +
  90 + echo "Test non-streaming"
  91 + ./export-onnx.py
  92 + ls -lh
  93 + ./test-onnx.py
  94 +
  95 + cat > README.md <<EOF
  96 +# Introduction
  97 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=aishell2_u2pp_conformer_exp.tar.gz
  98 +EOF
  99 +
  100 + popd
  101 +}
  102 +
  103 +function multi_cn() {
  104 + echo "multi_cn"
  105 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/multi_cn_unified_conformer_exp.tar.gz
  106 + tar xvf multi_cn_unified_conformer_exp.tar.gz
  107 + rm -v multi_cn_unified_conformer_exp.tar.gz
  108 +
  109 + pushd multi_cn_unified_conformer_exp
  110 + mkdir -p exp/20210815_unified_conformer_exp
  111 + cp global_cmvn ./exp/20210815_unified_conformer_exp
  112 + cp ../*.py .
  113 +
  114 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=multi_cn_unified_conformer_exp.tar.gz
  115 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav
  116 + soxi 0.wav
  117 +
  118 + echo "Test streaming"
  119 + ./export-onnx-streaming.py
  120 + ls -lh
  121 + ./test-onnx-streaming.py
  122 +
  123 + echo "Test non-streaming"
  124 + ./export-onnx.py
  125 + ls -lh
  126 + ./test-onnx.py
  127 +
  128 + cat > README.md <<EOF
  129 +# Introduction
  130 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=multi_cn_unified_conformer_exp.tar.gz
  131 +EOF
  132 +
  133 + popd
  134 +}
  135 +
  136 +function wenetspeech() {
  137 + echo "wenetspeech"
  138 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/wenetspeech_u2pp_conformer_exp.tar.gz
  139 + tar xvf wenetspeech_u2pp_conformer_exp.tar.gz
  140 + rm -v wenetspeech_u2pp_conformer_exp.tar.gz
  141 +
  142 + pushd 20220506_u2pp_conformer_exp
  143 + mkdir -p exp/20220506_u2pp_conformer_exp
  144 + cp global_cmvn ./exp/20220506_u2pp_conformer_exp
  145 + cp ../*.py .
  146 +
  147 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=wenetspeech_u2pp_conformer_exp.tar.gz
  148 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/zh.wav
  149 + soxi 0.wav
  150 +
  151 + echo "Test streaming"
  152 + ./export-onnx-streaming.py
  153 + ls -lh
  154 + ./test-onnx-streaming.py
  155 +
  156 + echo "Test non-streaming"
  157 + ./export-onnx.py
  158 + ls -lh
  159 + ./test-onnx.py
  160 +
  161 + cat > README.md <<EOF
  162 +# Introduction
  163 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=wenetspeech_u2pp_conformer_exp.tar.gz
  164 +EOF
  165 +
  166 + popd
  167 +}
  168 +
  169 +function librispeech() {
  170 + echo "librispeech"
  171 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/librispeech_u2pp_conformer_exp.tar.gz
  172 + tar xvf librispeech_u2pp_conformer_exp.tar.gz
  173 + rm -v librispeech_u2pp_conformer_exp.tar.gz
  174 +
  175 + pushd librispeech_u2pp_conformer_exp
  176 + mkdir -p data/train_960
  177 + cp global_cmvn ./data/train_960
  178 + cp ../*.py .
  179 +
  180 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=librispeech_u2pp_conformer_exp.tar.gz
  181 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/en.wav
  182 + soxi 0.wav
  183 +
  184 + echo "Test streaming"
  185 + ./export-onnx-streaming.py
  186 + ls -lh
  187 + ./test-onnx-streaming.py
  188 +
  189 + echo "Test non-streaming"
  190 + ./export-onnx.py
  191 + ls -lh
  192 + ./test-onnx.py
  193 +
  194 + cat > README.md <<EOF
  195 +# Introduction
  196 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=librispeech_u2pp_conformer_exp.tar.gz
  197 +EOF
  198 +
  199 + popd
  200 +}
  201 +
  202 +function gigaspeech() {
  203 + echo "gigaspeech"
  204 + wget -q https://huggingface.co/openspeech/wenet-models/resolve/main/gigaspeech_u2pp_conformer_exp.tar.gz
  205 + tar xvf gigaspeech_u2pp_conformer_exp.tar.gz
  206 + rm -v gigaspeech_u2pp_conformer_exp.tar.gz
  207 +
  208 + pushd 20210728_u2pp_conformer_exp
  209 + mkdir -p data/gigaspeech_train_xl
  210 + cp global_cmvn ./data/gigaspeech_train_xl
  211 + cp ../*.py .
  212 +
  213 + export WENET_URL=https://wenet.org.cn/downloads?models=wenet&version=gigaspeech_u2pp_conformer_exp.tar.gz
  214 + wget -O 0.wav https://huggingface.co/openspeech/wenet-models/resolve/main/en.wav
  215 + soxi 0.wav
  216 +
  217 + echo "Test streaming"
  218 + ./export-onnx-streaming.py
  219 + ls -lh
  220 + ./test-onnx-streaming.py
  221 +
  222 + echo "Test non-streaming"
  223 + ./export-onnx.py
  224 + ls -lh
  225 + ./test-onnx.py
  226 +
  227 + cat > README.md <<EOF
  228 +# Introduction
  229 +This model is converted from https://wenet.org.cn/downloads?models=wenet&version=gigaspeech_u2pp_conformer_exp.tar.gz
  230 +EOF
  231 +
  232 + popd
  233 +}
  234 +
  235 +install_dependencies
  236 +
  237 +aishell
  238 +
  239 +aishell2
  240 +
  241 +multi_cn
  242 +
  243 +wenetspeech
  244 +
  245 +librispeech
  246 +
  247 +gigaspeech
  248 +
  249 +tree .
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import kaldi_native_fbank as knf
  5 +import onnxruntime as ort
  6 +import torch
  7 +import torchaudio
  8 +from torch.nn.utils.rnn import pad_sequence
  9 +
  10 +
  11 +class OnnxModel:
  12 + def __init__(
  13 + self,
  14 + filename: str,
  15 + ):
  16 + session_opts = ort.SessionOptions()
  17 + session_opts.inter_op_num_threads = 1
  18 + session_opts.intra_op_num_threads = 4
  19 +
  20 + self.session_opts = session_opts
  21 +
  22 + self.model = ort.InferenceSession(
  23 + filename,
  24 + sess_options=self.session_opts,
  25 + providers=["CPUExecutionProvider"],
  26 + )
  27 +
  28 + meta = self.model.get_modelmeta().custom_metadata_map
  29 + self.left_chunks = int(meta["left_chunks"])
  30 + self.num_blocks = int(meta["num_blocks"])
  31 + self.chunk_size = int(meta["chunk_size"])
  32 + self.head = int(meta["head"])
  33 + self.output_size = int(meta["output_size"])
  34 + self.cnn_module_kernel = int(meta["cnn_module_kernel"])
  35 + self.right_context = int(meta["right_context"])
  36 + self.subsampling_factor = int(meta["subsampling_factor"])
  37 +
  38 + self._init_cache()
  39 +
  40 + def _init_cache(self):
  41 + required_cache_size = self.chunk_size * self.left_chunks
  42 +
  43 + self.attn_cache = torch.zeros(
  44 + self.num_blocks,
  45 + self.head,
  46 + required_cache_size,
  47 + self.output_size // self.head * 2,
  48 + dtype=torch.float32,
  49 + ).numpy()
  50 +
  51 + self.conv_cache = torch.zeros(
  52 + self.num_blocks,
  53 + 1,
  54 + self.output_size,
  55 + self.cnn_module_kernel - 1,
  56 + dtype=torch.float32,
  57 + ).numpy()
  58 +
  59 + self.offset = torch.tensor([required_cache_size], dtype=torch.int64).numpy()
  60 +
  61 + self.required_cache_size = torch.tensor(
  62 + [self.chunk_size * self.left_chunks], dtype=torch.int64
  63 + ).numpy()
  64 +
  65 + def __call__(self, x: torch.Tensor) -> torch.Tensor:
  66 + """
  67 + Args:
  68 + x:
  69 + A 2-D tensor of shape (T, C)
  70 + Returns:
  71 + Return a 2-D tensor of shape (T, C) containing log_probs.
  72 + """
  73 + attn_mask = torch.ones(
  74 + 1, 1, int(self.required_cache_size + self.chunk_size), dtype=torch.bool
  75 + )
  76 + chunk_idx = self.offset // self.chunk_size - self.left_chunks
  77 + if chunk_idx < self.left_chunks:
  78 + attn_mask[
  79 + :, :, : int(self.required_cache_size - chunk_idx * self.chunk_size)
  80 + ] = False
  81 +
  82 + log_probs, new_attn_cache, new_conv_cache = self.model.run(
  83 + [
  84 + self.model.get_outputs()[0].name,
  85 + self.model.get_outputs()[1].name,
  86 + self.model.get_outputs()[2].name,
  87 + ],
  88 + {
  89 + self.model.get_inputs()[0].name: x.unsqueeze(0).numpy(),
  90 + self.model.get_inputs()[1].name: self.offset,
  91 + self.model.get_inputs()[2].name: self.required_cache_size,
  92 + self.model.get_inputs()[3].name: self.attn_cache,
  93 + self.model.get_inputs()[4].name: self.conv_cache,
  94 + self.model.get_inputs()[5].name: attn_mask.numpy(),
  95 + },
  96 + )
  97 +
  98 + self.attn_cache = new_attn_cache
  99 + self.conv_cache = new_conv_cache
  100 +
  101 + log_probs = torch.from_numpy(log_probs)
  102 +
  103 + self.offset += log_probs.shape[1]
  104 +
  105 + return log_probs.squeeze(0)
  106 +
  107 +
  108 +def get_features(test_wav_filename):
  109 + wave, sample_rate = torchaudio.load(test_wav_filename)
  110 + audio = wave[0].contiguous() # only use the first channel
  111 + if sample_rate != 16000:
  112 + audio = torchaudio.functional.resample(
  113 + audio, orig_freq=sample_rate, new_freq=16000
  114 + )
  115 + audio *= 372768
  116 +
  117 + opts = knf.FbankOptions()
  118 + opts.frame_opts.dither = 0
  119 + opts.mel_opts.num_bins = 80
  120 + opts.frame_opts.snip_edges = False
  121 + opts.mel_opts.debug_mel = False
  122 +
  123 + fbank = knf.OnlineFbank(opts)
  124 + fbank.accept_waveform(16000, audio.numpy())
  125 + frames = []
  126 + for i in range(fbank.num_frames_ready):
  127 + frames.append(torch.from_numpy(fbank.get_frame(i)))
  128 + frames = torch.stack(frames)
  129 + return frames
  130 +
  131 +
  132 +def main():
  133 + model_filename = "./model-streaming.onnx"
  134 + model = OnnxModel(model_filename)
  135 +
  136 + filename = "./0.wav"
  137 + x = get_features(filename)
  138 +
  139 + padding = torch.zeros(int(16000 * 0.5), 80)
  140 + x = torch.cat([x, padding], dim=0)
  141 +
  142 + chunk_length = (
  143 + (model.chunk_size - 1) * model.subsampling_factor + model.right_context + 1
  144 + )
  145 + chunk_length = int(chunk_length)
  146 + chunk_shift = int(model.required_cache_size)
  147 + print(chunk_length, chunk_shift)
  148 +
  149 + num_frames = x.shape[0]
  150 + n = (num_frames - chunk_length) // chunk_shift + 1
  151 + tokens = []
  152 + for i in range(n):
  153 + start = i * chunk_shift
  154 + end = start + chunk_length
  155 + frames = x[start:end, :]
  156 + log_probs = model(frames)
  157 +
  158 + indexes = log_probs.argmax(dim=1)
  159 + indexes = torch.unique_consecutive(indexes)
  160 + indexes = indexes[indexes != 0].tolist()
  161 + if indexes:
  162 + tokens.extend(indexes)
  163 +
  164 + id2word = dict()
  165 + with open("./units.txt", encoding="utf-8") as f:
  166 + for line in f:
  167 + word, idx = line.strip().split()
  168 + id2word[int(idx)] = word
  169 + text = "".join([id2word[i] for i in tokens])
  170 + print(text)
  171 +
  172 +
  173 +if __name__ == "__main__":
  174 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import kaldi_native_fbank as knf
  5 +import onnxruntime as ort
  6 +import torch
  7 +import torchaudio
  8 +from torch.nn.utils.rnn import pad_sequence
  9 +
  10 +
  11 +class OnnxModel:
  12 + def __init__(
  13 + self,
  14 + filename: str,
  15 + ):
  16 + session_opts = ort.SessionOptions()
  17 + session_opts.inter_op_num_threads = 1
  18 + session_opts.intra_op_num_threads = 4
  19 +
  20 + self.session_opts = session_opts
  21 +
  22 + self.model = ort.InferenceSession(
  23 + filename,
  24 + sess_options=self.session_opts,
  25 + providers=["CPUExecutionProvider"],
  26 + )
  27 +
  28 + def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
  29 + """
  30 + Args:
  31 + x:
  32 + A 3-D tensor of shape (N, T, C)
  33 + x_lens:
  34 + A 1-D tensor of shape (N,). Its dtype is torch.int64
  35 + Returns:
  36 + Return a 3-D tensor of shape (N, T, C) containing log_probs.
  37 + """
  38 + log_probs, log_probs_lens = self.model.run(
  39 + [self.model.get_outputs()[0].name, self.model.get_outputs()[1].name],
  40 + {
  41 + self.model.get_inputs()[0].name: x.numpy(),
  42 + self.model.get_inputs()[1].name: x_lens.numpy(),
  43 + },
  44 + )
  45 + return torch.from_numpy(log_probs), torch.from_numpy(log_probs_lens)
  46 +
  47 +
  48 +def get_features(test_wav_filename):
  49 + wave, sample_rate = torchaudio.load(test_wav_filename)
  50 + audio = wave[0].contiguous() # only use the first channel
  51 + if sample_rate != 16000:
  52 + audio = torchaudio.functional.resample(
  53 + audio, orig_freq=sample_rate, new_freq=16000
  54 + )
  55 + audio *= 372768
  56 +
  57 + opts = knf.FbankOptions()
  58 + opts.frame_opts.dither = 0
  59 + opts.mel_opts.num_bins = 80
  60 + opts.frame_opts.snip_edges = False
  61 + opts.mel_opts.debug_mel = False
  62 +
  63 + fbank = knf.OnlineFbank(opts)
  64 + fbank.accept_waveform(16000, audio.numpy())
  65 + frames = []
  66 + for i in range(fbank.num_frames_ready):
  67 + frames.append(torch.from_numpy(fbank.get_frame(i)))
  68 + frames = torch.stack(frames)
  69 + return frames
  70 +
  71 +
  72 +def main():
  73 + model_filename = "./model.onnx"
  74 + model = OnnxModel(model_filename)
  75 +
  76 + filename = "./0.wav"
  77 + x = get_features(filename)
  78 + x = x.unsqueeze(0)
  79 +
  80 + # Note: It supports only batch size == 1
  81 + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
  82 +
  83 + print(x.shape, x_lens)
  84 +
  85 + log_probs, log_probs_lens = model(x, x_lens)
  86 + log_probs = log_probs[0]
  87 + print(log_probs.shape)
  88 +
  89 + indexes = log_probs.argmax(dim=1)
  90 + print(indexes)
  91 + indexes = torch.unique_consecutive(indexes)
  92 + indexes = indexes[indexes != 0].tolist()
  93 +
  94 + id2word = dict()
  95 + with open("./units.txt", encoding="utf-8") as f:
  96 + for line in f:
  97 + word, idx = line.strip().split()
  98 + id2word[int(idx)] = word
  99 + text = "".join([id2word[i] for i in indexes])
  100 + print(text)
  101 +
  102 +
  103 +if __name__ == "__main__":
  104 + main()