Fangjun Kuang
Committed by GitHub

export parakeet-tdt-0.6b-v2 to sherpa-onnx (#2180)

  1 +name: export-nemo-parakeet-tdt-0.6b-v2
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - export-nemo-parakeet-tdt-0.6b-v2
  7 + workflow_dispatch:
  8 +
  9 +concurrency:
  10 + group: export-nemo-parakeet-tdt-0.6b-v2-${{ github.ref }}
  11 + cancel-in-progress: true
  12 +
  13 +jobs:
  14 + export-nemo-parakeet-tdt-0_6b-v2:
  15 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  16 + name: parakeet tdt 0.6b v2
  17 + runs-on: ${{ matrix.os }}
  18 + strategy:
  19 + fail-fast: false
  20 + matrix:
  21 + os: [macos-latest]
  22 + python-version: ["3.10"]
  23 +
  24 + steps:
  25 + - uses: actions/checkout@v4
  26 +
  27 + - name: Setup Python ${{ matrix.python-version }}
  28 + uses: actions/setup-python@v5
  29 + with:
  30 + python-version: ${{ matrix.python-version }}
  31 +
  32 + - name: Run
  33 + shell: bash
  34 + run: |
  35 + cd scripts/nemo/parakeet-tdt-0.6b-v2
  36 + ./run.sh
  37 +
  38 + ls -lh *.onnx
  39 + mv -v *.onnx ../../..
  40 + mv -v tokens.txt ../../..
  41 + mv 2086-149220-0033.wav ../../../0.wav
  42 +
  43 + - name: Collect files (fp32)
  44 + shell: bash
  45 + run: |
  46 + d=sherpa-onnx-nemo-parakeet-tdt-0.6b-v2
  47 + mkdir -p $d
  48 + cp encoder.int8.onnx $d
  49 + cp decoder.onnx $d
  50 + cp joiner.onnx $d
  51 + cp tokens.txt $d
  52 +
  53 + mkdir $d/test_wavs
  54 + cp 0.wav $d/test_wavs
  55 +
  56 + tar cjfv $d.tar.bz2 $d
  57 +
  58 + - name: Collect files (int8)
  59 + shell: bash
  60 + run: |
  61 + d=sherpa-onnx-nemo-parakeet-tdt-0.6b-v2-int8
  62 + mkdir -p $d
  63 + cp encoder.int8.onnx $d
  64 + cp decoder.int8.onnx $d
  65 + cp joiner.int8.onnx $d
  66 + cp tokens.txt $d
  67 +
  68 + mkdir $d/test_wavs
  69 + cp 0.wav $d/test_wavs
  70 +
  71 + tar cjfv $d.tar.bz2 $d
  72 +
  73 + - name: Collect files (fp16)
  74 + shell: bash
  75 + run: |
  76 + d=sherpa-onnx-nemo-parakeet-tdt-0.6b-v2-fp16
  77 + mkdir -p $d
  78 + cp encoder.fp16.onnx $d
  79 + cp decoder.fp16.onnx $d
  80 + cp joiner.fp16.onnx $d
  81 + cp tokens.txt $d
  82 +
  83 + mkdir $d/test_wavs
  84 + cp 0.wav $d/test_wavs
  85 +
  86 + tar cjfv $d.tar.bz2 $d
  87 +
  88 + - name: Publish to huggingface
  89 + env:
  90 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  91 + uses: nick-fields/retry@v3
  92 + with:
  93 + max_attempts: 20
  94 + timeout_seconds: 200
  95 + shell: bash
  96 + command: |
  97 + git config --global user.email "csukuangfj@gmail.com"
  98 + git config --global user.name "Fangjun Kuang"
  99 +
  100 + models=(
  101 + sherpa-onnx-nemo-parakeet-tdt-0.6b-v2
  102 + sherpa-onnx-nemo-parakeet-tdt-0.6b-v2-int8
  103 + sherpa-onnx-nemo-parakeet-tdt-0.6b-v2-fp16
  104 + )
  105 +
  106 + for m in ${models[@]}; do
  107 + rm -rf huggingface
  108 + export GIT_LFS_SKIP_SMUDGE=1
  109 + export GIT_CLONE_PROTECTION_ACTIVE=false
  110 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m huggingface
  111 + cp -av $m/* huggingface
  112 + cd huggingface
  113 + git lfs track "*.onnx"
  114 + git lfs track "*.wav"
  115 + git status
  116 + git add .
  117 + git status
  118 + git commit -m "first commit"
  119 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m main
  120 + cd ..
  121 + done
  122 +
  123 + - name: Release
  124 + uses: svenstaro/upload-release-action@v2
  125 + with:
  126 + file_glob: true
  127 + file: ./*.tar.bz2
  128 + overwrite: true
  129 + repo_name: k2-fsa/sherpa-onnx
  130 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  131 + tag: asr-models
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +from pathlib import Path
  5 +from typing import Dict
  6 +import os
  7 +
  8 +import nemo.collections.asr as nemo_asr
  9 +import onnx
  10 +import onnxmltools
  11 +import torch
  12 +from onnxmltools.utils.float16_converter import (
  13 + convert_float_to_float16,
  14 + convert_float_to_float16_model_path,
  15 +)
  16 +from onnxruntime.quantization import QuantType, quantize_dynamic
  17 +
  18 +
  19 +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
  20 + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
  21 + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
  22 + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
  23 +
  24 +
  25 +def export_onnx_fp16_large_2gb(onnx_fp32_path, onnx_fp16_path):
  26 + onnx_fp16_model = convert_float_to_float16_model_path(
  27 + onnx_fp32_path, keep_io_types=True
  28 + )
  29 + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
  30 +
  31 +
  32 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  33 + """Add meta data to an ONNX model. It is changed in-place.
  34 +
  35 + Args:
  36 + filename:
  37 + Filename of the ONNX model to be changed.
  38 + meta_data:
  39 + Key-value pairs.
  40 + """
  41 + model = onnx.load(filename)
  42 + while len(model.metadata_props):
  43 + model.metadata_props.pop()
  44 +
  45 + for key, value in meta_data.items():
  46 + meta = model.metadata_props.add()
  47 + meta.key = key
  48 + meta.value = str(value)
  49 +
  50 + onnx.save(model, filename)
  51 +
  52 +
  53 +@torch.no_grad()
  54 +def main():
  55 + asr_model = nemo_asr.models.ASRModel.from_pretrained(
  56 + model_name="nvidia/parakeet-tdt-0.6b-v2"
  57 + )
  58 +
  59 + asr_model.eval()
  60 +
  61 + with open("./tokens.txt", "w", encoding="utf-8") as f:
  62 + for i, s in enumerate(asr_model.joint.vocabulary):
  63 + f.write(f"{s} {i}\n")
  64 + f.write(f"<blk> {i+1}\n")
  65 + print("Saved to tokens.txt")
  66 +
  67 + asr_model.encoder.export("encoder.onnx")
  68 + asr_model.decoder.export("decoder.onnx")
  69 + asr_model.joint.export("joiner.onnx")
  70 + os.system("ls -lh *.onnx")
  71 +
  72 + normalize_type = asr_model.cfg.preprocessor.normalize
  73 + if normalize_type == "NA":
  74 + normalize_type = ""
  75 +
  76 + meta_data = {
  77 + "vocab_size": asr_model.decoder.vocab_size,
  78 + "normalize_type": normalize_type,
  79 + "pred_rnn_layers": asr_model.decoder.pred_rnn_layers,
  80 + "pred_hidden": asr_model.decoder.pred_hidden,
  81 + "subsampling_factor": 8,
  82 + "model_type": "EncDecRNNTBPEModel",
  83 + "version": "2",
  84 + "model_author": "NeMo",
  85 + "url": "https://huggingface.co/nvidia/parakeet-tdt-0.6b-v2",
  86 + "comment": "Only the transducer branch is exported",
  87 + "feat_dim": 128,
  88 + }
  89 +
  90 + for m in ["encoder", "decoder", "joiner"]:
  91 + quantize_dynamic(
  92 + model_input=f"./{m}.onnx",
  93 + model_output=f"./{m}.int8.onnx",
  94 + weight_type=QuantType.QUInt8 if m == "encoder" else QuantType.QInt8,
  95 + )
  96 + os.system("ls -lh *.onnx")
  97 +
  98 + if m == "encoder":
  99 + export_onnx_fp16_large_2gb(f"{m}.onnx", f"{m}.fp16.onnx")
  100 + else:
  101 + export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx")
  102 +
  103 + add_meta_data("encoder.int8.onnx", meta_data)
  104 + add_meta_data("encoder.fp16.onnx", meta_data)
  105 + print("meta_data", meta_data)
  106 +
  107 +
  108 +if __name__ == "__main__":
  109 + main()
  1 +#!/usr/bin/env bash
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +set -ex
  5 +
  6 +log() {
  7 + # This function is from espnet
  8 + local fname=${BASH_SOURCE[1]##*/}
  9 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  10 +}
  11 +
  12 +curl -SL -O https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav
  13 +
  14 +
  15 +
  16 +pip install \
  17 + nemo_toolkit['asr'] \
  18 + "numpy<2" \
  19 + ipython \
  20 + kaldi-native-fbank \
  21 + librosa \
  22 + onnx==1.17.0 \
  23 + onnxmltools \
  24 + onnxruntime==1.17.1 \
  25 + soundfile
  26 +
  27 +python3 ./export_onnx.py
  28 +ls -lh *.onnx
  29 +
  30 +echo "---fp32----"
  31 +python3 ./test_onnx.py \
  32 + --encoder ./encoder.int8.onnx \
  33 + --decoder ./decoder.onnx \
  34 + --joiner ./joiner.onnx \
  35 + --tokens ./tokens.txt \
  36 + --wav 2086-149220-0033.wav
  37 +
  38 +echo "---int8----"
  39 +python3 ./test_onnx.py \
  40 + --encoder ./encoder.int8.onnx \
  41 + --decoder ./decoder.int8.onnx \
  42 + --joiner ./joiner.int8.onnx \
  43 + --tokens ./tokens.txt \
  44 + --wav 2086-149220-0033.wav
  45 +
  46 +echo "---fp16----"
  47 +python3 ./test_onnx.py \
  48 + --encoder ./encoder.fp16.onnx \
  49 + --decoder ./decoder.fp16.onnx \
  50 + --joiner ./joiner.fp16.onnx \
  51 + --tokens ./tokens.txt \
  52 + --wav 2086-149220-0033.wav
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +import argparse
  4 +from pathlib import Path
  5 +
  6 +import kaldi_native_fbank as knf
  7 +import librosa
  8 +import numpy as np
  9 +import onnxruntime as ort
  10 +import soundfile as sf
  11 +import torch
  12 +import time
  13 +
  14 +
  15 +def get_args():
  16 + parser = argparse.ArgumentParser()
  17 + parser.add_argument(
  18 + "--encoder", type=str, required=True, help="Path to encoder.onnx"
  19 + )
  20 + parser.add_argument(
  21 + "--decoder", type=str, required=True, help="Path to decoder.onnx"
  22 + )
  23 + parser.add_argument("--joiner", type=str, required=True, help="Path to joiner.onnx")
  24 +
  25 + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
  26 +
  27 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  28 +
  29 + return parser.parse_args()
  30 +
  31 +
  32 +def create_fbank():
  33 + opts = knf.FbankOptions()
  34 + opts.frame_opts.dither = 0
  35 + opts.frame_opts.remove_dc_offset = False
  36 + opts.frame_opts.window_type = "hann"
  37 +
  38 + opts.mel_opts.low_freq = 0
  39 + opts.mel_opts.num_bins = 128
  40 +
  41 + opts.mel_opts.is_librosa = True
  42 +
  43 + fbank = knf.OnlineFbank(opts)
  44 + return fbank
  45 +
  46 +
  47 +def compute_features(audio, fbank):
  48 + assert len(audio.shape) == 1, audio.shape
  49 + fbank.accept_waveform(16000, audio)
  50 + ans = []
  51 + processed = 0
  52 + while processed < fbank.num_frames_ready:
  53 + ans.append(np.array(fbank.get_frame(processed)))
  54 + processed += 1
  55 + ans = np.stack(ans)
  56 + return ans
  57 +
  58 +
  59 +def display(sess, model):
  60 + print(f"=========={model} Input==========")
  61 + for i in sess.get_inputs():
  62 + print(i)
  63 + print(f"=========={model }Output==========")
  64 + for i in sess.get_outputs():
  65 + print(i)
  66 +
  67 +
  68 +class OnnxModel:
  69 + def __init__(
  70 + self,
  71 + encoder: str,
  72 + decoder: str,
  73 + joiner: str,
  74 + ):
  75 + self.init_encoder(encoder)
  76 + display(self.encoder, "encoder")
  77 + self.init_decoder(decoder)
  78 + display(self.decoder, "decoder")
  79 + self.init_joiner(joiner)
  80 + display(self.joiner, "joiner")
  81 +
  82 + def init_encoder(self, encoder):
  83 + session_opts = ort.SessionOptions()
  84 + session_opts.inter_op_num_threads = 1
  85 + session_opts.intra_op_num_threads = 1
  86 +
  87 + self.encoder = ort.InferenceSession(
  88 + encoder,
  89 + sess_options=session_opts,
  90 + providers=["CPUExecutionProvider"],
  91 + )
  92 +
  93 + meta = self.encoder.get_modelmeta().custom_metadata_map
  94 + self.normalize_type = meta["normalize_type"]
  95 + print(meta)
  96 +
  97 + self.pred_rnn_layers = int(meta["pred_rnn_layers"])
  98 + self.pred_hidden = int(meta["pred_hidden"])
  99 +
  100 + def init_decoder(self, decoder):
  101 + session_opts = ort.SessionOptions()
  102 + session_opts.inter_op_num_threads = 1
  103 + session_opts.intra_op_num_threads = 1
  104 +
  105 + self.decoder = ort.InferenceSession(
  106 + decoder,
  107 + sess_options=session_opts,
  108 + providers=["CPUExecutionProvider"],
  109 + )
  110 +
  111 + def init_joiner(self, joiner):
  112 + session_opts = ort.SessionOptions()
  113 + session_opts.inter_op_num_threads = 1
  114 + session_opts.intra_op_num_threads = 1
  115 +
  116 + self.joiner = ort.InferenceSession(
  117 + joiner,
  118 + sess_options=session_opts,
  119 + providers=["CPUExecutionProvider"],
  120 + )
  121 +
  122 + def get_decoder_state(self):
  123 + batch_size = 1
  124 + state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
  125 + state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy()
  126 + return state0, state1
  127 +
  128 + def run_encoder(self, x: np.ndarray):
  129 + # x: (T, C)
  130 + x = torch.from_numpy(x)
  131 + x = x.t().unsqueeze(0)
  132 + # x: [1, C, T]
  133 + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
  134 +
  135 + (encoder_out, out_len) = self.encoder.run(
  136 + [
  137 + self.encoder.get_outputs()[0].name,
  138 + self.encoder.get_outputs()[1].name,
  139 + ],
  140 + {
  141 + self.encoder.get_inputs()[0].name: x.numpy(),
  142 + self.encoder.get_inputs()[1].name: x_lens.numpy(),
  143 + },
  144 + )
  145 + # [batch_size, dim, T]
  146 + return encoder_out
  147 +
  148 + def run_decoder(
  149 + self,
  150 + token: int,
  151 + state0: np.ndarray,
  152 + state1: np.ndarray,
  153 + ):
  154 + target = torch.tensor([[token]], dtype=torch.int32).numpy()
  155 + target_len = torch.tensor([1], dtype=torch.int32).numpy()
  156 +
  157 + (decoder_out, decoder_out_length, state0_next, state1_next,) = self.decoder.run(
  158 + [
  159 + self.decoder.get_outputs()[0].name,
  160 + self.decoder.get_outputs()[1].name,
  161 + self.decoder.get_outputs()[2].name,
  162 + self.decoder.get_outputs()[3].name,
  163 + ],
  164 + {
  165 + self.decoder.get_inputs()[0].name: target,
  166 + self.decoder.get_inputs()[1].name: target_len,
  167 + self.decoder.get_inputs()[2].name: state0,
  168 + self.decoder.get_inputs()[3].name: state1,
  169 + },
  170 + )
  171 + return decoder_out, state0_next, state1_next
  172 +
  173 + def run_joiner(
  174 + self,
  175 + encoder_out: np.ndarray,
  176 + decoder_out: np.ndarray,
  177 + ):
  178 + # encoder_out: [batch_size, dim, 1]
  179 + # decoder_out: [batch_size, dim, 1]
  180 + logit = self.joiner.run(
  181 + [
  182 + self.joiner.get_outputs()[0].name,
  183 + ],
  184 + {
  185 + self.joiner.get_inputs()[0].name: encoder_out,
  186 + self.joiner.get_inputs()[1].name: decoder_out,
  187 + },
  188 + )[0]
  189 + # logit: [batch_size, 1, 1, vocab_size]
  190 + return logit
  191 +
  192 +
  193 +def main():
  194 + args = get_args()
  195 + assert Path(args.encoder).is_file(), args.encoder
  196 + assert Path(args.decoder).is_file(), args.decoder
  197 + assert Path(args.joiner).is_file(), args.joiner
  198 + assert Path(args.tokens).is_file(), args.tokens
  199 + assert Path(args.wav).is_file(), args.wav
  200 +
  201 + print(vars(args))
  202 +
  203 + model = OnnxModel(args.encoder, args.decoder, args.joiner)
  204 +
  205 + id2token = dict()
  206 + with open(args.tokens, encoding="utf-8") as f:
  207 + for line in f:
  208 + t, idx = line.split()
  209 + id2token[int(idx)] = t
  210 +
  211 + start = time.time()
  212 + fbank = create_fbank()
  213 + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
  214 + audio = audio[:, 0] # only use the first channel
  215 + if sample_rate != 16000:
  216 + audio = librosa.resample(
  217 + audio,
  218 + orig_sr=sample_rate,
  219 + target_sr=16000,
  220 + )
  221 + sample_rate = 16000
  222 +
  223 + tail_padding = np.zeros(sample_rate * 2)
  224 +
  225 + audio = np.concatenate([audio, tail_padding])
  226 +
  227 + blank = len(id2token) - 1
  228 + ans = [blank]
  229 + state0, state1 = model.get_decoder_state()
  230 + decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1)
  231 +
  232 + features = compute_features(audio, fbank)
  233 + if model.normalize_type != "":
  234 + assert model.normalize_type == "per_feature", model.normalize_type
  235 + features = torch.from_numpy(features)
  236 + mean = features.mean(dim=1, keepdims=True)
  237 + stddev = features.std(dim=1, keepdims=True) + 1e-5
  238 + features = (features - mean) / stddev
  239 + features = features.numpy()
  240 + print(audio.shape)
  241 + print("features.shape", features.shape)
  242 +
  243 + encoder_out = model.run_encoder(features)
  244 + # encoder_out:[batch_size, dim, T)
  245 + for t in range(encoder_out.shape[2]):
  246 + encoder_out_t = encoder_out[:, :, t : t + 1]
  247 + logits = model.run_joiner(encoder_out_t, decoder_out)
  248 + logits = torch.from_numpy(logits)
  249 + logits = logits.squeeze()
  250 + idx = torch.argmax(logits, dim=-1).item()
  251 + if idx != blank:
  252 + ans.append(idx)
  253 + state0 = state0_next
  254 + state1 = state1_next
  255 + decoder_out, state0_next, state1_next = model.run_decoder(
  256 + ans[-1], state0, state1
  257 + )
  258 +
  259 + end = time.time()
  260 +
  261 + elapsed_seconds = end - start
  262 + audio_duration = audio.shape[0] / 16000
  263 + real_time_factor = elapsed_seconds / audio_duration
  264 +
  265 + ans = ans[1:] # remove the first blank
  266 + tokens = [id2token[i] for i in ans]
  267 + underline = "▁"
  268 + # underline = b"\xe2\x96\x81".decode()
  269 + text = "".join(tokens).replace(underline, " ").strip()
  270 +
  271 + print(ans)
  272 + print(args.wav)
  273 + print(text)
  274 + print(f"RTF: {real_time_factor}")
  275 +
  276 +
  277 +if __name__ == "__main__":
  278 + main()