Committed by
GitHub
Export nvidia/canary-180m-flash to sherpa-onnx (#2272)
正在显示
4 个修改的文件
包含
851 行增加
和
0 行删除
| 1 | +name: export-nemo-canary-180m-flash | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - export-nemo-canary | ||
| 7 | + workflow_dispatch: | ||
| 8 | + | ||
| 9 | +concurrency: | ||
| 10 | + group: export-nemo-canary-180m-flash-${{ github.ref }} | ||
| 11 | + cancel-in-progress: true | ||
| 12 | + | ||
| 13 | +jobs: | ||
| 14 | + export-nemo-canary-180m-flash: | ||
| 15 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 16 | + name: parakeet nemo canary 180m flash | ||
| 17 | + runs-on: ${{ matrix.os }} | ||
| 18 | + strategy: | ||
| 19 | + fail-fast: false | ||
| 20 | + matrix: | ||
| 21 | + os: [macos-latest] | ||
| 22 | + python-version: ["3.10"] | ||
| 23 | + | ||
| 24 | + steps: | ||
| 25 | + - uses: actions/checkout@v4 | ||
| 26 | + | ||
| 27 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 28 | + uses: actions/setup-python@v5 | ||
| 29 | + with: | ||
| 30 | + python-version: ${{ matrix.python-version }} | ||
| 31 | + | ||
| 32 | + - name: Run | ||
| 33 | + shell: bash | ||
| 34 | + run: | | ||
| 35 | + cd scripts/nemo/canary | ||
| 36 | + ./run_180m_flash.sh | ||
| 37 | + | ||
| 38 | + ls -lh *.onnx | ||
| 39 | + mv -v *.onnx ../../.. | ||
| 40 | + mv -v tokens.txt ../../.. | ||
| 41 | + mv de.wav ../../../ | ||
| 42 | + mv en.wav ../../../ | ||
| 43 | + | ||
| 44 | + - name: Collect files (fp32) | ||
| 45 | + shell: bash | ||
| 46 | + run: | | ||
| 47 | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr | ||
| 48 | + mkdir -p $d | ||
| 49 | + cp encoder.onnx $d | ||
| 50 | + cp decoder.onnx $d | ||
| 51 | + cp tokens.txt $d | ||
| 52 | + | ||
| 53 | + mkdir $d/test_wavs | ||
| 54 | + cp de.wav $d/test_wavs | ||
| 55 | + cp en.wav $d/test_wavs | ||
| 56 | + | ||
| 57 | + tar cjfv $d.tar.bz2 $d | ||
| 58 | + | ||
| 59 | + - name: Collect files (int8) | ||
| 60 | + shell: bash | ||
| 61 | + run: | | ||
| 62 | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 | ||
| 63 | + mkdir -p $d | ||
| 64 | + cp encoder.int8.onnx $d | ||
| 65 | + cp decoder.fp16.onnx $d | ||
| 66 | + cp tokens.txt $d | ||
| 67 | + | ||
| 68 | + mkdir $d/test_wavs | ||
| 69 | + cp de.wav $d/test_wavs | ||
| 70 | + cp en.wav $d/test_wavs | ||
| 71 | + | ||
| 72 | + tar cjfv $d.tar.bz2 $d | ||
| 73 | + | ||
| 74 | + - name: Collect files (fp16) | ||
| 75 | + shell: bash | ||
| 76 | + run: | | ||
| 77 | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 | ||
| 78 | + mkdir -p $d | ||
| 79 | + cp encoder.fp16.onnx $d | ||
| 80 | + cp decoder.fp16.onnx $d | ||
| 81 | + cp tokens.txt $d | ||
| 82 | + | ||
| 83 | + mkdir $d/test_wavs | ||
| 84 | + cp de.wav $d/test_wavs | ||
| 85 | + cp en.wav $d/test_wavs | ||
| 86 | + | ||
| 87 | + tar cjfv $d.tar.bz2 $d | ||
| 88 | + | ||
| 89 | + - name: Publish to huggingface | ||
| 90 | + env: | ||
| 91 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 92 | + uses: nick-fields/retry@v3 | ||
| 93 | + with: | ||
| 94 | + max_attempts: 20 | ||
| 95 | + timeout_seconds: 200 | ||
| 96 | + shell: bash | ||
| 97 | + command: | | ||
| 98 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 99 | + git config --global user.name "Fangjun Kuang" | ||
| 100 | + | ||
| 101 | + models=( | ||
| 102 | + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr | ||
| 103 | + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 | ||
| 104 | + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 | ||
| 105 | + ) | ||
| 106 | + | ||
| 107 | + for m in ${models[@]}; do | ||
| 108 | + rm -rf huggingface | ||
| 109 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 110 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 111 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m huggingface | ||
| 112 | + cp -av $m/* huggingface | ||
| 113 | + cd huggingface | ||
| 114 | + git lfs track "*.onnx" | ||
| 115 | + git lfs track "*.wav" | ||
| 116 | + git status | ||
| 117 | + git add . | ||
| 118 | + git status | ||
| 119 | + git commit -m "first commit" | ||
| 120 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m main | ||
| 121 | + cd .. | ||
| 122 | + done | ||
| 123 | + | ||
| 124 | + - name: Release | ||
| 125 | + uses: svenstaro/upload-release-action@v2 | ||
| 126 | + with: | ||
| 127 | + file_glob: true | ||
| 128 | + file: ./*.tar.bz2 | ||
| 129 | + overwrite: true | ||
| 130 | + repo_name: k2-fsa/sherpa-onnx | ||
| 131 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 132 | + tag: asr-models |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import os | ||
| 5 | +from typing import Tuple | ||
| 6 | + | ||
| 7 | +import nemo | ||
| 8 | +import onnxmltools | ||
| 9 | +import torch | ||
| 10 | +from nemo.collections.common.parts import NEG_INF | ||
| 11 | +from onnxmltools.utils.float16_converter import convert_float_to_float16 | ||
| 12 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 13 | + | ||
| 14 | +""" | ||
| 15 | +NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : | ||
| 16 | +Could not find an implementation for Trilu(14) node with name '/Trilu' | ||
| 17 | + | ||
| 18 | +See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631 | ||
| 19 | + | ||
| 20 | +So we use fixed_form_attention_mask() to replace | ||
| 21 | +the original form_attention_mask() | ||
| 22 | +""" | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def fixed_form_attention_mask(input_mask, diagonal=None): | ||
| 26 | + """ | ||
| 27 | + Fixed: Build attention mask with optional masking of future tokens we forbid | ||
| 28 | + to attend to (e.g. as it is in Transformer decoder). | ||
| 29 | + | ||
| 30 | + Args: | ||
| 31 | + input_mask: binary mask of size B x L with 1s corresponding to valid | ||
| 32 | + tokens and 0s corresponding to padding tokens | ||
| 33 | + diagonal: diagonal where triangular future mask starts | ||
| 34 | + None -- do not mask anything | ||
| 35 | + 0 -- regular translation or language modeling future masking | ||
| 36 | + 1 -- query stream masking as in XLNet architecture | ||
| 37 | + Returns: | ||
| 38 | + attention_mask: mask of size B x 1 x L x L with 0s corresponding to | ||
| 39 | + tokens we plan to attend to and -10000 otherwise | ||
| 40 | + """ | ||
| 41 | + | ||
| 42 | + if input_mask is None: | ||
| 43 | + return None | ||
| 44 | + attn_shape = (1, input_mask.shape[1], input_mask.shape[1]) | ||
| 45 | + attn_mask = input_mask.to(dtype=bool).unsqueeze(1) | ||
| 46 | + if diagonal is not None: | ||
| 47 | + future_mask = torch.tril( | ||
| 48 | + torch.ones( | ||
| 49 | + attn_shape, | ||
| 50 | + dtype=torch.int64, # it was torch.bool | ||
| 51 | + # but onnxruntime does not support torch.int32 or torch.bool | ||
| 52 | + # in torch.tril | ||
| 53 | + device=input_mask.device, | ||
| 54 | + ), | ||
| 55 | + diagonal, | ||
| 56 | + ).bool() | ||
| 57 | + attn_mask = attn_mask & future_mask | ||
| 58 | + attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF | ||
| 59 | + return attention_mask.unsqueeze(1) | ||
| 60 | + | ||
| 61 | + | ||
| 62 | +nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask | ||
| 63 | + | ||
| 64 | +from nemo.collections.asr.models import EncDecMultiTaskModel | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): | ||
| 68 | + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) | ||
| 69 | + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) | ||
| 70 | + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) | ||
| 71 | + | ||
| 72 | + | ||
| 73 | +def lens_to_mask(lens, max_length): | ||
| 74 | + """ | ||
| 75 | + Create a mask from a tensor of lengths. | ||
| 76 | + """ | ||
| 77 | + batch_size = lens.shape[0] | ||
| 78 | + arange = torch.arange(max_length, device=lens.device) | ||
| 79 | + mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1) | ||
| 80 | + return mask | ||
| 81 | + | ||
| 82 | + | ||
| 83 | +class EncoderWrapper(torch.nn.Module): | ||
| 84 | + def __init__(self, m): | ||
| 85 | + super().__init__() | ||
| 86 | + self.encoder = m.encoder | ||
| 87 | + self.encoder_decoder_proj = m.encoder_decoder_proj | ||
| 88 | + | ||
| 89 | + def forward( | ||
| 90 | + self, x: torch.Tensor, x_len: torch.Tensor | ||
| 91 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| 92 | + """ | ||
| 93 | + Args: | ||
| 94 | + x: (N, T, C) | ||
| 95 | + x_len: (N,) | ||
| 96 | + Returns: | ||
| 97 | + - enc_states: (N, T, C) | ||
| 98 | + - encoded_len: (N,) | ||
| 99 | + - enc_mask: (N, T) | ||
| 100 | + """ | ||
| 101 | + x = x.permute(0, 2, 1) | ||
| 102 | + # x: (N, C, T) | ||
| 103 | + encoded, encoded_len = self.encoder(audio_signal=x, length=x_len) | ||
| 104 | + | ||
| 105 | + enc_states = encoded.permute(0, 2, 1) | ||
| 106 | + | ||
| 107 | + enc_states = self.encoder_decoder_proj(enc_states) | ||
| 108 | + | ||
| 109 | + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]) | ||
| 110 | + | ||
| 111 | + return enc_states, encoded_len, enc_mask | ||
| 112 | + | ||
| 113 | + | ||
| 114 | +class DecoderWrapper(torch.nn.Module): | ||
| 115 | + def __init__(self, m): | ||
| 116 | + super().__init__() | ||
| 117 | + self.decoder = m.transf_decoder | ||
| 118 | + self.log_softmax = m.log_softmax | ||
| 119 | + | ||
| 120 | + # We use only greedy search, so there is no need to compute log_softmax | ||
| 121 | + self.log_softmax.mlp.log_softmax = False | ||
| 122 | + | ||
| 123 | + def forward( | ||
| 124 | + self, | ||
| 125 | + decoder_input_ids: torch.Tensor, | ||
| 126 | + decoder_mems_list_0: torch.Tensor, | ||
| 127 | + decoder_mems_list_1: torch.Tensor, | ||
| 128 | + decoder_mems_list_2: torch.Tensor, | ||
| 129 | + decoder_mems_list_3: torch.Tensor, | ||
| 130 | + decoder_mems_list_4: torch.Tensor, | ||
| 131 | + decoder_mems_list_5: torch.Tensor, | ||
| 132 | + enc_states: torch.Tensor, | ||
| 133 | + enc_mask: torch.Tensor, | ||
| 134 | + ): | ||
| 135 | + """ | ||
| 136 | + Args: | ||
| 137 | + decoder_input_ids: (N, num_tokens), torch.int32 | ||
| 138 | + decoder_mems_list_i: (N, num_tokens, 1024) | ||
| 139 | + enc_states: (N, T, 1024) | ||
| 140 | + enc_mask: (N, T) | ||
| 141 | + Returns: | ||
| 142 | + - logits: (N, 1, vocab_size) | ||
| 143 | + - decoder_mems_list_i: (N, num_tokens_2, 1024) | ||
| 144 | + """ | ||
| 145 | + pos = decoder_input_ids[0][-1].item() | ||
| 146 | + decoder_input_ids = decoder_input_ids[:, :-1] | ||
| 147 | + | ||
| 148 | + decoder_hidden_states = self.decoder.embedding.forward( | ||
| 149 | + decoder_input_ids, start_pos=pos | ||
| 150 | + ) | ||
| 151 | + decoder_input_mask = torch.ones_like(decoder_input_ids).float() | ||
| 152 | + | ||
| 153 | + decoder_mems_list = self.decoder.decoder.forward( | ||
| 154 | + decoder_hidden_states, | ||
| 155 | + decoder_input_mask, | ||
| 156 | + enc_states, | ||
| 157 | + enc_mask, | ||
| 158 | + [ | ||
| 159 | + decoder_mems_list_0, | ||
| 160 | + decoder_mems_list_1, | ||
| 161 | + decoder_mems_list_2, | ||
| 162 | + decoder_mems_list_3, | ||
| 163 | + decoder_mems_list_4, | ||
| 164 | + decoder_mems_list_5, | ||
| 165 | + ], | ||
| 166 | + return_mems=True, | ||
| 167 | + ) | ||
| 168 | + logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:]) | ||
| 169 | + | ||
| 170 | + return logits, decoder_mems_list | ||
| 171 | + | ||
| 172 | + | ||
| 173 | +def export_encoder(canary_model): | ||
| 174 | + encoder = EncoderWrapper(canary_model) | ||
| 175 | + x = torch.rand(1, 4000, 128) | ||
| 176 | + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) | ||
| 177 | + | ||
| 178 | + encoder_filename = "encoder.onnx" | ||
| 179 | + torch.onnx.export( | ||
| 180 | + encoder, | ||
| 181 | + (x, x_lens), | ||
| 182 | + encoder_filename, | ||
| 183 | + input_names=["x", "x_len"], | ||
| 184 | + output_names=["enc_states", "enc_len", "enc_mask"], | ||
| 185 | + opset_version=14, | ||
| 186 | + dynamic_axes={ | ||
| 187 | + "x": {0: "N", 1: "T"}, | ||
| 188 | + "x_len": {0: "N"}, | ||
| 189 | + "enc_states": {0: "N", 1: "T"}, | ||
| 190 | + "enc_len": {0: "N"}, | ||
| 191 | + "enc_mask": {0: "N", 1: "T"}, | ||
| 192 | + }, | ||
| 193 | + ) | ||
| 194 | + | ||
| 195 | + | ||
| 196 | +def export_decoder(canary_model): | ||
| 197 | + decoder = DecoderWrapper(canary_model) | ||
| 198 | + decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) | ||
| 199 | + | ||
| 200 | + decoder_mems_list_0 = torch.zeros(1, 1, 1024) | ||
| 201 | + decoder_mems_list_1 = torch.zeros(1, 1, 1024) | ||
| 202 | + decoder_mems_list_2 = torch.zeros(1, 1, 1024) | ||
| 203 | + decoder_mems_list_3 = torch.zeros(1, 1, 1024) | ||
| 204 | + decoder_mems_list_4 = torch.zeros(1, 1, 1024) | ||
| 205 | + decoder_mems_list_5 = torch.zeros(1, 1, 1024) | ||
| 206 | + | ||
| 207 | + enc_states = torch.zeros(1, 1000, 1024) | ||
| 208 | + enc_mask = torch.ones(1, 1000).bool() | ||
| 209 | + | ||
| 210 | + torch.onnx.export( | ||
| 211 | + decoder, | ||
| 212 | + ( | ||
| 213 | + decoder_input_ids, | ||
| 214 | + decoder_mems_list_0, | ||
| 215 | + decoder_mems_list_1, | ||
| 216 | + decoder_mems_list_2, | ||
| 217 | + decoder_mems_list_3, | ||
| 218 | + decoder_mems_list_4, | ||
| 219 | + decoder_mems_list_5, | ||
| 220 | + enc_states, | ||
| 221 | + enc_mask, | ||
| 222 | + ), | ||
| 223 | + "decoder.onnx", | ||
| 224 | + opset_version=14, | ||
| 225 | + input_names=[ | ||
| 226 | + "decoder_input_ids", | ||
| 227 | + "decoder_mems_list_0", | ||
| 228 | + "decoder_mems_list_1", | ||
| 229 | + "decoder_mems_list_2", | ||
| 230 | + "decoder_mems_list_3", | ||
| 231 | + "decoder_mems_list_4", | ||
| 232 | + "decoder_mems_list_5", | ||
| 233 | + "enc_states", | ||
| 234 | + "enc_mask", | ||
| 235 | + ], | ||
| 236 | + output_names=[ | ||
| 237 | + "logits", | ||
| 238 | + "next_decoder_mem_list_0", | ||
| 239 | + "next_decoder_mem_list_1", | ||
| 240 | + "next_decoder_mem_list_2", | ||
| 241 | + "next_decoder_mem_list_3", | ||
| 242 | + "next_decoder_mem_list_4", | ||
| 243 | + "next_decoder_mem_list_5", | ||
| 244 | + ], | ||
| 245 | + dynamic_axes={ | ||
| 246 | + "decoder_input_ids": {1: "num_tokens"}, | ||
| 247 | + "decoder_mems_list_0": {1: "num_tokens"}, | ||
| 248 | + "decoder_mems_list_1": {1: "num_tokens"}, | ||
| 249 | + "decoder_mems_list_2": {1: "num_tokens"}, | ||
| 250 | + "decoder_mems_list_3": {1: "num_tokens"}, | ||
| 251 | + "decoder_mems_list_4": {1: "num_tokens"}, | ||
| 252 | + "decoder_mems_list_5": {1: "num_tokens"}, | ||
| 253 | + "enc_states": {1: "T"}, | ||
| 254 | + "enc_mask": {1: "T"}, | ||
| 255 | + }, | ||
| 256 | + ) | ||
| 257 | + | ||
| 258 | + | ||
| 259 | +def export_tokens(canary_model): | ||
| 260 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 261 | + for i in range(canary_model.tokenizer.vocab_size): | ||
| 262 | + s = canary_model.tokenizer.ids_to_text([i]) | ||
| 263 | + f.write(f"{s} {i}\n") | ||
| 264 | + print("Saved to tokens.txt") | ||
| 265 | + | ||
| 266 | + | ||
| 267 | +@torch.no_grad() | ||
| 268 | +def main(): | ||
| 269 | + canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash") | ||
| 270 | + export_tokens(canary_model) | ||
| 271 | + export_encoder(canary_model) | ||
| 272 | + export_decoder(canary_model) | ||
| 273 | + | ||
| 274 | + for m in ["encoder", "decoder"]: | ||
| 275 | + if m == "encoder": | ||
| 276 | + # we don't quantize the decoder with int8 since the accuracy drops | ||
| 277 | + quantize_dynamic( | ||
| 278 | + model_input=f"./{m}.onnx", | ||
| 279 | + model_output=f"./{m}.int8.onnx", | ||
| 280 | + weight_type=QuantType.QUInt8, | ||
| 281 | + ) | ||
| 282 | + | ||
| 283 | + export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") | ||
| 284 | + | ||
| 285 | + os.system("ls -lh *.onnx") | ||
| 286 | + | ||
| 287 | + | ||
| 288 | +if __name__ == "__main__": | ||
| 289 | + main() |
scripts/nemo/canary/run_180m_flash.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +log() { | ||
| 7 | + # This function is from espnet | ||
| 8 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 9 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 10 | +} | ||
| 11 | + | ||
| 12 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/de.wav | ||
| 13 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/en.wav | ||
| 14 | + | ||
| 15 | +pip install \ | ||
| 16 | + nemo_toolkit['asr'] \ | ||
| 17 | + "numpy<2" \ | ||
| 18 | + ipython \ | ||
| 19 | + kaldi-native-fbank \ | ||
| 20 | + librosa \ | ||
| 21 | + onnx==1.17.0 \ | ||
| 22 | + onnxmltools \ | ||
| 23 | + onnxruntime==1.17.1 \ | ||
| 24 | + soundfile | ||
| 25 | + | ||
| 26 | +python3 ./export_onnx_180m_flash.py | ||
| 27 | +ls -lh *.onnx | ||
| 28 | + | ||
| 29 | + | ||
| 30 | +log "-----fp32------" | ||
| 31 | + | ||
| 32 | +python3 ./test_180m_flash.py \ | ||
| 33 | + --encoder ./encoder.onnx \ | ||
| 34 | + --decoder ./decoder.onnx \ | ||
| 35 | + --source-lang en \ | ||
| 36 | + --target-lang en \ | ||
| 37 | + --tokens ./tokens.txt \ | ||
| 38 | + --wav ./en.wav | ||
| 39 | + | ||
| 40 | +python3 ./test_180m_flash.py \ | ||
| 41 | + --encoder ./encoder.onnx \ | ||
| 42 | + --decoder ./decoder.onnx \ | ||
| 43 | + --source-lang en \ | ||
| 44 | + --target-lang de \ | ||
| 45 | + --tokens ./tokens.txt \ | ||
| 46 | + --wav ./en.wav | ||
| 47 | + | ||
| 48 | +python3 ./test_180m_flash.py \ | ||
| 49 | + --encoder ./encoder.onnx \ | ||
| 50 | + --decoder ./decoder.onnx \ | ||
| 51 | + --source-lang de \ | ||
| 52 | + --target-lang de \ | ||
| 53 | + --tokens ./tokens.txt \ | ||
| 54 | + --wav ./de.wav | ||
| 55 | + | ||
| 56 | +python3 ./test_180m_flash.py \ | ||
| 57 | + --encoder ./encoder.onnx \ | ||
| 58 | + --decoder ./decoder.onnx \ | ||
| 59 | + --source-lang de \ | ||
| 60 | + --target-lang en \ | ||
| 61 | + --tokens ./tokens.txt \ | ||
| 62 | + --wav ./de.wav | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +log "-----int8------" | ||
| 66 | + | ||
| 67 | +python3 ./test_180m_flash.py \ | ||
| 68 | + --encoder ./encoder.int8.onnx \ | ||
| 69 | + --decoder ./decoder.fp16.onnx \ | ||
| 70 | + --source-lang en \ | ||
| 71 | + --target-lang en \ | ||
| 72 | + --tokens ./tokens.txt \ | ||
| 73 | + --wav ./en.wav | ||
| 74 | + | ||
| 75 | +python3 ./test_180m_flash.py \ | ||
| 76 | + --encoder ./encoder.int8.onnx \ | ||
| 77 | + --decoder ./decoder.fp16.onnx \ | ||
| 78 | + --source-lang en \ | ||
| 79 | + --target-lang de \ | ||
| 80 | + --tokens ./tokens.txt \ | ||
| 81 | + --wav ./en.wav | ||
| 82 | + | ||
| 83 | +python3 ./test_180m_flash.py \ | ||
| 84 | + --encoder ./encoder.int8.onnx \ | ||
| 85 | + --decoder ./decoder.fp16.onnx \ | ||
| 86 | + --source-lang de \ | ||
| 87 | + --target-lang de \ | ||
| 88 | + --tokens ./tokens.txt \ | ||
| 89 | + --wav ./de.wav | ||
| 90 | + | ||
| 91 | +python3 ./test_180m_flash.py \ | ||
| 92 | + --encoder ./encoder.int8.onnx \ | ||
| 93 | + --decoder ./decoder.fp16.onnx \ | ||
| 94 | + --source-lang de \ | ||
| 95 | + --target-lang en \ | ||
| 96 | + --tokens ./tokens.txt \ | ||
| 97 | + --wav ./de.wav | ||
| 98 | + | ||
| 99 | +log "-----fp16------" | ||
| 100 | + | ||
| 101 | +python3 ./test_180m_flash.py \ | ||
| 102 | + --encoder ./encoder.fp16.onnx \ | ||
| 103 | + --decoder ./decoder.fp16.onnx \ | ||
| 104 | + --source-lang en \ | ||
| 105 | + --target-lang en \ | ||
| 106 | + --tokens ./tokens.txt \ | ||
| 107 | + --wav ./en.wav | ||
| 108 | + | ||
| 109 | +python3 ./test_180m_flash.py \ | ||
| 110 | + --encoder ./encoder.fp16.onnx \ | ||
| 111 | + --decoder ./decoder.fp16.onnx \ | ||
| 112 | + --source-lang en \ | ||
| 113 | + --target-lang de \ | ||
| 114 | + --tokens ./tokens.txt \ | ||
| 115 | + --wav ./en.wav | ||
| 116 | + | ||
| 117 | +python3 ./test_180m_flash.py \ | ||
| 118 | + --encoder ./encoder.fp16.onnx \ | ||
| 119 | + --decoder ./decoder.fp16.onnx \ | ||
| 120 | + --source-lang de \ | ||
| 121 | + --target-lang de \ | ||
| 122 | + --tokens ./tokens.txt \ | ||
| 123 | + --wav ./de.wav | ||
| 124 | + | ||
| 125 | +python3 ./test_180m_flash.py \ | ||
| 126 | + --encoder ./encoder.fp16.onnx \ | ||
| 127 | + --decoder ./decoder.fp16.onnx \ | ||
| 128 | + --source-lang de \ | ||
| 129 | + --target-lang en \ | ||
| 130 | + --tokens ./tokens.txt \ | ||
| 131 | + --wav ./de.wav |
scripts/nemo/canary/test_180m_flash.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +import time | ||
| 6 | +from pathlib import Path | ||
| 7 | +from typing import List | ||
| 8 | + | ||
| 9 | +import kaldi_native_fbank as knf | ||
| 10 | +import librosa | ||
| 11 | +import numpy as np | ||
| 12 | +import onnxruntime as ort | ||
| 13 | +import soundfile as sf | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +def get_args(): | ||
| 17 | + parser = argparse.ArgumentParser() | ||
| 18 | + parser.add_argument( | ||
| 19 | + "--encoder", type=str, required=True, help="Path to encoder.onnx" | ||
| 20 | + ) | ||
| 21 | + parser.add_argument( | ||
| 22 | + "--decoder", type=str, required=True, help="Path to decoder.onnx" | ||
| 23 | + ) | ||
| 24 | + | ||
| 25 | + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") | ||
| 26 | + | ||
| 27 | + parser.add_argument( | ||
| 28 | + "--source-lang", | ||
| 29 | + type=str, | ||
| 30 | + help="Language of the input wav. Valid values are: en, de, es, fr", | ||
| 31 | + ) | ||
| 32 | + parser.add_argument( | ||
| 33 | + "--target-lang", | ||
| 34 | + type=str, | ||
| 35 | + help="Language of the recognition result. Valid values are: en, de, es, fr", | ||
| 36 | + ) | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--use-pnc", | ||
| 39 | + type=int, | ||
| 40 | + default=1, | ||
| 41 | + help="1 to enable cases and punctuations. 0 to disable that", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 45 | + | ||
| 46 | + return parser.parse_args() | ||
| 47 | + | ||
| 48 | + | ||
| 49 | +def display(sess, model): | ||
| 50 | + print(f"=========={model} Input==========") | ||
| 51 | + for i in sess.get_inputs(): | ||
| 52 | + print(i) | ||
| 53 | + print(f"=========={model }Output==========") | ||
| 54 | + for i in sess.get_outputs(): | ||
| 55 | + print(i) | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +class OnnxModel: | ||
| 59 | + def __init__( | ||
| 60 | + self, | ||
| 61 | + encoder: str, | ||
| 62 | + decoder: str, | ||
| 63 | + ): | ||
| 64 | + self.init_encoder(encoder) | ||
| 65 | + display(self.encoder, "encoder") | ||
| 66 | + | ||
| 67 | + self.init_decoder(decoder) | ||
| 68 | + display(self.decoder, "decoder") | ||
| 69 | + | ||
| 70 | + def init_encoder(self, encoder): | ||
| 71 | + session_opts = ort.SessionOptions() | ||
| 72 | + session_opts.inter_op_num_threads = 1 | ||
| 73 | + session_opts.intra_op_num_threads = 1 | ||
| 74 | + | ||
| 75 | + self.encoder = ort.InferenceSession( | ||
| 76 | + encoder, | ||
| 77 | + sess_options=session_opts, | ||
| 78 | + providers=["CPUExecutionProvider"], | ||
| 79 | + ) | ||
| 80 | + | ||
| 81 | + meta = self.encoder.get_modelmeta().custom_metadata_map | ||
| 82 | + # self.normalize_type = meta["normalize_type"] | ||
| 83 | + self.normalize_type = "per_feature" | ||
| 84 | + print(meta) | ||
| 85 | + | ||
| 86 | + def init_decoder(self, decoder): | ||
| 87 | + session_opts = ort.SessionOptions() | ||
| 88 | + session_opts.inter_op_num_threads = 1 | ||
| 89 | + session_opts.intra_op_num_threads = 1 | ||
| 90 | + | ||
| 91 | + self.decoder = ort.InferenceSession( | ||
| 92 | + decoder, | ||
| 93 | + sess_options=session_opts, | ||
| 94 | + providers=["CPUExecutionProvider"], | ||
| 95 | + ) | ||
| 96 | + | ||
| 97 | + def run_encoder(self, x: np.ndarray, x_lens: np.ndarray): | ||
| 98 | + """ | ||
| 99 | + Args: | ||
| 100 | + x: (N, T, C), np.float | ||
| 101 | + x_lens: (N,), np.int64 | ||
| 102 | + Returns: | ||
| 103 | + enc_states: (N, T, C) | ||
| 104 | + enc_lens: (N,), np.int64 | ||
| 105 | + enc_masks: (N, T), np.bool | ||
| 106 | + """ | ||
| 107 | + enc_states, enc_lens, enc_masks = self.encoder.run( | ||
| 108 | + [ | ||
| 109 | + self.encoder.get_outputs()[0].name, | ||
| 110 | + self.encoder.get_outputs()[1].name, | ||
| 111 | + self.encoder.get_outputs()[2].name, | ||
| 112 | + ], | ||
| 113 | + { | ||
| 114 | + self.encoder.get_inputs()[0].name: x, | ||
| 115 | + self.encoder.get_inputs()[1].name: x_lens, | ||
| 116 | + }, | ||
| 117 | + ) | ||
| 118 | + return enc_states, enc_lens, enc_masks | ||
| 119 | + | ||
| 120 | + def run_decoder( | ||
| 121 | + self, | ||
| 122 | + decoder_input_ids: np.ndarray, | ||
| 123 | + decoder_mems_list: List[np.ndarray], | ||
| 124 | + enc_states: np.ndarray, | ||
| 125 | + enc_mask: np.ndarray, | ||
| 126 | + ): | ||
| 127 | + """ | ||
| 128 | + Args: | ||
| 129 | + decoder_input_ids: (N, num_tokens), int32 | ||
| 130 | + decoder_mems_list: a list of tensors, each of which is (N, num_tokens, C) | ||
| 131 | + enc_states: (N, T, C), float | ||
| 132 | + enc_mask: (N, T), bool | ||
| 133 | + Returns: | ||
| 134 | + logits: (1, 1, vocab_size), float | ||
| 135 | + new_decoder_mems_list: | ||
| 136 | + """ | ||
| 137 | + (logits, *new_decoder_mems_list) = self.decoder.run( | ||
| 138 | + [ | ||
| 139 | + self.decoder.get_outputs()[0].name, | ||
| 140 | + self.decoder.get_outputs()[1].name, | ||
| 141 | + self.decoder.get_outputs()[2].name, | ||
| 142 | + self.decoder.get_outputs()[3].name, | ||
| 143 | + self.decoder.get_outputs()[4].name, | ||
| 144 | + self.decoder.get_outputs()[5].name, | ||
| 145 | + self.decoder.get_outputs()[6].name, | ||
| 146 | + ], | ||
| 147 | + { | ||
| 148 | + self.decoder.get_inputs()[0].name: decoder_input_ids, | ||
| 149 | + self.decoder.get_inputs()[1].name: decoder_mems_list[0], | ||
| 150 | + self.decoder.get_inputs()[2].name: decoder_mems_list[1], | ||
| 151 | + self.decoder.get_inputs()[3].name: decoder_mems_list[2], | ||
| 152 | + self.decoder.get_inputs()[4].name: decoder_mems_list[3], | ||
| 153 | + self.decoder.get_inputs()[5].name: decoder_mems_list[4], | ||
| 154 | + self.decoder.get_inputs()[6].name: decoder_mems_list[5], | ||
| 155 | + self.decoder.get_inputs()[7].name: enc_states, | ||
| 156 | + self.decoder.get_inputs()[8].name: enc_mask, | ||
| 157 | + }, | ||
| 158 | + ) | ||
| 159 | + return logits, new_decoder_mems_list | ||
| 160 | + | ||
| 161 | + | ||
| 162 | +def create_fbank(): | ||
| 163 | + opts = knf.FbankOptions() | ||
| 164 | + opts.frame_opts.dither = 0 | ||
| 165 | + opts.frame_opts.remove_dc_offset = False | ||
| 166 | + opts.frame_opts.window_type = "hann" | ||
| 167 | + | ||
| 168 | + opts.mel_opts.low_freq = 0 | ||
| 169 | + opts.mel_opts.num_bins = 128 | ||
| 170 | + | ||
| 171 | + opts.mel_opts.is_librosa = True | ||
| 172 | + | ||
| 173 | + fbank = knf.OnlineFbank(opts) | ||
| 174 | + return fbank | ||
| 175 | + | ||
| 176 | + | ||
| 177 | +def compute_features(audio, fbank): | ||
| 178 | + assert len(audio.shape) == 1, audio.shape | ||
| 179 | + fbank.accept_waveform(16000, audio) | ||
| 180 | + ans = [] | ||
| 181 | + processed = 0 | ||
| 182 | + while processed < fbank.num_frames_ready: | ||
| 183 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 184 | + processed += 1 | ||
| 185 | + ans = np.stack(ans) | ||
| 186 | + return ans | ||
| 187 | + | ||
| 188 | + | ||
| 189 | +def main(): | ||
| 190 | + args = get_args() | ||
| 191 | + assert Path(args.encoder).is_file(), args.encoder | ||
| 192 | + assert Path(args.decoder).is_file(), args.decoder | ||
| 193 | + assert Path(args.tokens).is_file(), args.tokens | ||
| 194 | + assert Path(args.wav).is_file(), args.wav | ||
| 195 | + | ||
| 196 | + print(vars(args)) | ||
| 197 | + | ||
| 198 | + id2token = dict() | ||
| 199 | + token2id = dict() | ||
| 200 | + with open(args.tokens, encoding="utf-8") as f: | ||
| 201 | + for line in f: | ||
| 202 | + fields = line.split() | ||
| 203 | + if len(fields) == 2: | ||
| 204 | + t, idx = fields[0], int(fields[1]) | ||
| 205 | + if line[0] == " ": | ||
| 206 | + t = " " + t | ||
| 207 | + else: | ||
| 208 | + t = " " | ||
| 209 | + idx = int(fields[0]) | ||
| 210 | + | ||
| 211 | + id2token[idx] = t | ||
| 212 | + token2id[t] = idx | ||
| 213 | + | ||
| 214 | + model = OnnxModel(args.encoder, args.decoder) | ||
| 215 | + | ||
| 216 | + fbank = create_fbank() | ||
| 217 | + | ||
| 218 | + start = time.time() | ||
| 219 | + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) | ||
| 220 | + audio = audio[:, 0] # only use the first channel | ||
| 221 | + if sample_rate != 16000: | ||
| 222 | + audio = librosa.resample( | ||
| 223 | + audio, | ||
| 224 | + orig_sr=sample_rate, | ||
| 225 | + target_sr=16000, | ||
| 226 | + ) | ||
| 227 | + sample_rate = 16000 | ||
| 228 | + | ||
| 229 | + features = compute_features(audio, fbank) | ||
| 230 | + if model.normalize_type != "": | ||
| 231 | + assert model.normalize_type == "per_feature", model.normalize_type | ||
| 232 | + mean = features.mean(axis=1, keepdims=True) | ||
| 233 | + stddev = features.std(axis=1, keepdims=True) + 1e-5 | ||
| 234 | + features = (features - mean) / stddev | ||
| 235 | + | ||
| 236 | + features = np.expand_dims(features, axis=0) | ||
| 237 | + # features.shape: (1, 291, 128) | ||
| 238 | + | ||
| 239 | + features_len = np.array([features.shape[1]], dtype=np.int64) | ||
| 240 | + | ||
| 241 | + enc_states, _, enc_masks = model.run_encoder(features, features_len) | ||
| 242 | + | ||
| 243 | + decoder_input_ids = [] | ||
| 244 | + decoder_input_ids.append(token2id["<|startofcontext|>"]) | ||
| 245 | + decoder_input_ids.append(token2id["<|startoftranscript|>"]) | ||
| 246 | + decoder_input_ids.append(token2id["<|emo:undefined|>"]) | ||
| 247 | + if args.source_lang in ("en", "es", "de", "fr"): | ||
| 248 | + decoder_input_ids.append(token2id[f"<|{args.source_lang}|>"]) | ||
| 249 | + else: | ||
| 250 | + decoder_input_ids.append(token2id[f"<|en|>"]) | ||
| 251 | + | ||
| 252 | + if args.target_lang in ("en", "es", "de", "fr"): | ||
| 253 | + decoder_input_ids.append(token2id[f"<|{args.target_lang}|>"]) | ||
| 254 | + else: | ||
| 255 | + decoder_input_ids.append(token2id[f"<|en|>"]) | ||
| 256 | + | ||
| 257 | + if args.use_pnc: | ||
| 258 | + decoder_input_ids.append(token2id[f"<|pnc|>"]) | ||
| 259 | + else: | ||
| 260 | + decoder_input_ids.append(token2id[f"<|nopnc|>"]) | ||
| 261 | + | ||
| 262 | + decoder_input_ids.append(token2id[f"<|noitn|>"]) | ||
| 263 | + decoder_input_ids.append(token2id["<|notimestamp|>"]) | ||
| 264 | + decoder_input_ids.append(token2id["<|nodiarize|>"]) | ||
| 265 | + | ||
| 266 | + decoder_input_ids.append(0) | ||
| 267 | + | ||
| 268 | + decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] | ||
| 269 | + | ||
| 270 | + logits, decoder_mems_list = model.run_decoder( | ||
| 271 | + np.array([decoder_input_ids], dtype=np.int32), | ||
| 272 | + decoder_mems_list, | ||
| 273 | + enc_states, | ||
| 274 | + enc_masks, | ||
| 275 | + ) | ||
| 276 | + tokens = [logits.argmax()] | ||
| 277 | + print("decoder_input_ids", decoder_input_ids) | ||
| 278 | + eos = token2id["<|endoftext|>"] | ||
| 279 | + | ||
| 280 | + for i in range(1, 200): | ||
| 281 | + decoder_input_ids = [tokens[-1], i] | ||
| 282 | + logits, decoder_mems_list = model.run_decoder( | ||
| 283 | + np.array([decoder_input_ids], dtype=np.int32), | ||
| 284 | + decoder_mems_list, | ||
| 285 | + enc_states, | ||
| 286 | + enc_masks, | ||
| 287 | + ) | ||
| 288 | + t = logits.argmax() | ||
| 289 | + if t == eos: | ||
| 290 | + break | ||
| 291 | + tokens.append(t) | ||
| 292 | + print("len(tokens)", len(tokens)) | ||
| 293 | + print("tokens", tokens) | ||
| 294 | + text = "".join([id2token[i] for i in tokens]) | ||
| 295 | + print("text:", text) | ||
| 296 | + | ||
| 297 | + | ||
| 298 | +if __name__ == "__main__": | ||
| 299 | + main() |
-
请 注册 或 登录 后发表评论