Committed by
GitHub
Export NeMo FastConformer Hybrid Transducer Large Streaming to ONNX (#844)
正在显示
9 个修改的文件
包含
611 行增加
和
1 行删除
| 1 | +name: export-nemo-fast-conformer-transducer-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-nemo-fast-conformer-hybrid-transducer-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-nemo-fast-conformer-hybrid-transducer-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: NeMo transducer | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [macos-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install NeMo | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + BRANCH='main' | ||
| 33 | + pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
| 34 | + pip install onnxruntime | ||
| 35 | + pip install kaldi-native-fbank | ||
| 36 | + pip install soundfile librosa | ||
| 37 | + | ||
| 38 | + - name: Run | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
| 42 | + ./run-transducer.sh | ||
| 43 | + | ||
| 44 | + mv -v sherpa-onnx-nemo* ../../.. | ||
| 45 | + | ||
| 46 | + - name: Download test waves | ||
| 47 | + shell: bash | ||
| 48 | + run: | | ||
| 49 | + mkdir test_wavs | ||
| 50 | + pushd test_wavs | ||
| 51 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav | ||
| 52 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav | ||
| 53 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav | ||
| 54 | + curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | ||
| 55 | + popd | ||
| 56 | + | ||
| 57 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
| 58 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
| 59 | + cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | ||
| 60 | + | ||
| 61 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-80ms | ||
| 62 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-480ms | ||
| 63 | + tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-transducer-1040ms | ||
| 64 | + | ||
| 65 | + - name: Release | ||
| 66 | + uses: svenstaro/upload-release-action@v2 | ||
| 67 | + with: | ||
| 68 | + file_glob: true | ||
| 69 | + file: ./*.tar.bz2 | ||
| 70 | + overwrite: true | ||
| 71 | + repo_name: k2-fsa/sherpa-onnx | ||
| 72 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 73 | + tag: asr-models |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import argparse | ||
| 4 | +from typing import Dict | ||
| 5 | + | ||
| 6 | +import nemo.collections.asr as nemo_asr | ||
| 7 | +import onnx | ||
| 8 | +import torch | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +def get_args(): | ||
| 12 | + parser = argparse.ArgumentParser() | ||
| 13 | + parser.add_argument( | ||
| 14 | + "--model", | ||
| 15 | + type=str, | ||
| 16 | + required=True, | ||
| 17 | + choices=["80", "480", "1040"], | ||
| 18 | + ) | ||
| 19 | + return parser.parse_args() | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 23 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 24 | + | ||
| 25 | + Args: | ||
| 26 | + filename: | ||
| 27 | + Filename of the ONNX model to be changed. | ||
| 28 | + meta_data: | ||
| 29 | + Key-value pairs. | ||
| 30 | + """ | ||
| 31 | + model = onnx.load(filename) | ||
| 32 | + while len(model.metadata_props): | ||
| 33 | + model.metadata_props.pop() | ||
| 34 | + | ||
| 35 | + for key, value in meta_data.items(): | ||
| 36 | + meta = model.metadata_props.add() | ||
| 37 | + meta.key = key | ||
| 38 | + meta.value = str(value) | ||
| 39 | + | ||
| 40 | + onnx.save(model, filename) | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +@torch.no_grad() | ||
| 44 | +def main(): | ||
| 45 | + args = get_args() | ||
| 46 | + model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms" | ||
| 47 | + | ||
| 48 | + asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
| 49 | + | ||
| 50 | + with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
| 51 | + for i, s in enumerate(asr_model.joint.vocabulary): | ||
| 52 | + f.write(f"{s} {i}\n") | ||
| 53 | + f.write(f"<blk> {i+1}\n") | ||
| 54 | + print("Saved to tokens.txt") | ||
| 55 | + | ||
| 56 | + decoder_type = "rnnt" | ||
| 57 | + asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
| 58 | + asr_model.eval() | ||
| 59 | + | ||
| 60 | + assert asr_model.encoder.streaming_cfg is not None | ||
| 61 | + if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list): | ||
| 62 | + chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1] | ||
| 63 | + else: | ||
| 64 | + chunk_size = asr_model.encoder.streaming_cfg.chunk_size | ||
| 65 | + | ||
| 66 | + if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list): | ||
| 67 | + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1] | ||
| 68 | + else: | ||
| 69 | + pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size | ||
| 70 | + window_size = chunk_size + pre_encode_cache_size | ||
| 71 | + | ||
| 72 | + print("chunk_size", chunk_size) | ||
| 73 | + print("pre_encode_cache_size", pre_encode_cache_size) | ||
| 74 | + print("window_size", window_size) | ||
| 75 | + | ||
| 76 | + chunk_shift = chunk_size | ||
| 77 | + | ||
| 78 | + # cache_last_channel: (batch_size, dim1, dim2, dim3) | ||
| 79 | + cache_last_channel_dim1 = len(asr_model.encoder.layers) | ||
| 80 | + cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size | ||
| 81 | + cache_last_channel_dim3 = asr_model.encoder.d_model | ||
| 82 | + | ||
| 83 | + # cache_last_time: (batch_size, dim1, dim2, dim3) | ||
| 84 | + cache_last_time_dim1 = len(asr_model.encoder.layers) | ||
| 85 | + cache_last_time_dim2 = asr_model.encoder.d_model | ||
| 86 | + cache_last_time_dim3 = asr_model.encoder.conv_context_size[0] | ||
| 87 | + | ||
| 88 | + asr_model.set_export_config({"decoder_type": "rnnt", "cache_support": True}) | ||
| 89 | + | ||
| 90 | + # asr_model.export("model.onnx") | ||
| 91 | + asr_model.encoder.export("encoder.onnx") | ||
| 92 | + asr_model.decoder.export("decoder.onnx") | ||
| 93 | + asr_model.joint.export("joiner.onnx") | ||
| 94 | + # model.onnx is a suffix. | ||
| 95 | + # It will generate two files: | ||
| 96 | + # encoder-model.onnx | ||
| 97 | + # decoder_joint-model.onnx | ||
| 98 | + | ||
| 99 | + meta_data = { | ||
| 100 | + "vocab_size": asr_model.decoder.vocab_size, | ||
| 101 | + "window_size": window_size, | ||
| 102 | + "chunk_shift": chunk_shift, | ||
| 103 | + "normalize_type": "None", | ||
| 104 | + "cache_last_channel_dim1": cache_last_channel_dim1, | ||
| 105 | + "cache_last_channel_dim2": cache_last_channel_dim2, | ||
| 106 | + "cache_last_channel_dim3": cache_last_channel_dim3, | ||
| 107 | + "cache_last_time_dim1": cache_last_time_dim1, | ||
| 108 | + "cache_last_time_dim2": cache_last_time_dim2, | ||
| 109 | + "cache_last_time_dim3": cache_last_time_dim3, | ||
| 110 | + "pred_rnn_layers": asr_model.decoder.pred_rnn_layers, | ||
| 111 | + "pred_hidden": asr_model.decoder.pred_hidden, | ||
| 112 | + "subsampling_factor": 8, | ||
| 113 | + "model_type": "EncDecHybridRNNTCTCBPEModel", | ||
| 114 | + "version": "1", | ||
| 115 | + "model_author": "NeMo", | ||
| 116 | + "url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
| 117 | + "comment": "Only the transducer branch is exported", | ||
| 118 | + } | ||
| 119 | + add_meta_data("encoder.onnx", meta_data) | ||
| 120 | + | ||
| 121 | + print(meta_data) | ||
| 122 | + | ||
| 123 | + | ||
| 124 | +if __name__ == "__main__": | ||
| 125 | + main() |
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | +if [ ! -e ./0.wav ]; then | ||
| 7 | + # curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
| 8 | + curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
| 9 | +fi | ||
| 10 | + | ||
| 11 | +ms=( | ||
| 12 | +80 | ||
| 13 | +480 | ||
| 14 | +1040 | ||
| 15 | +) | ||
| 16 | + | ||
| 17 | +for m in ${ms[@]}; do | ||
| 18 | + ./export-onnx-transducer.py --model $m | ||
| 19 | + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | ||
| 20 | + if [ ! -f $d/encoder.onnx ]; then | ||
| 21 | + mkdir -p $d | ||
| 22 | + mv -v encoder.onnx $d/ | ||
| 23 | + mv -v decoder.onnx $d/ | ||
| 24 | + mv -v joiner.onnx $d/ | ||
| 25 | + mv -v tokens.txt $d/ | ||
| 26 | + ls -lh $d | ||
| 27 | + fi | ||
| 28 | +done | ||
| 29 | + | ||
| 30 | +# Now test the exported models | ||
| 31 | + | ||
| 32 | +for m in ${ms[@]}; do | ||
| 33 | + d=sherpa-onnx-nemo-streaming-fast-conformer-transducer-${m}ms | ||
| 34 | + python3 ./test-onnx-transducer.py \ | ||
| 35 | + --encoder $d/encoder.onnx \ | ||
| 36 | + --decoder $d/decoder.onnx \ | ||
| 37 | + --joiner $d/joiner.onnx \ | ||
| 38 | + --tokens $d/tokens.txt \ | ||
| 39 | + --wav ./0.wav | ||
| 40 | +done |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import onnxruntime | ||
| 5 | + | ||
| 6 | + | ||
| 7 | +def show(filename): | ||
| 8 | + session_opts = onnxruntime.SessionOptions() | ||
| 9 | + session_opts.log_severity_level = 3 | ||
| 10 | + sess = onnxruntime.InferenceSession(filename, session_opts) | ||
| 11 | + for i in sess.get_inputs(): | ||
| 12 | + print(i) | ||
| 13 | + | ||
| 14 | + print("-----") | ||
| 15 | + | ||
| 16 | + for i in sess.get_outputs(): | ||
| 17 | + print(i) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +def main(): | ||
| 21 | + print("=========encoder==========") | ||
| 22 | + show("./encoder.onnx") | ||
| 23 | + | ||
| 24 | + print("=========decoder==========") | ||
| 25 | + show("./decoder.onnx") | ||
| 26 | + | ||
| 27 | + print("=========joiner==========") | ||
| 28 | + show("./joiner.onnx") | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +if __name__ == "__main__": | ||
| 32 | + main() | ||
| 33 | + | ||
| 34 | +""" | ||
| 35 | +=========encoder========== | ||
| 36 | +NodeArg(name='audio_signal', type='tensor(float)', shape=['audio_signal_dynamic_axes_1', 80, 'audio_signal_dynamic_axes_2']) | ||
| 37 | +NodeArg(name='length', type='tensor(int64)', shape=['length_dynamic_axes_1']) | ||
| 38 | +NodeArg(name='cache_last_channel', type='tensor(float)', shape=['cache_last_channel_dynamic_axes_1', 17, 'cache_last_channel_dynamic_axes_2', 512]) | ||
| 39 | +NodeArg(name='cache_last_time', type='tensor(float)', shape=['cache_last_time_dynamic_axes_1', 17, 512, 'cache_last_time_dynamic_axes_2']) | ||
| 40 | +NodeArg(name='cache_last_channel_len', type='tensor(int64)', shape=['cache_last_channel_len_dynamic_axes_1']) | ||
| 41 | +----- | ||
| 42 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 512, 'outputs_dynamic_axes_2']) | ||
| 43 | +NodeArg(name='encoded_lengths', type='tensor(int64)', shape=['encoded_lengths_dynamic_axes_1']) | ||
| 44 | +NodeArg(name='cache_last_channel_next', type='tensor(float)', shape=['cache_last_channel_next_dynamic_axes_1', 17, 'cache_last_channel_next_dynamic_axes_2', 512]) | ||
| 45 | +NodeArg(name='cache_last_time_next', type='tensor(float)', shape=['cache_last_time_next_dynamic_axes_1', 17, 512, 'cache_last_time_next_dynamic_axes_2']) | ||
| 46 | +NodeArg(name='cache_last_channel_next_len', type='tensor(int64)', shape=['cache_last_channel_next_len_dynamic_axes_1']) | ||
| 47 | +=========decoder========== | ||
| 48 | +NodeArg(name='targets', type='tensor(int32)', shape=['targets_dynamic_axes_1', 'targets_dynamic_axes_2']) | ||
| 49 | +NodeArg(name='target_length', type='tensor(int32)', shape=['target_length_dynamic_axes_1']) | ||
| 50 | +NodeArg(name='states.1', type='tensor(float)', shape=[1, 'states.1_dim_1', 640]) | ||
| 51 | +NodeArg(name='onnx::LSTM_3', type='tensor(float)', shape=[1, 1, 640]) | ||
| 52 | +----- | ||
| 53 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 640, 'outputs_dynamic_axes_2']) | ||
| 54 | +NodeArg(name='prednet_lengths', type='tensor(int32)', shape=['prednet_lengths_dynamic_axes_1']) | ||
| 55 | +NodeArg(name='states', type='tensor(float)', shape=[1, 'states_dynamic_axes_1', 640]) | ||
| 56 | +NodeArg(name='74', type='tensor(float)', shape=[1, 'LSTM74_dim_1', 640]) | ||
| 57 | +=========joiner========== | ||
| 58 | +NodeArg(name='encoder_outputs', type='tensor(float)', shape=['encoder_outputs_dynamic_axes_1', 512, 'encoder_outputs_dynamic_axes_2']) | ||
| 59 | +NodeArg(name='decoder_outputs', type='tensor(float)', shape=['decoder_outputs_dynamic_axes_1', 640, 'decoder_outputs_dynamic_axes_2']) | ||
| 60 | +----- | ||
| 61 | +NodeArg(name='outputs', type='tensor(float)', shape=['outputs_dynamic_axes_1', 'outputs_dynamic_axes_2', 'outputs_dynamic_axes_3', 1025]) | ||
| 62 | + | ||
| 63 | +""" |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from pathlib import Path | ||
| 6 | + | ||
| 7 | +import kaldi_native_fbank as knf | ||
| 8 | +import librosa | ||
| 9 | +import numpy as np | ||
| 10 | +import onnxruntime as ort | ||
| 11 | +import soundfile as sf | ||
| 12 | +import torch | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def get_args(): | ||
| 16 | + parser = argparse.ArgumentParser() | ||
| 17 | + parser.add_argument( | ||
| 18 | + "--encoder", type=str, required=True, help="Path to encoder.onnx" | ||
| 19 | + ) | ||
| 20 | + parser.add_argument( | ||
| 21 | + "--decoder", type=str, required=True, help="Path to decoder.onnx" | ||
| 22 | + ) | ||
| 23 | + parser.add_argument("--joiner", type=str, required=True, help="Path to joiner.onnx") | ||
| 24 | + | ||
| 25 | + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") | ||
| 26 | + | ||
| 27 | + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") | ||
| 28 | + | ||
| 29 | + return parser.parse_args() | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +def create_fbank(): | ||
| 33 | + opts = knf.FbankOptions() | ||
| 34 | + opts.frame_opts.dither = 0 | ||
| 35 | + opts.frame_opts.remove_dc_offset = False | ||
| 36 | + opts.frame_opts.window_type = "hann" | ||
| 37 | + | ||
| 38 | + opts.mel_opts.low_freq = 0 | ||
| 39 | + opts.mel_opts.num_bins = 80 | ||
| 40 | + | ||
| 41 | + opts.mel_opts.is_librosa = True | ||
| 42 | + | ||
| 43 | + fbank = knf.OnlineFbank(opts) | ||
| 44 | + return fbank | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def compute_features(audio, fbank): | ||
| 48 | + assert len(audio.shape) == 1, audio.shape | ||
| 49 | + fbank.accept_waveform(16000, audio) | ||
| 50 | + ans = [] | ||
| 51 | + processed = 0 | ||
| 52 | + while processed < fbank.num_frames_ready: | ||
| 53 | + ans.append(np.array(fbank.get_frame(processed))) | ||
| 54 | + processed += 1 | ||
| 55 | + ans = np.stack(ans) | ||
| 56 | + return ans | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +class OnnxModel: | ||
| 60 | + def __init__( | ||
| 61 | + self, | ||
| 62 | + encoder: str, | ||
| 63 | + decoder: str, | ||
| 64 | + joiner: str, | ||
| 65 | + ): | ||
| 66 | + self.init_encoder(encoder) | ||
| 67 | + self.init_decoder(decoder) | ||
| 68 | + self.init_joiner(joiner) | ||
| 69 | + | ||
| 70 | + def init_encoder(self, encoder): | ||
| 71 | + session_opts = ort.SessionOptions() | ||
| 72 | + session_opts.inter_op_num_threads = 1 | ||
| 73 | + session_opts.intra_op_num_threads = 1 | ||
| 74 | + | ||
| 75 | + self.encoder = ort.InferenceSession( | ||
| 76 | + encoder, | ||
| 77 | + sess_options=session_opts, | ||
| 78 | + providers=["CPUExecutionProvider"], | ||
| 79 | + ) | ||
| 80 | + | ||
| 81 | + meta = self.encoder.get_modelmeta().custom_metadata_map | ||
| 82 | + print(meta) | ||
| 83 | + | ||
| 84 | + self.window_size = int(meta["window_size"]) | ||
| 85 | + self.chunk_shift = int(meta["chunk_shift"]) | ||
| 86 | + | ||
| 87 | + self.cache_last_channel_dim1 = int(meta["cache_last_channel_dim1"]) | ||
| 88 | + self.cache_last_channel_dim2 = int(meta["cache_last_channel_dim2"]) | ||
| 89 | + self.cache_last_channel_dim3 = int(meta["cache_last_channel_dim3"]) | ||
| 90 | + | ||
| 91 | + self.cache_last_time_dim1 = int(meta["cache_last_time_dim1"]) | ||
| 92 | + self.cache_last_time_dim2 = int(meta["cache_last_time_dim2"]) | ||
| 93 | + self.cache_last_time_dim3 = int(meta["cache_last_time_dim3"]) | ||
| 94 | + | ||
| 95 | + self.pred_rnn_layers = int(meta["pred_rnn_layers"]) | ||
| 96 | + self.pred_hidden = int(meta["pred_hidden"]) | ||
| 97 | + | ||
| 98 | + self.init_cache_state() | ||
| 99 | + | ||
| 100 | + def init_decoder(self, decoder): | ||
| 101 | + session_opts = ort.SessionOptions() | ||
| 102 | + session_opts.inter_op_num_threads = 1 | ||
| 103 | + session_opts.intra_op_num_threads = 1 | ||
| 104 | + | ||
| 105 | + self.decoder = ort.InferenceSession( | ||
| 106 | + decoder, | ||
| 107 | + sess_options=session_opts, | ||
| 108 | + providers=["CPUExecutionProvider"], | ||
| 109 | + ) | ||
| 110 | + | ||
| 111 | + def init_joiner(self, joiner): | ||
| 112 | + session_opts = ort.SessionOptions() | ||
| 113 | + session_opts.inter_op_num_threads = 1 | ||
| 114 | + session_opts.intra_op_num_threads = 1 | ||
| 115 | + | ||
| 116 | + self.joiner = ort.InferenceSession( | ||
| 117 | + joiner, | ||
| 118 | + sess_options=session_opts, | ||
| 119 | + providers=["CPUExecutionProvider"], | ||
| 120 | + ) | ||
| 121 | + | ||
| 122 | + def get_decoder_state(self): | ||
| 123 | + batch_size = 1 | ||
| 124 | + state0 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() | ||
| 125 | + state1 = torch.zeros(self.pred_rnn_layers, batch_size, self.pred_hidden).numpy() | ||
| 126 | + return state0, state1 | ||
| 127 | + | ||
| 128 | + def init_cache_state(self): | ||
| 129 | + self.cache_last_channel = torch.zeros( | ||
| 130 | + 1, | ||
| 131 | + self.cache_last_channel_dim1, | ||
| 132 | + self.cache_last_channel_dim2, | ||
| 133 | + self.cache_last_channel_dim3, | ||
| 134 | + dtype=torch.float32, | ||
| 135 | + ).numpy() | ||
| 136 | + | ||
| 137 | + self.cache_last_time = torch.zeros( | ||
| 138 | + 1, | ||
| 139 | + self.cache_last_time_dim1, | ||
| 140 | + self.cache_last_time_dim2, | ||
| 141 | + self.cache_last_time_dim3, | ||
| 142 | + dtype=torch.float32, | ||
| 143 | + ).numpy() | ||
| 144 | + | ||
| 145 | + self.cache_last_channel_len = torch.ones([1], dtype=torch.int64).numpy() | ||
| 146 | + | ||
| 147 | + def run_encoder(self, x: np.ndarray): | ||
| 148 | + # x: (T, C) | ||
| 149 | + x = torch.from_numpy(x) | ||
| 150 | + x = x.t().unsqueeze(0) | ||
| 151 | + # x: [1, C, T] | ||
| 152 | + x_lens = torch.tensor([x.shape[-1]], dtype=torch.int64) | ||
| 153 | + | ||
| 154 | + ( | ||
| 155 | + encoder_out, | ||
| 156 | + out_len, | ||
| 157 | + cache_last_channel_next, | ||
| 158 | + cache_last_time_next, | ||
| 159 | + cache_last_channel_len_next, | ||
| 160 | + ) = self.encoder.run( | ||
| 161 | + [ | ||
| 162 | + self.encoder.get_outputs()[0].name, | ||
| 163 | + self.encoder.get_outputs()[1].name, | ||
| 164 | + self.encoder.get_outputs()[2].name, | ||
| 165 | + self.encoder.get_outputs()[3].name, | ||
| 166 | + self.encoder.get_outputs()[4].name, | ||
| 167 | + ], | ||
| 168 | + { | ||
| 169 | + self.encoder.get_inputs()[0].name: x.numpy(), | ||
| 170 | + self.encoder.get_inputs()[1].name: x_lens.numpy(), | ||
| 171 | + self.encoder.get_inputs()[2].name: self.cache_last_channel, | ||
| 172 | + self.encoder.get_inputs()[3].name: self.cache_last_time, | ||
| 173 | + self.encoder.get_inputs()[4].name: self.cache_last_channel_len, | ||
| 174 | + }, | ||
| 175 | + ) | ||
| 176 | + self.cache_last_channel = cache_last_channel_next | ||
| 177 | + self.cache_last_time = cache_last_time_next | ||
| 178 | + self.cache_last_channel_len = cache_last_channel_len_next | ||
| 179 | + | ||
| 180 | + # [batch_size, dim, T] | ||
| 181 | + return encoder_out | ||
| 182 | + | ||
| 183 | + def run_decoder( | ||
| 184 | + self, | ||
| 185 | + token: int, | ||
| 186 | + state0: np.ndarray, | ||
| 187 | + state1: np.ndarray, | ||
| 188 | + ): | ||
| 189 | + target = torch.tensor([[token]], dtype=torch.int32).numpy() | ||
| 190 | + target_len = torch.tensor([1], dtype=torch.int32).numpy() | ||
| 191 | + | ||
| 192 | + ( | ||
| 193 | + decoder_out, | ||
| 194 | + decoder_out_length, | ||
| 195 | + state0_next, | ||
| 196 | + state1_next, | ||
| 197 | + ) = self.decoder.run( | ||
| 198 | + [ | ||
| 199 | + self.decoder.get_outputs()[0].name, | ||
| 200 | + self.decoder.get_outputs()[1].name, | ||
| 201 | + self.decoder.get_outputs()[2].name, | ||
| 202 | + self.decoder.get_outputs()[3].name, | ||
| 203 | + ], | ||
| 204 | + { | ||
| 205 | + self.decoder.get_inputs()[0].name: target, | ||
| 206 | + self.decoder.get_inputs()[1].name: target_len, | ||
| 207 | + self.decoder.get_inputs()[2].name: state0, | ||
| 208 | + self.decoder.get_inputs()[3].name: state1, | ||
| 209 | + }, | ||
| 210 | + ) | ||
| 211 | + return decoder_out, state0_next, state1_next | ||
| 212 | + | ||
| 213 | + def run_joiner( | ||
| 214 | + self, | ||
| 215 | + encoder_out: np.ndarray, | ||
| 216 | + decoder_out: np.ndarray, | ||
| 217 | + ): | ||
| 218 | + # encoder_out: [batch_size, dim, 1] | ||
| 219 | + # decoder_out: [batch_size, dim, 1] | ||
| 220 | + logit = self.joiner.run( | ||
| 221 | + [ | ||
| 222 | + self.joiner.get_outputs()[0].name, | ||
| 223 | + ], | ||
| 224 | + { | ||
| 225 | + self.joiner.get_inputs()[0].name: encoder_out, | ||
| 226 | + self.joiner.get_inputs()[1].name: decoder_out, | ||
| 227 | + }, | ||
| 228 | + )[0] | ||
| 229 | + # logit: [batch_size, 1, 1, vocab_size] | ||
| 230 | + return logit | ||
| 231 | + | ||
| 232 | + | ||
| 233 | +def main(): | ||
| 234 | + args = get_args() | ||
| 235 | + assert Path(args.encoder).is_file(), args.encoder | ||
| 236 | + assert Path(args.decoder).is_file(), args.decoder | ||
| 237 | + assert Path(args.joiner).is_file(), args.joiner | ||
| 238 | + assert Path(args.tokens).is_file(), args.tokens | ||
| 239 | + assert Path(args.wav).is_file(), args.wav | ||
| 240 | + | ||
| 241 | + print(vars(args)) | ||
| 242 | + | ||
| 243 | + model = OnnxModel(args.encoder, args.decoder, args.joiner) | ||
| 244 | + | ||
| 245 | + id2token = dict() | ||
| 246 | + with open(args.tokens, encoding="utf-8") as f: | ||
| 247 | + for line in f: | ||
| 248 | + t, idx = line.split() | ||
| 249 | + id2token[int(idx)] = t | ||
| 250 | + | ||
| 251 | + fbank = create_fbank() | ||
| 252 | + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) | ||
| 253 | + audio = audio[:, 0] # only use the first channel | ||
| 254 | + if sample_rate != 16000: | ||
| 255 | + audio = librosa.resample( | ||
| 256 | + audio, | ||
| 257 | + orig_sr=sample_rate, | ||
| 258 | + target_sr=16000, | ||
| 259 | + ) | ||
| 260 | + sample_rate = 16000 | ||
| 261 | + | ||
| 262 | + tail_padding = np.zeros(sample_rate * 2) | ||
| 263 | + | ||
| 264 | + audio = np.concatenate([audio, tail_padding]) | ||
| 265 | + | ||
| 266 | + window_size = model.window_size | ||
| 267 | + chunk_shift = model.chunk_shift | ||
| 268 | + | ||
| 269 | + blank = len(id2token) - 1 | ||
| 270 | + ans = [blank] | ||
| 271 | + state0, state1 = model.get_decoder_state() | ||
| 272 | + decoder_out, state0_next, state1_next = model.run_decoder(ans[-1], state0, state1) | ||
| 273 | + | ||
| 274 | + features = compute_features(audio, fbank) | ||
| 275 | + num_chunks = (features.shape[0] - window_size) // chunk_shift + 1 | ||
| 276 | + for i in range(num_chunks): | ||
| 277 | + start = i * chunk_shift | ||
| 278 | + end = start + window_size | ||
| 279 | + chunk = features[start:end, :] | ||
| 280 | + | ||
| 281 | + encoder_out = model.run_encoder(chunk) | ||
| 282 | + # encoder_out:[batch_size, dim, T) | ||
| 283 | + for t in range(encoder_out.shape[2]): | ||
| 284 | + encoder_out_t = encoder_out[:, :, t : t + 1] | ||
| 285 | + logits = model.run_joiner(encoder_out_t, decoder_out) | ||
| 286 | + logits = torch.from_numpy(logits) | ||
| 287 | + logits = logits.squeeze() | ||
| 288 | + idx = torch.argmax(logits, dim=-1).item() | ||
| 289 | + if idx != blank: | ||
| 290 | + ans.append(idx) | ||
| 291 | + state0 = state0_next | ||
| 292 | + state1 = state1_next | ||
| 293 | + decoder_out, state0_next, state1_next = model.run_decoder( | ||
| 294 | + ans[-1], state0, state1 | ||
| 295 | + ) | ||
| 296 | + | ||
| 297 | + ans = ans[1:] # remove the first blank | ||
| 298 | + tokens = [id2token[i] for i in ans] | ||
| 299 | + underline = "▁" | ||
| 300 | + # underline = b"\xe2\x96\x81".decode() | ||
| 301 | + text = "".join(tokens).replace(underline, " ").strip() | ||
| 302 | + print(args.wav) | ||
| 303 | + print(text) | ||
| 304 | + | ||
| 305 | + | ||
| 306 | +main() |
-
请 注册 或 登录 后发表评论