Committed by
GitHub
export telespeech ctc models to sherpa-onnx (#968)
正在显示
6 个修改的文件
包含
342 行增加
和
0 行删除
.github/workflows/export-telespeech-ctc.yaml
0 → 100644
| 1 | +name: export-telespeech-ctc-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-telespeech-ctc-to-onnx-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-telespeech-ctc-to-onnx: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: telespeech | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [ubuntu-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install Python dependencies | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + pip install onnx onnxruntime soundfile librosa numpy kaldi-native-fbank | ||
| 33 | + | ||
| 34 | + - name: Run | ||
| 35 | + shell: bash | ||
| 36 | + run: | | ||
| 37 | + cd scripts/tele-speech | ||
| 38 | + ./run.sh | ||
| 39 | + | ||
| 40 | + ./test.py | ||
| 41 | + | ||
| 42 | + - name: Release | ||
| 43 | + uses: svenstaro/upload-release-action@v2 | ||
| 44 | + with: | ||
| 45 | + file_glob: true | ||
| 46 | + file: ./*.tar.bz2 | ||
| 47 | + overwrite: true | ||
| 48 | + repo_name: k2-fsa/sherpa-onnx | ||
| 49 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 50 | + tag: asr-models |
scripts/tele-speech/.gitignore
0 → 100644
| 1 | +*.json |
scripts/tele-speech/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains scripts about adding metadata to | ||
| 4 | +onnx models from | ||
| 5 | +https://hf-mirror.com/lovemefan/telespeech/tree/main | ||
| 6 | + | ||
| 7 | +Please see | ||
| 8 | + | ||
| 9 | + - https://github.com/Tele-AI/TeleSpeech-ASR | ||
| 10 | + - https://github.com/lovemefan/telespeech-asr-python | ||
| 11 | + - [TeleSpeech模型社区许可协议.pdf](https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) | ||
| 12 | + | ||
| 13 | +for more details. |
scripts/tele-speech/add-metadata.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +import json | ||
| 4 | +from typing import Dict | ||
| 5 | + | ||
| 6 | +import onnx | ||
| 7 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
| 11 | + """Add meta data to an ONNX model. It is changed in-place. | ||
| 12 | + | ||
| 13 | + Args: | ||
| 14 | + filename: | ||
| 15 | + Filename of the ONNX model to be changed. | ||
| 16 | + meta_data: | ||
| 17 | + Key-value pairs. | ||
| 18 | + """ | ||
| 19 | + model = onnx.load(filename) | ||
| 20 | + | ||
| 21 | + while len(model.metadata_props): | ||
| 22 | + model.metadata_props.pop() | ||
| 23 | + | ||
| 24 | + for key, value in meta_data.items(): | ||
| 25 | + meta = model.metadata_props.add() | ||
| 26 | + meta.key = key | ||
| 27 | + meta.value = value | ||
| 28 | + | ||
| 29 | + onnx.save(model, filename) | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +def main(): | ||
| 33 | + with open("./vocab.json", "r", encoding="utf-8") as f: | ||
| 34 | + tokens = json.load(f) | ||
| 35 | + | ||
| 36 | + vocab_size = len(tokens) | ||
| 37 | + with open("tokens.txt", "w", encoding="utf-8") as f: | ||
| 38 | + for token, idx in tokens.items(): | ||
| 39 | + if idx == 0: | ||
| 40 | + f.write("<blk> 0\n") | ||
| 41 | + else: | ||
| 42 | + f.write(f"{token} {idx}\n") | ||
| 43 | + | ||
| 44 | + filename = "model.onnx" | ||
| 45 | + meta_data = { | ||
| 46 | + "model_type": "telespeech_ctc", | ||
| 47 | + "version": "1", | ||
| 48 | + "model_author": "Tele-AI", | ||
| 49 | + "comment": "See also https://github.com/lovemefan/telespeech-asr-python", | ||
| 50 | + "license": "https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf", | ||
| 51 | + "url": "https://github.com/Tele-AI/TeleSpeech-ASR", | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + add_meta_data(filename, meta_data) | ||
| 55 | + | ||
| 56 | + filename_int8 = f"model.int8.onnx" | ||
| 57 | + quantize_dynamic( | ||
| 58 | + model_input=filename, | ||
| 59 | + model_output=filename_int8, | ||
| 60 | + op_types_to_quantize=["MatMul"], | ||
| 61 | + weight_type=QuantType.QInt8, | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + # filename_uint8 = f"model.uint8.onnx" | ||
| 65 | + # quantize_dynamic( | ||
| 66 | + # model_input=filename, | ||
| 67 | + # model_output=filename_uint8, | ||
| 68 | + # op_types_to_quantize=["MatMul"], | ||
| 69 | + # weight_type=QuantType.QUInt8, | ||
| 70 | + # ) | ||
| 71 | + | ||
| 72 | + | ||
| 73 | +if __name__ == "__main__": | ||
| 74 | + main() |
scripts/tele-speech/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/model_export.onnx | ||
| 4 | + | ||
| 5 | +mv model_export.onnx model.onnx | ||
| 6 | + | ||
| 7 | +curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/vocab.json | ||
| 8 | + | ||
| 9 | +curl -SL -O https://github.com/csukuangfj/models/releases/download/a/TeleSpeech.pdf | ||
| 10 | +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/3-sichuan.wav | ||
| 11 | +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/4-tianjin.wav | ||
| 12 | +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/5-henan.wav | ||
| 13 | + | ||
| 14 | +ls -lh | ||
| 15 | + | ||
| 16 | +./add-metadata.py | ||
| 17 | + | ||
| 18 | +dst=sherpa-onnx-telespeech-ctc-zh-2024-06-04 | ||
| 19 | +mkdir $dst | ||
| 20 | +mkdir $dst/test_wavs | ||
| 21 | +cp -v model.onnx $dst/ | ||
| 22 | +cp -v tokens.txt $dst | ||
| 23 | +cp -v *.wav $dst/test_wavs | ||
| 24 | +cp -v *.pdf $dst | ||
| 25 | +cp -v README.md $dst | ||
| 26 | +cp -v *.py $dst | ||
| 27 | + | ||
| 28 | +ls -lh $dst | ||
| 29 | + | ||
| 30 | +tar cvjfv ${dst}.tar.bz2 $dst | ||
| 31 | + | ||
| 32 | +dst=sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04 | ||
| 33 | +mkdir $dst | ||
| 34 | +mkdir $dst/test_wavs | ||
| 35 | +cp -v model.int8.onnx $dst/ | ||
| 36 | +cp -v tokens.txt $dst | ||
| 37 | +cp -v *.wav $dst/test_wavs | ||
| 38 | +cp -v *.pdf $dst | ||
| 39 | +cp -v README.md $dst | ||
| 40 | +cp -v *.py $dst | ||
| 41 | + | ||
| 42 | +ls -lh $dst | ||
| 43 | + | ||
| 44 | +tar cvjfv ${dst}.tar.bz2 $dst | ||
| 45 | + | ||
| 46 | +cp -v *.tar.bz2 ../.. | ||
| 47 | + | ||
| 48 | +ls -lh ../../ |
scripts/tele-speech/test.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +from typing import Tuple | ||
| 5 | + | ||
| 6 | +import kaldi_native_fbank as knf | ||
| 7 | +import numpy as np | ||
| 8 | +import onnxruntime as ort | ||
| 9 | +import soundfile as sf | ||
| 10 | + | ||
| 11 | +""" | ||
| 12 | +NodeArg(name='feats', type='tensor(float)', shape=[1, 'T', 40]) | ||
| 13 | +----- | ||
| 14 | +NodeArg(name='logits', type='tensor(float)', shape=['Addlogits_dim_0', 1, 7535]) | ||
| 15 | +""" | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +class OnnxModel: | ||
| 19 | + def __init__( | ||
| 20 | + self, | ||
| 21 | + filename: str, | ||
| 22 | + ): | ||
| 23 | + session_opts = ort.SessionOptions() | ||
| 24 | + session_opts.inter_op_num_threads = 1 | ||
| 25 | + session_opts.intra_op_num_threads = 1 | ||
| 26 | + | ||
| 27 | + self.session_opts = session_opts | ||
| 28 | + | ||
| 29 | + self.model = ort.InferenceSession( | ||
| 30 | + filename, | ||
| 31 | + sess_options=self.session_opts, | ||
| 32 | + providers=["CPUExecutionProvider"], | ||
| 33 | + ) | ||
| 34 | + | ||
| 35 | + self.show() | ||
| 36 | + | ||
| 37 | + def show(self): | ||
| 38 | + for i in self.model.get_inputs(): | ||
| 39 | + print(i) | ||
| 40 | + | ||
| 41 | + print("-----") | ||
| 42 | + | ||
| 43 | + for i in self.model.get_outputs(): | ||
| 44 | + print(i) | ||
| 45 | + | ||
| 46 | + def __call__(self, x): | ||
| 47 | + """ | ||
| 48 | + Args: | ||
| 49 | + x: a float32 tensor of shape (N, T, C) | ||
| 50 | + """ | ||
| 51 | + logits = self.model.run( | ||
| 52 | + [ | ||
| 53 | + self.model.get_outputs()[0].name, | ||
| 54 | + ], | ||
| 55 | + { | ||
| 56 | + self.model.get_inputs()[0].name: x, | ||
| 57 | + }, | ||
| 58 | + )[0] | ||
| 59 | + | ||
| 60 | + return logits | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +def load_audio(filename: str) -> Tuple[np.ndarray, int]: | ||
| 64 | + data, sample_rate = sf.read( | ||
| 65 | + filename, | ||
| 66 | + always_2d=True, | ||
| 67 | + dtype="float32", | ||
| 68 | + ) | ||
| 69 | + data = data[:, 0] # use only the first channel | ||
| 70 | + samples = np.ascontiguousarray(data) | ||
| 71 | + return samples, sample_rate | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +def get_features(test_wav_filename): | ||
| 75 | + samples, sample_rate = load_audio(test_wav_filename) | ||
| 76 | + | ||
| 77 | + if sample_rate != 16000: | ||
| 78 | + import librosa | ||
| 79 | + | ||
| 80 | + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) | ||
| 81 | + sample_rate = 16000 | ||
| 82 | + | ||
| 83 | + samples *= 372768 | ||
| 84 | + | ||
| 85 | + opts = knf.MfccOptions() | ||
| 86 | + # See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/mfcc_hires.conf | ||
| 87 | + opts.frame_opts.dither = 0 | ||
| 88 | + | ||
| 89 | + opts.num_ceps = 40 | ||
| 90 | + opts.use_energy = False | ||
| 91 | + | ||
| 92 | + opts.mel_opts.num_bins = 40 | ||
| 93 | + opts.mel_opts.low_freq = 40 | ||
| 94 | + opts.mel_opts.high_freq = -200 | ||
| 95 | + | ||
| 96 | + mfcc = knf.OnlineMfcc(opts) | ||
| 97 | + mfcc.accept_waveform(16000, samples) | ||
| 98 | + frames = [] | ||
| 99 | + for i in range(mfcc.num_frames_ready): | ||
| 100 | + frames.append(mfcc.get_frame(i)) | ||
| 101 | + | ||
| 102 | + frames = np.stack(frames, axis=0) | ||
| 103 | + return frames | ||
| 104 | + | ||
| 105 | + | ||
| 106 | +def cmvn(features): | ||
| 107 | + # See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/conf/train_d2v2_ark_conformer.yaml#L70 | ||
| 108 | + # https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/dataset.py#L184 | ||
| 109 | + # https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/processor.py#L278 | ||
| 110 | + mean = features.mean(axis=0, keepdims=True) | ||
| 111 | + std = features.std(axis=0, keepdims=True) | ||
| 112 | + return (features - mean) / (std + 1e-5) | ||
| 113 | + | ||
| 114 | + | ||
| 115 | +def main(): | ||
| 116 | + # Please download the test data from | ||
| 117 | + # https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/tree/main/test_wavs | ||
| 118 | + test_wav_filename = "./3-sichuan.wav" | ||
| 119 | + test_wav_filename = "./4-tianjin.wav" | ||
| 120 | + test_wav_filename = "./5-henan.wav" | ||
| 121 | + | ||
| 122 | + features = get_features(test_wav_filename) | ||
| 123 | + | ||
| 124 | + features = cmvn(features) | ||
| 125 | + | ||
| 126 | + features = np.expand_dims(features, axis=0) # (T, C) -> (N, T, C) | ||
| 127 | + | ||
| 128 | + model_filename = "./model.int8.onnx" | ||
| 129 | + model = OnnxModel(model_filename) | ||
| 130 | + logits = model(features) | ||
| 131 | + logits = logits.squeeze(axis=1) # remove batch axis | ||
| 132 | + ids = logits.argmax(axis=-1) | ||
| 133 | + | ||
| 134 | + id2token = dict() | ||
| 135 | + with open("./tokens.txt", encoding="utf-8") as f: | ||
| 136 | + for line in f: | ||
| 137 | + t, idx = line.split() | ||
| 138 | + id2token[int(idx)] = t | ||
| 139 | + | ||
| 140 | + tokens = [] | ||
| 141 | + | ||
| 142 | + blank = 0 | ||
| 143 | + prev = -1 | ||
| 144 | + | ||
| 145 | + for k in ids: | ||
| 146 | + if k != blank and k != prev: | ||
| 147 | + tokens.append(k) | ||
| 148 | + prev = k | ||
| 149 | + | ||
| 150 | + tokens = [id2token[i] for i in tokens] | ||
| 151 | + text = "".join(tokens) | ||
| 152 | + print(text) | ||
| 153 | + | ||
| 154 | + | ||
| 155 | +if __name__ == "__main__": | ||
| 156 | + main() |
-
请 注册 或 登录 后发表评论