Fangjun Kuang
Committed by GitHub

Export nvidia/canary-180m-flash to sherpa-onnx (#2272)

  1 +name: export-nemo-canary-180m-flash
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - export-nemo-canary
  7 + workflow_dispatch:
  8 +
  9 +concurrency:
  10 + group: export-nemo-canary-180m-flash-${{ github.ref }}
  11 + cancel-in-progress: true
  12 +
  13 +jobs:
  14 + export-nemo-canary-180m-flash:
  15 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  16 + name: parakeet nemo canary 180m flash
  17 + runs-on: ${{ matrix.os }}
  18 + strategy:
  19 + fail-fast: false
  20 + matrix:
  21 + os: [macos-latest]
  22 + python-version: ["3.10"]
  23 +
  24 + steps:
  25 + - uses: actions/checkout@v4
  26 +
  27 + - name: Setup Python ${{ matrix.python-version }}
  28 + uses: actions/setup-python@v5
  29 + with:
  30 + python-version: ${{ matrix.python-version }}
  31 +
  32 + - name: Run
  33 + shell: bash
  34 + run: |
  35 + cd scripts/nemo/canary
  36 + ./run_180m_flash.sh
  37 +
  38 + ls -lh *.onnx
  39 + mv -v *.onnx ../../..
  40 + mv -v tokens.txt ../../..
  41 + mv de.wav ../../../
  42 + mv en.wav ../../../
  43 +
  44 + - name: Collect files (fp32)
  45 + shell: bash
  46 + run: |
  47 + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
  48 + mkdir -p $d
  49 + cp encoder.onnx $d
  50 + cp decoder.onnx $d
  51 + cp tokens.txt $d
  52 +
  53 + mkdir $d/test_wavs
  54 + cp de.wav $d/test_wavs
  55 + cp en.wav $d/test_wavs
  56 +
  57 + tar cjfv $d.tar.bz2 $d
  58 +
  59 + - name: Collect files (int8)
  60 + shell: bash
  61 + run: |
  62 + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
  63 + mkdir -p $d
  64 + cp encoder.int8.onnx $d
  65 + cp decoder.fp16.onnx $d
  66 + cp tokens.txt $d
  67 +
  68 + mkdir $d/test_wavs
  69 + cp de.wav $d/test_wavs
  70 + cp en.wav $d/test_wavs
  71 +
  72 + tar cjfv $d.tar.bz2 $d
  73 +
  74 + - name: Collect files (fp16)
  75 + shell: bash
  76 + run: |
  77 + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
  78 + mkdir -p $d
  79 + cp encoder.fp16.onnx $d
  80 + cp decoder.fp16.onnx $d
  81 + cp tokens.txt $d
  82 +
  83 + mkdir $d/test_wavs
  84 + cp de.wav $d/test_wavs
  85 + cp en.wav $d/test_wavs
  86 +
  87 + tar cjfv $d.tar.bz2 $d
  88 +
  89 + - name: Publish to huggingface
  90 + env:
  91 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  92 + uses: nick-fields/retry@v3
  93 + with:
  94 + max_attempts: 20
  95 + timeout_seconds: 200
  96 + shell: bash
  97 + command: |
  98 + git config --global user.email "csukuangfj@gmail.com"
  99 + git config --global user.name "Fangjun Kuang"
  100 +
  101 + models=(
  102 + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr
  103 + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8
  104 + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16
  105 + )
  106 +
  107 + for m in ${models[@]}; do
  108 + rm -rf huggingface
  109 + export GIT_LFS_SKIP_SMUDGE=1
  110 + export GIT_CLONE_PROTECTION_ACTIVE=false
  111 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m huggingface
  112 + cp -av $m/* huggingface
  113 + cd huggingface
  114 + git lfs track "*.onnx"
  115 + git lfs track "*.wav"
  116 + git status
  117 + git add .
  118 + git status
  119 + git commit -m "first commit"
  120 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m main
  121 + cd ..
  122 + done
  123 +
  124 + - name: Release
  125 + uses: svenstaro/upload-release-action@v2
  126 + with:
  127 + file_glob: true
  128 + file: ./*.tar.bz2
  129 + overwrite: true
  130 + repo_name: k2-fsa/sherpa-onnx
  131 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  132 + tag: asr-models
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import os
  5 +from typing import Tuple
  6 +
  7 +import nemo
  8 +import onnxmltools
  9 +import torch
  10 +from nemo.collections.common.parts import NEG_INF
  11 +from onnxmltools.utils.float16_converter import convert_float_to_float16
  12 +from onnxruntime.quantization import QuantType, quantize_dynamic
  13 +
  14 +"""
  15 +NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED :
  16 +Could not find an implementation for Trilu(14) node with name '/Trilu'
  17 +
  18 +See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631
  19 +
  20 +So we use fixed_form_attention_mask() to replace
  21 +the original form_attention_mask()
  22 +"""
  23 +
  24 +
  25 +def fixed_form_attention_mask(input_mask, diagonal=None):
  26 + """
  27 + Fixed: Build attention mask with optional masking of future tokens we forbid
  28 + to attend to (e.g. as it is in Transformer decoder).
  29 +
  30 + Args:
  31 + input_mask: binary mask of size B x L with 1s corresponding to valid
  32 + tokens and 0s corresponding to padding tokens
  33 + diagonal: diagonal where triangular future mask starts
  34 + None -- do not mask anything
  35 + 0 -- regular translation or language modeling future masking
  36 + 1 -- query stream masking as in XLNet architecture
  37 + Returns:
  38 + attention_mask: mask of size B x 1 x L x L with 0s corresponding to
  39 + tokens we plan to attend to and -10000 otherwise
  40 + """
  41 +
  42 + if input_mask is None:
  43 + return None
  44 + attn_shape = (1, input_mask.shape[1], input_mask.shape[1])
  45 + attn_mask = input_mask.to(dtype=bool).unsqueeze(1)
  46 + if diagonal is not None:
  47 + future_mask = torch.tril(
  48 + torch.ones(
  49 + attn_shape,
  50 + dtype=torch.int64, # it was torch.bool
  51 + # but onnxruntime does not support torch.int32 or torch.bool
  52 + # in torch.tril
  53 + device=input_mask.device,
  54 + ),
  55 + diagonal,
  56 + ).bool()
  57 + attn_mask = attn_mask & future_mask
  58 + attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF
  59 + return attention_mask.unsqueeze(1)
  60 +
  61 +
  62 +nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask
  63 +
  64 +from nemo.collections.asr.models import EncDecMultiTaskModel
  65 +
  66 +
  67 +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
  68 + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
  69 + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
  70 + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
  71 +
  72 +
  73 +def lens_to_mask(lens, max_length):
  74 + """
  75 + Create a mask from a tensor of lengths.
  76 + """
  77 + batch_size = lens.shape[0]
  78 + arange = torch.arange(max_length, device=lens.device)
  79 + mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
  80 + return mask
  81 +
  82 +
  83 +class EncoderWrapper(torch.nn.Module):
  84 + def __init__(self, m):
  85 + super().__init__()
  86 + self.encoder = m.encoder
  87 + self.encoder_decoder_proj = m.encoder_decoder_proj
  88 +
  89 + def forward(
  90 + self, x: torch.Tensor, x_len: torch.Tensor
  91 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  92 + """
  93 + Args:
  94 + x: (N, T, C)
  95 + x_len: (N,)
  96 + Returns:
  97 + - enc_states: (N, T, C)
  98 + - encoded_len: (N,)
  99 + - enc_mask: (N, T)
  100 + """
  101 + x = x.permute(0, 2, 1)
  102 + # x: (N, C, T)
  103 + encoded, encoded_len = self.encoder(audio_signal=x, length=x_len)
  104 +
  105 + enc_states = encoded.permute(0, 2, 1)
  106 +
  107 + enc_states = self.encoder_decoder_proj(enc_states)
  108 +
  109 + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1])
  110 +
  111 + return enc_states, encoded_len, enc_mask
  112 +
  113 +
  114 +class DecoderWrapper(torch.nn.Module):
  115 + def __init__(self, m):
  116 + super().__init__()
  117 + self.decoder = m.transf_decoder
  118 + self.log_softmax = m.log_softmax
  119 +
  120 + # We use only greedy search, so there is no need to compute log_softmax
  121 + self.log_softmax.mlp.log_softmax = False
  122 +
  123 + def forward(
  124 + self,
  125 + decoder_input_ids: torch.Tensor,
  126 + decoder_mems_list_0: torch.Tensor,
  127 + decoder_mems_list_1: torch.Tensor,
  128 + decoder_mems_list_2: torch.Tensor,
  129 + decoder_mems_list_3: torch.Tensor,
  130 + decoder_mems_list_4: torch.Tensor,
  131 + decoder_mems_list_5: torch.Tensor,
  132 + enc_states: torch.Tensor,
  133 + enc_mask: torch.Tensor,
  134 + ):
  135 + """
  136 + Args:
  137 + decoder_input_ids: (N, num_tokens), torch.int32
  138 + decoder_mems_list_i: (N, num_tokens, 1024)
  139 + enc_states: (N, T, 1024)
  140 + enc_mask: (N, T)
  141 + Returns:
  142 + - logits: (N, 1, vocab_size)
  143 + - decoder_mems_list_i: (N, num_tokens_2, 1024)
  144 + """
  145 + pos = decoder_input_ids[0][-1].item()
  146 + decoder_input_ids = decoder_input_ids[:, :-1]
  147 +
  148 + decoder_hidden_states = self.decoder.embedding.forward(
  149 + decoder_input_ids, start_pos=pos
  150 + )
  151 + decoder_input_mask = torch.ones_like(decoder_input_ids).float()
  152 +
  153 + decoder_mems_list = self.decoder.decoder.forward(
  154 + decoder_hidden_states,
  155 + decoder_input_mask,
  156 + enc_states,
  157 + enc_mask,
  158 + [
  159 + decoder_mems_list_0,
  160 + decoder_mems_list_1,
  161 + decoder_mems_list_2,
  162 + decoder_mems_list_3,
  163 + decoder_mems_list_4,
  164 + decoder_mems_list_5,
  165 + ],
  166 + return_mems=True,
  167 + )
  168 + logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:])
  169 +
  170 + return logits, decoder_mems_list
  171 +
  172 +
  173 +def export_encoder(canary_model):
  174 + encoder = EncoderWrapper(canary_model)
  175 + x = torch.rand(1, 4000, 128)
  176 + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
  177 +
  178 + encoder_filename = "encoder.onnx"
  179 + torch.onnx.export(
  180 + encoder,
  181 + (x, x_lens),
  182 + encoder_filename,
  183 + input_names=["x", "x_len"],
  184 + output_names=["enc_states", "enc_len", "enc_mask"],
  185 + opset_version=14,
  186 + dynamic_axes={
  187 + "x": {0: "N", 1: "T"},
  188 + "x_len": {0: "N"},
  189 + "enc_states": {0: "N", 1: "T"},
  190 + "enc_len": {0: "N"},
  191 + "enc_mask": {0: "N", 1: "T"},
  192 + },
  193 + )
  194 +
  195 +
  196 +def export_decoder(canary_model):
  197 + decoder = DecoderWrapper(canary_model)
  198 + decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32)
  199 +
  200 + decoder_mems_list_0 = torch.zeros(1, 1, 1024)
  201 + decoder_mems_list_1 = torch.zeros(1, 1, 1024)
  202 + decoder_mems_list_2 = torch.zeros(1, 1, 1024)
  203 + decoder_mems_list_3 = torch.zeros(1, 1, 1024)
  204 + decoder_mems_list_4 = torch.zeros(1, 1, 1024)
  205 + decoder_mems_list_5 = torch.zeros(1, 1, 1024)
  206 +
  207 + enc_states = torch.zeros(1, 1000, 1024)
  208 + enc_mask = torch.ones(1, 1000).bool()
  209 +
  210 + torch.onnx.export(
  211 + decoder,
  212 + (
  213 + decoder_input_ids,
  214 + decoder_mems_list_0,
  215 + decoder_mems_list_1,
  216 + decoder_mems_list_2,
  217 + decoder_mems_list_3,
  218 + decoder_mems_list_4,
  219 + decoder_mems_list_5,
  220 + enc_states,
  221 + enc_mask,
  222 + ),
  223 + "decoder.onnx",
  224 + opset_version=14,
  225 + input_names=[
  226 + "decoder_input_ids",
  227 + "decoder_mems_list_0",
  228 + "decoder_mems_list_1",
  229 + "decoder_mems_list_2",
  230 + "decoder_mems_list_3",
  231 + "decoder_mems_list_4",
  232 + "decoder_mems_list_5",
  233 + "enc_states",
  234 + "enc_mask",
  235 + ],
  236 + output_names=[
  237 + "logits",
  238 + "next_decoder_mem_list_0",
  239 + "next_decoder_mem_list_1",
  240 + "next_decoder_mem_list_2",
  241 + "next_decoder_mem_list_3",
  242 + "next_decoder_mem_list_4",
  243 + "next_decoder_mem_list_5",
  244 + ],
  245 + dynamic_axes={
  246 + "decoder_input_ids": {1: "num_tokens"},
  247 + "decoder_mems_list_0": {1: "num_tokens"},
  248 + "decoder_mems_list_1": {1: "num_tokens"},
  249 + "decoder_mems_list_2": {1: "num_tokens"},
  250 + "decoder_mems_list_3": {1: "num_tokens"},
  251 + "decoder_mems_list_4": {1: "num_tokens"},
  252 + "decoder_mems_list_5": {1: "num_tokens"},
  253 + "enc_states": {1: "T"},
  254 + "enc_mask": {1: "T"},
  255 + },
  256 + )
  257 +
  258 +
  259 +def export_tokens(canary_model):
  260 + with open("./tokens.txt", "w", encoding="utf-8") as f:
  261 + for i in range(canary_model.tokenizer.vocab_size):
  262 + s = canary_model.tokenizer.ids_to_text([i])
  263 + f.write(f"{s} {i}\n")
  264 + print("Saved to tokens.txt")
  265 +
  266 +
  267 +@torch.no_grad()
  268 +def main():
  269 + canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash")
  270 + export_tokens(canary_model)
  271 + export_encoder(canary_model)
  272 + export_decoder(canary_model)
  273 +
  274 + for m in ["encoder", "decoder"]:
  275 + if m == "encoder":
  276 + # we don't quantize the decoder with int8 since the accuracy drops
  277 + quantize_dynamic(
  278 + model_input=f"./{m}.onnx",
  279 + model_output=f"./{m}.int8.onnx",
  280 + weight_type=QuantType.QUInt8,
  281 + )
  282 +
  283 + export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
  284 +
  285 + os.system("ls -lh *.onnx")
  286 +
  287 +
  288 +if __name__ == "__main__":
  289 + main()
  1 +#!/usr/bin/env bash
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +set -ex
  5 +
  6 +log() {
  7 + # This function is from espnet
  8 + local fname=${BASH_SOURCE[1]##*/}
  9 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  10 +}
  11 +
  12 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/de.wav
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/en.wav
  14 +
  15 +pip install \
  16 + nemo_toolkit['asr'] \
  17 + "numpy<2" \
  18 + ipython \
  19 + kaldi-native-fbank \
  20 + librosa \
  21 + onnx==1.17.0 \
  22 + onnxmltools \
  23 + onnxruntime==1.17.1 \
  24 + soundfile
  25 +
  26 +python3 ./export_onnx_180m_flash.py
  27 +ls -lh *.onnx
  28 +
  29 +
  30 +log "-----fp32------"
  31 +
  32 +python3 ./test_180m_flash.py \
  33 + --encoder ./encoder.onnx \
  34 + --decoder ./decoder.onnx \
  35 + --source-lang en \
  36 + --target-lang en \
  37 + --tokens ./tokens.txt \
  38 + --wav ./en.wav
  39 +
  40 +python3 ./test_180m_flash.py \
  41 + --encoder ./encoder.onnx \
  42 + --decoder ./decoder.onnx \
  43 + --source-lang en \
  44 + --target-lang de \
  45 + --tokens ./tokens.txt \
  46 + --wav ./en.wav
  47 +
  48 +python3 ./test_180m_flash.py \
  49 + --encoder ./encoder.onnx \
  50 + --decoder ./decoder.onnx \
  51 + --source-lang de \
  52 + --target-lang de \
  53 + --tokens ./tokens.txt \
  54 + --wav ./de.wav
  55 +
  56 +python3 ./test_180m_flash.py \
  57 + --encoder ./encoder.onnx \
  58 + --decoder ./decoder.onnx \
  59 + --source-lang de \
  60 + --target-lang en \
  61 + --tokens ./tokens.txt \
  62 + --wav ./de.wav
  63 +
  64 +
  65 +log "-----int8------"
  66 +
  67 +python3 ./test_180m_flash.py \
  68 + --encoder ./encoder.int8.onnx \
  69 + --decoder ./decoder.fp16.onnx \
  70 + --source-lang en \
  71 + --target-lang en \
  72 + --tokens ./tokens.txt \
  73 + --wav ./en.wav
  74 +
  75 +python3 ./test_180m_flash.py \
  76 + --encoder ./encoder.int8.onnx \
  77 + --decoder ./decoder.fp16.onnx \
  78 + --source-lang en \
  79 + --target-lang de \
  80 + --tokens ./tokens.txt \
  81 + --wav ./en.wav
  82 +
  83 +python3 ./test_180m_flash.py \
  84 + --encoder ./encoder.int8.onnx \
  85 + --decoder ./decoder.fp16.onnx \
  86 + --source-lang de \
  87 + --target-lang de \
  88 + --tokens ./tokens.txt \
  89 + --wav ./de.wav
  90 +
  91 +python3 ./test_180m_flash.py \
  92 + --encoder ./encoder.int8.onnx \
  93 + --decoder ./decoder.fp16.onnx \
  94 + --source-lang de \
  95 + --target-lang en \
  96 + --tokens ./tokens.txt \
  97 + --wav ./de.wav
  98 +
  99 +log "-----fp16------"
  100 +
  101 +python3 ./test_180m_flash.py \
  102 + --encoder ./encoder.fp16.onnx \
  103 + --decoder ./decoder.fp16.onnx \
  104 + --source-lang en \
  105 + --target-lang en \
  106 + --tokens ./tokens.txt \
  107 + --wav ./en.wav
  108 +
  109 +python3 ./test_180m_flash.py \
  110 + --encoder ./encoder.fp16.onnx \
  111 + --decoder ./decoder.fp16.onnx \
  112 + --source-lang en \
  113 + --target-lang de \
  114 + --tokens ./tokens.txt \
  115 + --wav ./en.wav
  116 +
  117 +python3 ./test_180m_flash.py \
  118 + --encoder ./encoder.fp16.onnx \
  119 + --decoder ./decoder.fp16.onnx \
  120 + --source-lang de \
  121 + --target-lang de \
  122 + --tokens ./tokens.txt \
  123 + --wav ./de.wav
  124 +
  125 +python3 ./test_180m_flash.py \
  126 + --encoder ./encoder.fp16.onnx \
  127 + --decoder ./decoder.fp16.onnx \
  128 + --source-lang de \
  129 + --target-lang en \
  130 + --tokens ./tokens.txt \
  131 + --wav ./de.wav
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +import time
  6 +from pathlib import Path
  7 +from typing import List
  8 +
  9 +import kaldi_native_fbank as knf
  10 +import librosa
  11 +import numpy as np
  12 +import onnxruntime as ort
  13 +import soundfile as sf
  14 +
  15 +
  16 +def get_args():
  17 + parser = argparse.ArgumentParser()
  18 + parser.add_argument(
  19 + "--encoder", type=str, required=True, help="Path to encoder.onnx"
  20 + )
  21 + parser.add_argument(
  22 + "--decoder", type=str, required=True, help="Path to decoder.onnx"
  23 + )
  24 +
  25 + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
  26 +
  27 + parser.add_argument(
  28 + "--source-lang",
  29 + type=str,
  30 + help="Language of the input wav. Valid values are: en, de, es, fr",
  31 + )
  32 + parser.add_argument(
  33 + "--target-lang",
  34 + type=str,
  35 + help="Language of the recognition result. Valid values are: en, de, es, fr",
  36 + )
  37 + parser.add_argument(
  38 + "--use-pnc",
  39 + type=int,
  40 + default=1,
  41 + help="1 to enable cases and punctuations. 0 to disable that",
  42 + )
  43 +
  44 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  45 +
  46 + return parser.parse_args()
  47 +
  48 +
  49 +def display(sess, model):
  50 + print(f"=========={model} Input==========")
  51 + for i in sess.get_inputs():
  52 + print(i)
  53 + print(f"=========={model }Output==========")
  54 + for i in sess.get_outputs():
  55 + print(i)
  56 +
  57 +
  58 +class OnnxModel:
  59 + def __init__(
  60 + self,
  61 + encoder: str,
  62 + decoder: str,
  63 + ):
  64 + self.init_encoder(encoder)
  65 + display(self.encoder, "encoder")
  66 +
  67 + self.init_decoder(decoder)
  68 + display(self.decoder, "decoder")
  69 +
  70 + def init_encoder(self, encoder):
  71 + session_opts = ort.SessionOptions()
  72 + session_opts.inter_op_num_threads = 1
  73 + session_opts.intra_op_num_threads = 1
  74 +
  75 + self.encoder = ort.InferenceSession(
  76 + encoder,
  77 + sess_options=session_opts,
  78 + providers=["CPUExecutionProvider"],
  79 + )
  80 +
  81 + meta = self.encoder.get_modelmeta().custom_metadata_map
  82 + # self.normalize_type = meta["normalize_type"]
  83 + self.normalize_type = "per_feature"
  84 + print(meta)
  85 +
  86 + def init_decoder(self, decoder):
  87 + session_opts = ort.SessionOptions()
  88 + session_opts.inter_op_num_threads = 1
  89 + session_opts.intra_op_num_threads = 1
  90 +
  91 + self.decoder = ort.InferenceSession(
  92 + decoder,
  93 + sess_options=session_opts,
  94 + providers=["CPUExecutionProvider"],
  95 + )
  96 +
  97 + def run_encoder(self, x: np.ndarray, x_lens: np.ndarray):
  98 + """
  99 + Args:
  100 + x: (N, T, C), np.float
  101 + x_lens: (N,), np.int64
  102 + Returns:
  103 + enc_states: (N, T, C)
  104 + enc_lens: (N,), np.int64
  105 + enc_masks: (N, T), np.bool
  106 + """
  107 + enc_states, enc_lens, enc_masks = self.encoder.run(
  108 + [
  109 + self.encoder.get_outputs()[0].name,
  110 + self.encoder.get_outputs()[1].name,
  111 + self.encoder.get_outputs()[2].name,
  112 + ],
  113 + {
  114 + self.encoder.get_inputs()[0].name: x,
  115 + self.encoder.get_inputs()[1].name: x_lens,
  116 + },
  117 + )
  118 + return enc_states, enc_lens, enc_masks
  119 +
  120 + def run_decoder(
  121 + self,
  122 + decoder_input_ids: np.ndarray,
  123 + decoder_mems_list: List[np.ndarray],
  124 + enc_states: np.ndarray,
  125 + enc_mask: np.ndarray,
  126 + ):
  127 + """
  128 + Args:
  129 + decoder_input_ids: (N, num_tokens), int32
  130 + decoder_mems_list: a list of tensors, each of which is (N, num_tokens, C)
  131 + enc_states: (N, T, C), float
  132 + enc_mask: (N, T), bool
  133 + Returns:
  134 + logits: (1, 1, vocab_size), float
  135 + new_decoder_mems_list:
  136 + """
  137 + (logits, *new_decoder_mems_list) = self.decoder.run(
  138 + [
  139 + self.decoder.get_outputs()[0].name,
  140 + self.decoder.get_outputs()[1].name,
  141 + self.decoder.get_outputs()[2].name,
  142 + self.decoder.get_outputs()[3].name,
  143 + self.decoder.get_outputs()[4].name,
  144 + self.decoder.get_outputs()[5].name,
  145 + self.decoder.get_outputs()[6].name,
  146 + ],
  147 + {
  148 + self.decoder.get_inputs()[0].name: decoder_input_ids,
  149 + self.decoder.get_inputs()[1].name: decoder_mems_list[0],
  150 + self.decoder.get_inputs()[2].name: decoder_mems_list[1],
  151 + self.decoder.get_inputs()[3].name: decoder_mems_list[2],
  152 + self.decoder.get_inputs()[4].name: decoder_mems_list[3],
  153 + self.decoder.get_inputs()[5].name: decoder_mems_list[4],
  154 + self.decoder.get_inputs()[6].name: decoder_mems_list[5],
  155 + self.decoder.get_inputs()[7].name: enc_states,
  156 + self.decoder.get_inputs()[8].name: enc_mask,
  157 + },
  158 + )
  159 + return logits, new_decoder_mems_list
  160 +
  161 +
  162 +def create_fbank():
  163 + opts = knf.FbankOptions()
  164 + opts.frame_opts.dither = 0
  165 + opts.frame_opts.remove_dc_offset = False
  166 + opts.frame_opts.window_type = "hann"
  167 +
  168 + opts.mel_opts.low_freq = 0
  169 + opts.mel_opts.num_bins = 128
  170 +
  171 + opts.mel_opts.is_librosa = True
  172 +
  173 + fbank = knf.OnlineFbank(opts)
  174 + return fbank
  175 +
  176 +
  177 +def compute_features(audio, fbank):
  178 + assert len(audio.shape) == 1, audio.shape
  179 + fbank.accept_waveform(16000, audio)
  180 + ans = []
  181 + processed = 0
  182 + while processed < fbank.num_frames_ready:
  183 + ans.append(np.array(fbank.get_frame(processed)))
  184 + processed += 1
  185 + ans = np.stack(ans)
  186 + return ans
  187 +
  188 +
  189 +def main():
  190 + args = get_args()
  191 + assert Path(args.encoder).is_file(), args.encoder
  192 + assert Path(args.decoder).is_file(), args.decoder
  193 + assert Path(args.tokens).is_file(), args.tokens
  194 + assert Path(args.wav).is_file(), args.wav
  195 +
  196 + print(vars(args))
  197 +
  198 + id2token = dict()
  199 + token2id = dict()
  200 + with open(args.tokens, encoding="utf-8") as f:
  201 + for line in f:
  202 + fields = line.split()
  203 + if len(fields) == 2:
  204 + t, idx = fields[0], int(fields[1])
  205 + if line[0] == " ":
  206 + t = " " + t
  207 + else:
  208 + t = " "
  209 + idx = int(fields[0])
  210 +
  211 + id2token[idx] = t
  212 + token2id[t] = idx
  213 +
  214 + model = OnnxModel(args.encoder, args.decoder)
  215 +
  216 + fbank = create_fbank()
  217 +
  218 + start = time.time()
  219 + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
  220 + audio = audio[:, 0] # only use the first channel
  221 + if sample_rate != 16000:
  222 + audio = librosa.resample(
  223 + audio,
  224 + orig_sr=sample_rate,
  225 + target_sr=16000,
  226 + )
  227 + sample_rate = 16000
  228 +
  229 + features = compute_features(audio, fbank)
  230 + if model.normalize_type != "":
  231 + assert model.normalize_type == "per_feature", model.normalize_type
  232 + mean = features.mean(axis=1, keepdims=True)
  233 + stddev = features.std(axis=1, keepdims=True) + 1e-5
  234 + features = (features - mean) / stddev
  235 +
  236 + features = np.expand_dims(features, axis=0)
  237 + # features.shape: (1, 291, 128)
  238 +
  239 + features_len = np.array([features.shape[1]], dtype=np.int64)
  240 +
  241 + enc_states, _, enc_masks = model.run_encoder(features, features_len)
  242 +
  243 + decoder_input_ids = []
  244 + decoder_input_ids.append(token2id["<|startofcontext|>"])
  245 + decoder_input_ids.append(token2id["<|startoftranscript|>"])
  246 + decoder_input_ids.append(token2id["<|emo:undefined|>"])
  247 + if args.source_lang in ("en", "es", "de", "fr"):
  248 + decoder_input_ids.append(token2id[f"<|{args.source_lang}|>"])
  249 + else:
  250 + decoder_input_ids.append(token2id[f"<|en|>"])
  251 +
  252 + if args.target_lang in ("en", "es", "de", "fr"):
  253 + decoder_input_ids.append(token2id[f"<|{args.target_lang}|>"])
  254 + else:
  255 + decoder_input_ids.append(token2id[f"<|en|>"])
  256 +
  257 + if args.use_pnc:
  258 + decoder_input_ids.append(token2id[f"<|pnc|>"])
  259 + else:
  260 + decoder_input_ids.append(token2id[f"<|nopnc|>"])
  261 +
  262 + decoder_input_ids.append(token2id[f"<|noitn|>"])
  263 + decoder_input_ids.append(token2id["<|notimestamp|>"])
  264 + decoder_input_ids.append(token2id["<|nodiarize|>"])
  265 +
  266 + decoder_input_ids.append(0)
  267 +
  268 + decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)]
  269 +
  270 + logits, decoder_mems_list = model.run_decoder(
  271 + np.array([decoder_input_ids], dtype=np.int32),
  272 + decoder_mems_list,
  273 + enc_states,
  274 + enc_masks,
  275 + )
  276 + tokens = [logits.argmax()]
  277 + print("decoder_input_ids", decoder_input_ids)
  278 + eos = token2id["<|endoftext|>"]
  279 +
  280 + for i in range(1, 200):
  281 + decoder_input_ids = [tokens[-1], i]
  282 + logits, decoder_mems_list = model.run_decoder(
  283 + np.array([decoder_input_ids], dtype=np.int32),
  284 + decoder_mems_list,
  285 + enc_states,
  286 + enc_masks,
  287 + )
  288 + t = logits.argmax()
  289 + if t == eos:
  290 + break
  291 + tokens.append(t)
  292 + print("len(tokens)", len(tokens))
  293 + print("tokens", tokens)
  294 + text = "".join([id2token[i] for i in tokens])
  295 + print("text:", text)
  296 +
  297 +
  298 +if __name__ == "__main__":
  299 + main()