Committed by
GitHub
Export NeMo FastConformer Hybrid Transducer-CTC Large Streaming to ONNX. (#843)
正在显示
5 个修改的文件
包含
431 行增加
和
0 行删除
| 1 | +name: export-nemo-speaker-verification-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export NeMo fast conformer | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install NeMo | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + BRANCH='main' | ||
| 33 | + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 34 | + pip install onnxruntime | ||
| 35 | + pip install kaldi-native-fbank | ||
| 36 | + pip install soundfile librosa | ||
| 37 | + | ||
| 38 | + - name: Run | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
| 42 | + ./run-ctc.sh | ||
| 43 | + | ||
| 44 | + mv -v sherpa-onnx-nemo* ../../.. | ||
| 45 | + | ||
| 46 | + - name: Download test waves | ||
| 47 | + shell: bash | ||
| 48 | + run: | | ||
| 49 | + mkdir test_wavs | ||
| 50 | + pushd test_wavs | ||
| 51 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav | ||
| 52 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav | ||
| 53 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav | ||
| 54 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | ||
| 55 | + popd | ||
| 56 | + | ||
| 57 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
| 58 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
| 59 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | ||
| 60 | + | ||
| 61 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
| 62 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
| 63 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | ||
| 64 | + | ||
| 65 | + - name: Release | ||
| 66 | + uses: svenstaro/upload-release-action@v2 | ||
| 67 | + with: | ||
| 68 | + file_glob: true | ||
| 69 | + file: ./*.tar.bz2 | ||
| 70 | + overwrite: true | ||
| 71 | + repo_name: k2-fsa/sherpa-onnx | ||
| 72 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 73 | + tag: asr-models |
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains scripts for exporting models from | ||
| 4 | + | ||
| 5 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms | ||
| 6 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms | ||
| 7 | + - https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms | ||
| 8 | + | ||
| 9 | +to `sherpa-onnx`. |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +import argparse | ||
| 3 | +from typing import Dict | ||
| 4 | + | ||
| 5 | +import nemo.collections.asr as nemo_asr | ||
| 6 | +import onnx | ||
| 7 | +import torch | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def get_args(): | ||
| 11 | + parser = argparse.ArgumentParser() | ||
| 12 | + parser.add_argument( | ||
| 13 | + "--model", | ||
| 14 | + type=str, | ||
| 15 | + required=True, | ||
| 16 | + choices=["80", "480", "1040"], | ||
| 17 | + ) | ||
| 18 | + return parser.parse_args() | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 22 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 23 | + | ||
| 24 | + Args: | ||
| 25 | + filename: | ||
| 26 | + Filename of the ONNX model to be changed. | ||
| 27 | + meta_data: | ||
| 28 | + Key-value pairs. | ||
| 29 | + """ | ||
| 30 | + model = onnx.load(filename) | ||
| 31 | + while len(model.metadata_props): | ||
| 32 | + model.metadata_props.pop() | ||
| 33 | + | ||
| 34 | + for key, value in meta_data.items(): | ||
| 35 | + meta = model.metadata_props.add() | ||
| 36 | + meta.key = key | ||
| 37 | + meta.value = str(value) | ||
| 38 | + | ||
| 39 | + onnx.save(model, filename) | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +@torch.no_grad() | ||
| 43 | +def main(): | ||
| 44 | + args = get_args() | ||
| 45 | + model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms" | ||
| 46 | + | ||
| 47 | + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
| 48 | + | ||
| 49 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 50 | + for i, s in enumerate(asr_model.joint.vocabulary): | ||
| 51 | + f.write(f"{s} {i}\n") | ||
| 52 | + f.write(f"<blk> {i+1}\n") | ||
| 53 | + print("Saved to tokens.txt") | ||
| 54 | + | ||
| 55 | + decoder_type = "ctc" | ||
| 56 | + asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
| 57 | + asr_model.eval() | ||
| 58 | + | ||
| 59 | + assert asr_model.encoder.streaming_cfg is not None | ||
| 60 | + if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list): | ||
| 61 | + chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1] | ||
| 62 | + else: | ||
| 63 | + chunk_size = asr_model.encoder.streaming_cfg.chunk_size | ||
| 64 | + | ||
| 65 | + if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list): | ||
| 66 | + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1] | ||
| 67 | + else: | ||
| 68 | + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size | ||
| 69 | + window_size = chunk_size + pre_encode_cache_size | ||
| 70 | + | ||
| 71 | + print("chunk_size", chunk_size) | ||
| 72 | + print("pre_encode_cache_size", pre_encode_cache_size) | ||
| 73 | + print("window_size", window_size) | ||
| 74 | + | ||
| 75 | + chunk_shift = chunk_size | ||
| 76 | + | ||
| 77 | + # cache_last_channel: (batch_size, dim1, dim2, dim3) | ||
| 78 | + cache_last_channel_dim1 = len(asr_model.encoder.layers) | ||
| 79 | + cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size | ||
| 80 | + cache_last_channel_dim3 = asr_model.encoder.d_model | ||
| 81 | + | ||
| 82 | + # cache_last_time: (batch_size, dim1, dim2, dim3) | ||
| 83 | + cache_last_time_dim1 = len(asr_model.encoder.layers) | ||
| 84 | + cache_last_time_dim2 = asr_model.encoder.d_model | ||
| 85 | + cache_last_time_dim3 = asr_model.encoder.conv_context_size[0] | ||
| 86 | + | ||
| 87 | + asr_model.set_export_config({"decoder_type": "ctc", "cache_support": True}) | ||
| 88 | + | ||
| 89 | + filename = "model.onnx" | ||
| 90 | + | ||
| 91 | + asr_model.export(filename) | ||
| 92 | + | ||
| 93 | + meta_data = { | ||
| 94 | + "vocab_size": asr_model.decoder.vocab_size, | ||
| 95 | + "window_size": window_size, | ||
| 96 | + "chunk_shift": chunk_shift, | ||
| 97 | + "normalize_type": "None", | ||
| 98 | + "cache_last_channel_dim1": cache_last_channel_dim1, | ||
| 99 | + "cache_last_channel_dim2": cache_last_channel_dim2, | ||
| 100 | + "cache_last_channel_dim3": cache_last_channel_dim3, | ||
| 101 | + "cache_last_time_dim1": cache_last_time_dim1, | ||
| 102 | + "cache_last_time_dim2": cache_last_time_dim2, | ||
| 103 | + "cache_last_time_dim3": cache_last_time_dim3, | ||
| 104 | + "subsampling_factor": 8, | ||
| 105 | + "model_type": "EncDecHybridRNNTCTCBPEModel", | ||
| 106 | + "version": "1", | ||
| 107 | + "model_author": "NeMo", | ||
| 108 | + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
| 109 | + "comment": "Only the CTC branch is exported", | ||
| 110 | + } | ||
| 111 | + add_meta_data(filename, meta_data) | ||
| 112 | + | ||
| 113 | + print(meta_data) | ||
| 114 | + | ||
| 115 | + | ||
| 116 | +if __name__ == "__main__": | ||
| 117 | + main() |
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +if [ ! -e ./0.wav ]; then | ||
| 6 | + # curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
| 7 | + curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
| 8 | +fi | ||
| 9 | + | ||
| 10 | +ms=( | ||
| 11 | +80 | ||
| 12 | +480 | ||
| 13 | +1040 | ||
| 14 | +) | ||
| 15 | + | ||
| 16 | +for m in ${ms[@]}; do | ||
| 17 | + ./export-onnx-ctc.py --model $m | ||
| 18 | + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | ||
| 19 | + if [ ! -f $d/model.onnx ]; then | ||
| 20 | + mkdir -p $d | ||
| 21 | + mv -v model.onnx $d/ | ||
| 22 | + mv -v tokens.txt $d/ | ||
| 23 | + ls -lh $d | ||
| 24 | + fi | ||
| 25 | +done | ||
| 26 | + | ||
| 27 | +# Now test the exported models | ||
| 28 | + | ||
| 29 | +for m in ${ms[@]}; do | ||
| 30 | + d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | ||
| 31 | + python3 ./test-onnx-ctc.py \ | ||
| 32 | + --model $d/model.onnx \ | ||
| 33 | + --tokens $d/tokens.txt \ | ||
| 34 | + --wav ./0.wav | ||
| 35 | +done |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +import argparse | ||
| 4 | +from pathlib import Path | ||
| 5 | + | ||
| 6 | +import kaldi_native_fbank as knf | ||
| 7 | +import numpy as np | ||
| 8 | +import onnxruntime as ort | ||
| 9 | +import torch | ||
| 10 | +import soundfile as sf | ||
| 11 | +import librosa | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +def get_args(): | ||
| 15 | + parser = argparse.ArgumentParser() | ||
| 16 | + parser.add_argument("--model", type=str, required=True, help="Path to model.onnx") | ||
| 17 | + | ||
| 18 | + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") | ||
| 19 | + | ||
| 20 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 21 | + | ||
| 22 | + return parser.parse_args() | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def create_fbank(): | ||
| 26 | + opts = knf.FbankOptions() | ||
| 27 | + opts.frame_opts.dither = 0 | ||
| 28 | + opts.frame_opts.remove_dc_offset = False | ||
| 29 | + opts.frame_opts.window_type = "hann" | ||
| 30 | + | ||
| 31 | + opts.mel_opts.low_freq = 0 | ||
| 32 | + opts.mel_opts.num_bins = 80 | ||
| 33 | + | ||
| 34 | + opts.mel_opts.is_librosa = True | ||
| 35 | + | ||
| 36 | + fbank = knf.OnlineFbank(opts) | ||
| 37 | + return fbank | ||
| 38 | + | ||
| 39 | + | ||
| 40 | +def compute_features(audio, fbank): | ||
| 41 | + assert len(audio.shape) == 1, audio.shape | ||
| 42 | + fbank.accept_waveform(16000, audio) | ||
| 43 | + ans = [] | ||
| 44 | + processed = 0 | ||
| 45 | + while processed < fbank.num_frames_ready: | ||
| 46 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 47 | + processed += 1 | ||
| 48 | + ans = np.stack(ans) | ||
| 49 | + return ans | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +class OnnxModel: | ||
| 53 | + def __init__( | ||
| 54 | + self, | ||
| 55 | + filename: str, | ||
| 56 | + ): | ||
| 57 | + session_opts = ort.SessionOptions() | ||
| 58 | + session_opts.inter_op_num_threads = 1 | ||
| 59 | + session_opts.intra_op_num_threads = 1 | ||
| 60 | + | ||
| 61 | + self.session_opts = session_opts | ||
| 62 | + | ||
| 63 | + self.model = ort.InferenceSession( | ||
| 64 | + filename, | ||
| 65 | + sess_options=self.session_opts, | ||
| 66 | + providers=["CPUExecutionProvider"], | ||
| 67 | + ) | ||
| 68 | + | ||
| 69 | + meta = self.model.get_modelmeta().custom_metadata_map | ||
| 70 | + print(meta) | ||
| 71 | + | ||
| 72 | + self.window_size = int(meta["window_size"]) | ||
| 73 | + self.chunk_shift = int(meta["chunk_shift"]) | ||
| 74 | + | ||
| 75 | + self.cache_last_channel_dim1 = int(meta["cache_last_channel_dim1"]) | ||
| 76 | + self.cache_last_channel_dim2 = int(meta["cache_last_channel_dim2"]) | ||
| 77 | + self.cache_last_channel_dim3 = int(meta["cache_last_channel_dim3"]) | ||
| 78 | + | ||
| 79 | + self.cache_last_time_dim1 = int(meta["cache_last_time_dim1"]) | ||
| 80 | + self.cache_last_time_dim2 = int(meta["cache_last_time_dim2"]) | ||
| 81 | + self.cache_last_time_dim3 = int(meta["cache_last_time_dim3"]) | ||
| 82 | + | ||
| 83 | + self.init_cache_state() | ||
| 84 | + | ||
| 85 | + def init_cache_state(self): | ||
| 86 | + self.cache_last_channel = torch.zeros( | ||
| 87 | + 1, | ||
| 88 | + self.cache_last_channel_dim1, | ||
| 89 | + self.cache_last_channel_dim2, | ||
| 90 | + self.cache_last_channel_dim3, | ||
| 91 | + dtype=torch.float32, | ||
| 92 | + ).numpy() | ||
| 93 | + | ||
| 94 | + self.cache_last_time = torch.zeros( | ||
| 95 | + 1, | ||
| 96 | + self.cache_last_time_dim1, | ||
| 97 | + self.cache_last_time_dim2, | ||
| 98 | + self.cache_last_time_dim3, | ||
| 99 | + dtype=torch.float32, | ||
| 100 | + ).numpy() | ||
| 101 | + | ||
| 102 | + self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() | ||
| 103 | + | ||
| 104 | + def __call__(self, x: np.ndarray): | ||
| 105 | + # x: (T, C) | ||
| 106 | + x = torch.from_numpy(x) | ||
| 107 | + x = x.t().unsqueeze(0) | ||
| 108 | + # x: [1, C, T] | ||
| 109 | + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) | ||
| 110 | + | ||
| 111 | + ( | ||
| 112 | + log_probs, | ||
| 113 | + log_probs_len, | ||
| 114 | + cache_last_channel_next, | ||
| 115 | + cache_last_time_next, | ||
| 116 | + cache_last_channel_len_next, | ||
| 117 | + ) = self.model.run( | ||
| 118 | + [ | ||
| 119 | + self.model.get_outputs()[0].name, | ||
| 120 | + self.model.get_outputs()[1].name, | ||
| 121 | + self.model.get_outputs()[2].name, | ||
| 122 | + self.model.get_outputs()[3].name, | ||
| 123 | + self.model.get_outputs()[4].name, | ||
| 124 | + ], | ||
| 125 | + { | ||
| 126 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 127 | + self.model.get_inputs()[1].name: x_lens.numpy(), | ||
| 128 | + self.model.get_inputs()[2].name: self.cache_last_channel, | ||
| 129 | + self.model.get_inputs()[3].name: self.cache_last_time, | ||
| 130 | + self.model.get_inputs()[4].name: self.cache_last_channel_len, | ||
| 131 | + }, | ||
| 132 | + ) | ||
| 133 | + self.cache_last_channel = cache_last_channel_next | ||
| 134 | + self.cache_last_time = cache_last_time_next | ||
| 135 | + self.cache_last_channel_len = cache_last_channel_len_next | ||
| 136 | + | ||
| 137 | + # [T, vocab_size] | ||
| 138 | + return torch.from_numpy(log_probs).squeeze(0) | ||
| 139 | + | ||
| 140 | + | ||
| 141 | +def main(): | ||
| 142 | + args = get_args() | ||
| 143 | + assert Path(args.model).is_file(), args.model | ||
| 144 | + assert Path(args.tokens).is_file(), args.tokens | ||
| 145 | + assert Path(args.wav).is_file(), args.wav | ||
| 146 | + | ||
| 147 | + print(vars(args)) | ||
| 148 | + | ||
| 149 | + model = OnnxModel(args.model) | ||
| 150 | + | ||
| 151 | + id2token = dict() | ||
| 152 | + with open(args.tokens, encoding="utf-8") as f: | ||
| 153 | + for line in f: | ||
| 154 | + t, idx = line.split() | ||
| 155 | + id2token[int(idx)] = t | ||
| 156 | + | ||
| 157 | + fbank = create_fbank() | ||
| 158 | + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) | ||
| 159 | + audio = audio[:, 0] # only use the first channel | ||
| 160 | + if sample_rate != 16000: | ||
| 161 | + audio = librosa.resample( | ||
| 162 | + audio, | ||
| 163 | + orig_sr=sample_rate, | ||
| 164 | + target_sr=16000, | ||
| 165 | + ) | ||
| 166 | + sample_rate = 16000 | ||
| 167 | + | ||
| 168 | + window_size = model.window_size | ||
| 169 | + chunk_shift = model.chunk_shift | ||
| 170 | + | ||
| 171 | + blank = len(id2token) - 1 | ||
| 172 | + prev = -1 | ||
| 173 | + ans = [] | ||
| 174 | + | ||
| 175 | + features = compute_features(audio, fbank) | ||
| 176 | + num_chunks = (features.shape[0] - window_size) // chunk_shift + 1 | ||
| 177 | + for i in range(num_chunks): | ||
| 178 | + start = i * chunk_shift | ||
| 179 | + end = start + window_size | ||
| 180 | + chunk = features[start:end, :] | ||
| 181 | + | ||
| 182 | + log_probs = model(chunk) | ||
| 183 | + ids = torch.argmax(log_probs, dim=1).tolist() | ||
| 184 | + for i in ids: | ||
| 185 | + if i != blank and i != prev: | ||
| 186 | + ans.append(i) | ||
| 187 | + prev = i | ||
| 188 | + | ||
| 189 | + tokens = [id2token[i] for i in ans] | ||
| 190 | + underline = "▁" | ||
| 191 | + # underline = b"\xe2\x96\x81".decode() | ||
| 192 | + text = "".join(tokens).replace(underline, " ").strip() | ||
| 193 | + print(args.wav) | ||
| 194 | + print(text) | ||
| 195 | + | ||
| 196 | + | ||
| 197 | +main() |
-
请 注册 或 登录 后发表评论