Fangjun Kuang
Committed by GitHub

Export NeMo FastConformer Hybrid Transducer-CTC Large Streaming to ONNX. (#843)

  1 +name: export-nemo-speaker-verification-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export NeMo fast conformer
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install NeMo
  30 + shell: bash
  31 + run: |
  32 + BRANCH='main'
  33 + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]
  34 + pip install onnxruntime
  35 + pip install kaldi-native-fbank
  36 + pip install soundfile librosa
  37 +
  38 + - name: Run
  39 + shell: bash
  40 + run: |
  41 + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc
  42 + ./run-ctc.sh
  43 +
  44 + mv -v sherpa-onnx-nemo* ../../..
  45 +
  46 + - name: Download test waves
  47 + shell: bash
  48 + run: |
  49 + mkdir test_wavs
  50 + pushd test_wavs
  51 + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav
  52 + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav
  53 + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav
  54 + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt
  55 + popd
  56 +
  57 + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms
  58 + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms
  59 + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms
  60 +
  61 + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms
  62 + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms
  63 + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms
  64 +
  65 + - name: Release
  66 + uses: svenstaro/upload-release-action@v2
  67 + with:
  68 + file_glob: true
  69 + file: ./*.tar.bz2
  70 + overwrite: true
  71 + repo_name: k2-fsa/sherpa-onnx
  72 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  73 + tag: asr-models
  1 +# Introduction
  2 +
  3 +This folder contains scripts for exporting models from
  4 +
  5 + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms
  6 + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms
  7 + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms
  8 +
  9 +to `sherpa-onnx`.
  1 +#!/usr/bin/env python3
  2 +import argparse
  3 +from typing import Dict
  4 +
  5 +import nemo.collections.asr as nemo_asr
  6 +import onnx
  7 +import torch
  8 +
  9 +
  10 +def get_args():
  11 + parser = argparse.ArgumentParser()
  12 + parser.add_argument(
  13 + "--model",
  14 + type=str,
  15 + required=True,
  16 + choices=["80", "480", "1040"],
  17 + )
  18 + return parser.parse_args()
  19 +
  20 +
  21 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  22 + """Add meta data to an ONNX model. It is changed in-place.
  23 +
  24 + Args:
  25 + filename:
  26 + Filename of the ONNX model to be changed.
  27 + meta_data:
  28 + Key-value pairs.
  29 + """
  30 + model = onnx.load(filename)
  31 + while len(model.metadata_props):
  32 + model.metadata_props.pop()
  33 +
  34 + for key, value in meta_data.items():
  35 + meta = model.metadata_props.add()
  36 + meta.key = key
  37 + meta.value = str(value)
  38 +
  39 + onnx.save(model, filename)
  40 +
  41 +
  42 +@torch.no_grad()
  43 +def main():
  44 + args = get_args()
  45 + model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms"
  46 +
  47 + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
  48 +
  49 + with open("./tokens.txt", "w", encoding="utf-8") as f:
  50 + for i, s in enumerate(asr_model.joint.vocabulary):
  51 + f.write(f"{s} {i}\n")
  52 + f.write(f"<blk> {i+1}\n")
  53 + print("Saved to tokens.txt")
  54 +
  55 + decoder_type = "ctc"
  56 + asr_model.change_decoding_strategy(decoder_type=decoder_type)
  57 + asr_model.eval()
  58 +
  59 + assert asr_model.encoder.streaming_cfg is not None
  60 + if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list):
  61 + chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1]
  62 + else:
  63 + chunk_size = asr_model.encoder.streaming_cfg.chunk_size
  64 +
  65 + if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list):
  66 + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1]
  67 + else:
  68 + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size
  69 + window_size = chunk_size + pre_encode_cache_size
  70 +
  71 + print("chunk_size", chunk_size)
  72 + print("pre_encode_cache_size", pre_encode_cache_size)
  73 + print("window_size", window_size)
  74 +
  75 + chunk_shift = chunk_size
  76 +
  77 + # cache_last_channel: (batch_size, dim1, dim2, dim3)
  78 + cache_last_channel_dim1 = len(asr_model.encoder.layers)
  79 + cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size
  80 + cache_last_channel_dim3 = asr_model.encoder.d_model
  81 +
  82 + # cache_last_time: (batch_size, dim1, dim2, dim3)
  83 + cache_last_time_dim1 = len(asr_model.encoder.layers)
  84 + cache_last_time_dim2 = asr_model.encoder.d_model
  85 + cache_last_time_dim3 = asr_model.encoder.conv_context_size[0]
  86 +
  87 + asr_model.set_export_config({"decoder_type": "ctc", "cache_support": True})
  88 +
  89 + filename = "model.onnx"
  90 +
  91 + asr_model.export(filename)
  92 +
  93 + meta_data = {
  94 + "vocab_size": asr_model.decoder.vocab_size,
  95 + "window_size": window_size,
  96 + "chunk_shift": chunk_shift,
  97 + "normalize_type": "None",
  98 + "cache_last_channel_dim1": cache_last_channel_dim1,
  99 + "cache_last_channel_dim2": cache_last_channel_dim2,
  100 + "cache_last_channel_dim3": cache_last_channel_dim3,
  101 + "cache_last_time_dim1": cache_last_time_dim1,
  102 + "cache_last_time_dim2": cache_last_time_dim2,
  103 + "cache_last_time_dim3": cache_last_time_dim3,
  104 + "subsampling_factor": 8,
  105 + "model_type": "EncDecHybridRNNTCTCBPEModel",
  106 + "version": "1",
  107 + "model_author": "NeMo",
  108 + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}",
  109 + "comment": "Only the CTC branch is exported",
  110 + }
  111 + add_meta_data(filename, meta_data)
  112 +
  113 + print(meta_data)
  114 +
  115 +
  116 +if __name__ == "__main__":
  117 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +if [ ! -e ./0.wav ]; then
  6 + # curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
  7 + curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav
  8 +fi
  9 +
  10 +ms=(
  11 +80
  12 +480
  13 +1040
  14 +)
  15 +
  16 +for m in ${ms[@]}; do
  17 + ./export-onnx-ctc.py --model $m
  18 + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms
  19 + if [ ! -f $d/model.onnx ]; then
  20 + mkdir -p $d
  21 + mv -v model.onnx $d/
  22 + mv -v tokens.txt $d/
  23 + ls -lh $d
  24 + fi
  25 +done
  26 +
  27 +# Now test the exported models
  28 +
  29 +for m in ${ms[@]}; do
  30 + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms
  31 + python3 ./test-onnx-ctc.py \
  32 + --model $d/model.onnx \
  33 + --tokens $d/tokens.txt \
  34 + --wav ./0.wav
  35 +done
  1 +#!/usr/bin/env python3
  2 +
  3 +import argparse
  4 +from pathlib import Path
  5 +
  6 +import kaldi_native_fbank as knf
  7 +import numpy as np
  8 +import onnxruntime as ort
  9 +import torch
  10 +import soundfile as sf
  11 +import librosa
  12 +
  13 +
  14 +def get_args():
  15 + parser = argparse.ArgumentParser()
  16 + parser.add_argument("--model", type=str, required=True, help="Path to model.onnx")
  17 +
  18 + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt")
  19 +
  20 + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
  21 +
  22 + return parser.parse_args()
  23 +
  24 +
  25 +def create_fbank():
  26 + opts = knf.FbankOptions()
  27 + opts.frame_opts.dither = 0
  28 + opts.frame_opts.remove_dc_offset = False
  29 + opts.frame_opts.window_type = "hann"
  30 +
  31 + opts.mel_opts.low_freq = 0
  32 + opts.mel_opts.num_bins = 80
  33 +
  34 + opts.mel_opts.is_librosa = True
  35 +
  36 + fbank = knf.OnlineFbank(opts)
  37 + return fbank
  38 +
  39 +
  40 +def compute_features(audio, fbank):
  41 + assert len(audio.shape) == 1, audio.shape
  42 + fbank.accept_waveform(16000, audio)
  43 + ans = []
  44 + processed = 0
  45 + while processed < fbank.num_frames_ready:
  46 + ans.append(np.array(fbank.get_frame(processed)))
  47 + processed += 1
  48 + ans = np.stack(ans)
  49 + return ans
  50 +
  51 +
  52 +class OnnxModel:
  53 + def __init__(
  54 + self,
  55 + filename: str,
  56 + ):
  57 + session_opts = ort.SessionOptions()
  58 + session_opts.inter_op_num_threads = 1
  59 + session_opts.intra_op_num_threads = 1
  60 +
  61 + self.session_opts = session_opts
  62 +
  63 + self.model = ort.InferenceSession(
  64 + filename,
  65 + sess_options=self.session_opts,
  66 + providers=["CPUExecutionProvider"],
  67 + )
  68 +
  69 + meta = self.model.get_modelmeta().custom_metadata_map
  70 + print(meta)
  71 +
  72 + self.window_size = int(meta["window_size"])
  73 + self.chunk_shift = int(meta["chunk_shift"])
  74 +
  75 + self.cache_last_channel_dim1 = int(meta["cache_last_channel_dim1"])
  76 + self.cache_last_channel_dim2 = int(meta["cache_last_channel_dim2"])
  77 + self.cache_last_channel_dim3 = int(meta["cache_last_channel_dim3"])
  78 +
  79 + self.cache_last_time_dim1 = int(meta["cache_last_time_dim1"])
  80 + self.cache_last_time_dim2 = int(meta["cache_last_time_dim2"])
  81 + self.cache_last_time_dim3 = int(meta["cache_last_time_dim3"])
  82 +
  83 + self.init_cache_state()
  84 +
  85 + def init_cache_state(self):
  86 + self.cache_last_channel = torch.zeros(
  87 + 1,
  88 + self.cache_last_channel_dim1,
  89 + self.cache_last_channel_dim2,
  90 + self.cache_last_channel_dim3,
  91 + dtype=torch.float32,
  92 + ).numpy()
  93 +
  94 + self.cache_last_time = torch.zeros(
  95 + 1,
  96 + self.cache_last_time_dim1,
  97 + self.cache_last_time_dim2,
  98 + self.cache_last_time_dim3,
  99 + dtype=torch.float32,
  100 + ).numpy()
  101 +
  102 + self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy()
  103 +
  104 + def __call__(self, x: np.ndarray):
  105 + # x: (T, C)
  106 + x = torch.from_numpy(x)
  107 + x = x.t().unsqueeze(0)
  108 + # x: [1, C, T]
  109 + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64)
  110 +
  111 + (
  112 + log_probs,
  113 + log_probs_len,
  114 + cache_last_channel_next,
  115 + cache_last_time_next,
  116 + cache_last_channel_len_next,
  117 + ) = self.model.run(
  118 + [
  119 + self.model.get_outputs()[0].name,
  120 + self.model.get_outputs()[1].name,
  121 + self.model.get_outputs()[2].name,
  122 + self.model.get_outputs()[3].name,
  123 + self.model.get_outputs()[4].name,
  124 + ],
  125 + {
  126 + self.model.get_inputs()[0].name: x.numpy(),
  127 + self.model.get_inputs()[1].name: x_lens.numpy(),
  128 + self.model.get_inputs()[2].name: self.cache_last_channel,
  129 + self.model.get_inputs()[3].name: self.cache_last_time,
  130 + self.model.get_inputs()[4].name: self.cache_last_channel_len,
  131 + },
  132 + )
  133 + self.cache_last_channel = cache_last_channel_next
  134 + self.cache_last_time = cache_last_time_next
  135 + self.cache_last_channel_len = cache_last_channel_len_next
  136 +
  137 + # [T, vocab_size]
  138 + return torch.from_numpy(log_probs).squeeze(0)
  139 +
  140 +
  141 +def main():
  142 + args = get_args()
  143 + assert Path(args.model).is_file(), args.model
  144 + assert Path(args.tokens).is_file(), args.tokens
  145 + assert Path(args.wav).is_file(), args.wav
  146 +
  147 + print(vars(args))
  148 +
  149 + model = OnnxModel(args.model)
  150 +
  151 + id2token = dict()
  152 + with open(args.tokens, encoding="utf-8") as f:
  153 + for line in f:
  154 + t, idx = line.split()
  155 + id2token[int(idx)] = t
  156 +
  157 + fbank = create_fbank()
  158 + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True)
  159 + audio = audio[:, 0] # only use the first channel
  160 + if sample_rate != 16000:
  161 + audio = librosa.resample(
  162 + audio,
  163 + orig_sr=sample_rate,
  164 + target_sr=16000,
  165 + )
  166 + sample_rate = 16000
  167 +
  168 + window_size = model.window_size
  169 + chunk_shift = model.chunk_shift
  170 +
  171 + blank = len(id2token) - 1
  172 + prev = -1
  173 + ans = []
  174 +
  175 + features = compute_features(audio, fbank)
  176 + num_chunks = (features.shape[0] - window_size) // chunk_shift + 1
  177 + for i in range(num_chunks):
  178 + start = i * chunk_shift
  179 + end = start + window_size
  180 + chunk = features[start:end, :]
  181 +
  182 + log_probs = model(chunk)
  183 + ids = torch.argmax(log_probs, dim=1).tolist()
  184 + for i in ids:
  185 + if i != blank and i != prev:
  186 + ans.append(i)
  187 + prev = i
  188 +
  189 + tokens = [id2token[i] for i in ans]
  190 + underline = "▁"
  191 + # underline = b"\xe2\x96\x81".decode()
  192 + text = "".join(tokens).replace(underline, " ").strip()
  193 + print(args.wav)
  194 + print(text)
  195 +
  196 +
  197 +main()