Fangjun Kuang
Committed by GitHub

export telespeech ctc models to sherpa-onnx (#968)

  1 +name: export-telespeech-ctc-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-telespeech-ctc-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-telespeech-ctc-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: telespeech
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [ubuntu-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install Python dependencies
  30 + shell: bash
  31 + run: |
  32 + pip install onnx onnxruntime soundfile librosa numpy kaldi-native-fbank
  33 +
  34 + - name: Run
  35 + shell: bash
  36 + run: |
  37 + cd scripts/tele-speech
  38 + ./run.sh
  39 +
  40 + ./test.py
  41 +
  42 + - name: Release
  43 + uses: svenstaro/upload-release-action@v2
  44 + with:
  45 + file_glob: true
  46 + file: ./*.tar.bz2
  47 + overwrite: true
  48 + repo_name: k2-fsa/sherpa-onnx
  49 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  50 + tag: asr-models
  1 +# Introduction
  2 +
  3 +This folder contains scripts about adding metadata to
  4 +onnx models from
  5 +https://hf-mirror.com/lovemefan/telespeech/tree/main
  6 +
  7 +Please see
  8 +
  9 + - https://github.com/Tele-AI/TeleSpeech-ASR
  10 + - https://github.com/lovemefan/telespeech-asr-python
  11 + - [TeleSpeech模型社区许可协议.pdf](https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
  12 +
  13 +for more details.
  1 +#!/usr/bin/env python3
  2 +
  3 +import json
  4 +from typing import Dict
  5 +
  6 +import onnx
  7 +from onnxruntime.quantization import QuantType, quantize_dynamic
  8 +
  9 +
  10 +def add_meta_data(filename: str, meta_data: Dict[str, str]):
  11 + """Add meta data to an ONNX model. It is changed in-place.
  12 +
  13 + Args:
  14 + filename:
  15 + Filename of the ONNX model to be changed.
  16 + meta_data:
  17 + Key-value pairs.
  18 + """
  19 + model = onnx.load(filename)
  20 +
  21 + while len(model.metadata_props):
  22 + model.metadata_props.pop()
  23 +
  24 + for key, value in meta_data.items():
  25 + meta = model.metadata_props.add()
  26 + meta.key = key
  27 + meta.value = value
  28 +
  29 + onnx.save(model, filename)
  30 +
  31 +
  32 +def main():
  33 + with open("./vocab.json", "r", encoding="utf-8") as f:
  34 + tokens = json.load(f)
  35 +
  36 + vocab_size = len(tokens)
  37 + with open("tokens.txt", "w", encoding="utf-8") as f:
  38 + for token, idx in tokens.items():
  39 + if idx == 0:
  40 + f.write("<blk> 0\n")
  41 + else:
  42 + f.write(f"{token} {idx}\n")
  43 +
  44 + filename = "model.onnx"
  45 + meta_data = {
  46 + "model_type": "telespeech_ctc",
  47 + "version": "1",
  48 + "model_author": "Tele-AI",
  49 + "comment": "See also https://github.com/lovemefan/telespeech-asr-python",
  50 + "license": "https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/TeleSpeech%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf",
  51 + "url": "https://github.com/Tele-AI/TeleSpeech-ASR",
  52 + }
  53 +
  54 + add_meta_data(filename, meta_data)
  55 +
  56 + filename_int8 = f"model.int8.onnx"
  57 + quantize_dynamic(
  58 + model_input=filename,
  59 + model_output=filename_int8,
  60 + op_types_to_quantize=["MatMul"],
  61 + weight_type=QuantType.QInt8,
  62 + )
  63 +
  64 + # filename_uint8 = f"model.uint8.onnx"
  65 + # quantize_dynamic(
  66 + # model_input=filename,
  67 + # model_output=filename_uint8,
  68 + # op_types_to_quantize=["MatMul"],
  69 + # weight_type=QuantType.QUInt8,
  70 + # )
  71 +
  72 +
  73 +if __name__ == "__main__":
  74 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/model_export.onnx
  4 +
  5 +mv model_export.onnx model.onnx
  6 +
  7 +curl -SL -O https://hf-mirror.com/lovemefan/telespeech/resolve/main/vocab.json
  8 +
  9 +curl -SL -O https://github.com/csukuangfj/models/releases/download/a/TeleSpeech.pdf
  10 +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/3-sichuan.wav
  11 +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/4-tianjin.wav
  12 +curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/resolve/main/test_wavs/5-henan.wav
  13 +
  14 +ls -lh
  15 +
  16 +./add-metadata.py
  17 +
  18 +dst=sherpa-onnx-telespeech-ctc-zh-2024-06-04
  19 +mkdir $dst
  20 +mkdir $dst/test_wavs
  21 +cp -v model.onnx $dst/
  22 +cp -v tokens.txt $dst
  23 +cp -v *.wav $dst/test_wavs
  24 +cp -v *.pdf $dst
  25 +cp -v README.md $dst
  26 +cp -v *.py $dst
  27 +
  28 +ls -lh $dst
  29 +
  30 +tar cvjfv ${dst}.tar.bz2 $dst
  31 +
  32 +dst=sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04
  33 +mkdir $dst
  34 +mkdir $dst/test_wavs
  35 +cp -v model.int8.onnx $dst/
  36 +cp -v tokens.txt $dst
  37 +cp -v *.wav $dst/test_wavs
  38 +cp -v *.pdf $dst
  39 +cp -v README.md $dst
  40 +cp -v *.py $dst
  41 +
  42 +ls -lh $dst
  43 +
  44 +tar cvjfv ${dst}.tar.bz2 $dst
  45 +
  46 +cp -v *.tar.bz2 ../..
  47 +
  48 +ls -lh ../../
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +from typing import Tuple
  5 +
  6 +import kaldi_native_fbank as knf
  7 +import numpy as np
  8 +import onnxruntime as ort
  9 +import soundfile as sf
  10 +
  11 +"""
  12 +NodeArg(name='feats', type='tensor(float)', shape=[1, 'T', 40])
  13 +-----
  14 +NodeArg(name='logits', type='tensor(float)', shape=['Addlogits_dim_0', 1, 7535])
  15 +"""
  16 +
  17 +
  18 +class OnnxModel:
  19 + def __init__(
  20 + self,
  21 + filename: str,
  22 + ):
  23 + session_opts = ort.SessionOptions()
  24 + session_opts.inter_op_num_threads = 1
  25 + session_opts.intra_op_num_threads = 1
  26 +
  27 + self.session_opts = session_opts
  28 +
  29 + self.model = ort.InferenceSession(
  30 + filename,
  31 + sess_options=self.session_opts,
  32 + providers=["CPUExecutionProvider"],
  33 + )
  34 +
  35 + self.show()
  36 +
  37 + def show(self):
  38 + for i in self.model.get_inputs():
  39 + print(i)
  40 +
  41 + print("-----")
  42 +
  43 + for i in self.model.get_outputs():
  44 + print(i)
  45 +
  46 + def __call__(self, x):
  47 + """
  48 + Args:
  49 + x: a float32 tensor of shape (N, T, C)
  50 + """
  51 + logits = self.model.run(
  52 + [
  53 + self.model.get_outputs()[0].name,
  54 + ],
  55 + {
  56 + self.model.get_inputs()[0].name: x,
  57 + },
  58 + )[0]
  59 +
  60 + return logits
  61 +
  62 +
  63 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  64 + data, sample_rate = sf.read(
  65 + filename,
  66 + always_2d=True,
  67 + dtype="float32",
  68 + )
  69 + data = data[:, 0] # use only the first channel
  70 + samples = np.ascontiguousarray(data)
  71 + return samples, sample_rate
  72 +
  73 +
  74 +def get_features(test_wav_filename):
  75 + samples, sample_rate = load_audio(test_wav_filename)
  76 +
  77 + if sample_rate != 16000:
  78 + import librosa
  79 +
  80 + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
  81 + sample_rate = 16000
  82 +
  83 + samples *= 372768
  84 +
  85 + opts = knf.MfccOptions()
  86 + # See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/mfcc_hires.conf
  87 + opts.frame_opts.dither = 0
  88 +
  89 + opts.num_ceps = 40
  90 + opts.use_energy = False
  91 +
  92 + opts.mel_opts.num_bins = 40
  93 + opts.mel_opts.low_freq = 40
  94 + opts.mel_opts.high_freq = -200
  95 +
  96 + mfcc = knf.OnlineMfcc(opts)
  97 + mfcc.accept_waveform(16000, samples)
  98 + frames = []
  99 + for i in range(mfcc.num_frames_ready):
  100 + frames.append(mfcc.get_frame(i))
  101 +
  102 + frames = np.stack(frames, axis=0)
  103 + return frames
  104 +
  105 +
  106 +def cmvn(features):
  107 + # See https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/conf/train_d2v2_ark_conformer.yaml#L70
  108 + # https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/dataset.py#L184
  109 + # https://github.com/Tele-AI/TeleSpeech-ASR/blob/master/wenet_representation/wenet/dataset/processor.py#L278
  110 + mean = features.mean(axis=0, keepdims=True)
  111 + std = features.std(axis=0, keepdims=True)
  112 + return (features - mean) / (std + 1e-5)
  113 +
  114 +
  115 +def main():
  116 + # Please download the test data from
  117 + # https://hf-mirror.com/csukuangfj/sherpa-onnx-paraformer-zh-small-2024-03-09/tree/main/test_wavs
  118 + test_wav_filename = "./3-sichuan.wav"
  119 + test_wav_filename = "./4-tianjin.wav"
  120 + test_wav_filename = "./5-henan.wav"
  121 +
  122 + features = get_features(test_wav_filename)
  123 +
  124 + features = cmvn(features)
  125 +
  126 + features = np.expand_dims(features, axis=0) # (T, C) -> (N, T, C)
  127 +
  128 + model_filename = "./model.int8.onnx"
  129 + model = OnnxModel(model_filename)
  130 + logits = model(features)
  131 + logits = logits.squeeze(axis=1) # remove batch axis
  132 + ids = logits.argmax(axis=-1)
  133 +
  134 + id2token = dict()
  135 + with open("./tokens.txt", encoding="utf-8") as f:
  136 + for line in f:
  137 + t, idx = line.split()
  138 + id2token[int(idx)] = t
  139 +
  140 + tokens = []
  141 +
  142 + blank = 0
  143 + prev = -1
  144 +
  145 + for k in ids:
  146 + if k != blank and k != prev:
  147 + tokens.append(k)
  148 + prev = k
  149 +
  150 + tokens = [id2token[i] for i in tokens]
  151 + text = "".join(tokens)
  152 + print(text)
  153 +
  154 +
  155 +if __name__ == "__main__":
  156 + main()