Fangjun Kuang
Committed by GitHub

export sense-voice to onnx (#1144)

@@ -40,7 +40,7 @@ jobs: @@ -40,7 +40,7 @@ jobs:
40 name: test.wav 40 name: test.wav
41 path: scripts/melo-tts/test.wav 41 path: scripts/melo-tts/test.wav
42 42
43 - - name: Publish to huggingface (aishell) 43 + - name: Publish to huggingface
44 env: 44 env:
45 HF_TOKEN: ${{ secrets.HF_TOKEN }} 45 HF_TOKEN: ${{ secrets.HF_TOKEN }}
46 uses: nick-fields/retry@v3 46 uses: nick-fields/retry@v3
  1 +name: export-sense-voice-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-sense-voice-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-sense-voice-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export sense-voice
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [ubuntu-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Download test_wavs
  30 + shell: bash
  31 + run: |
  32 + sudo apt-get install -y -qq sox libsox-fmt-mp3
  33 + curl -SL -O https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/example/zh.mp3
  34 + curl -SL -O https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/example/en.mp3
  35 + curl -SL -O https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/example/ja.mp3
  36 + curl -SL -O https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/example/ko.mp3
  37 + curl -SL -O https://huggingface.co/FunAudioLLM/SenseVoiceSmall/resolve/main/example/yue.mp3
  38 +
  39 + soxi *.mp3
  40 +
  41 + sox zh.mp3 -r 16k zh.wav
  42 + sox en.mp3 -r 16k en.wav
  43 + sox ja.mp3 -r 16k ja.wav
  44 + sox ko.mp3 -r 16k ko.wav
  45 + sox yue.mp3 -r 16k yue.wav
  46 +
  47 + - name: Run
  48 + shell: bash
  49 + run: |
  50 + cd scripts/sense-voice
  51 + ./run.sh
  52 +
  53 + - name: Publish to huggingface
  54 + env:
  55 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  56 + uses: nick-fields/retry@v3
  57 + with:
  58 + max_attempts: 20
  59 + timeout_seconds: 200
  60 + shell: bash
  61 + command: |
  62 + git config --global user.email "csukuangfj@gmail.com"
  63 + git config --global user.name "Fangjun Kuang"
  64 +
  65 + rm -rf huggingface
  66 + export GIT_LFS_SKIP_SMUDGE=1
  67 + export GIT_CLONE_PROTECTION_ACTIVE=false
  68 +
  69 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 huggingface
  70 + cd huggingface
  71 + git fetch
  72 + git pull
  73 + echo "pwd: $PWD"
  74 + ls -lh ../scripts/sense-voice
  75 +
  76 + rm -rf ./
  77 +
  78 + cp -v ../scripts/sense-voice/*.onnx .
  79 + cp -v ../scripts/sense-voice/tokens.txt .
  80 + cp -v ../scripts/sense-voice/README.md .
  81 + cp -v ../scripts/sense-voice/export-onnx.py .
  82 +
  83 + mkdir test_wavs
  84 + cp -v ../*.wav ./test_wavs/
  85 +
  86 + curl -SL -O https://raw.githubusercontent.com/FunAudioLLM/SenseVoice/main/LICENSE
  87 +
  88 + git lfs track "*.onnx"
  89 + git add .
  90 +
  91 + ls -lh
  92 +
  93 + git status
  94 +
  95 + git commit -m "add models"
  96 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17 main || true
  97 +
  98 + cd ..
  99 +
  100 + rm -rf huggingface/.git*
  101 + dst=sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17
  102 +
  103 + mv huggingface $dst
  104 +
  105 + tar cjvf $dst.tar.bz2 $dst
  106 + rm -rf $dst
  107 +
  108 + - name: Release
  109 + uses: svenstaro/upload-release-action@v2
  110 + with:
  111 + file_glob: true
  112 + file: ./*.tar.bz2
  113 + overwrite: true
  114 + repo_name: k2-fsa/sherpa-onnx
  115 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  116 + tag: asr-models
@@ -2,8 +2,6 @@ @@ -2,8 +2,6 @@
2 2
3 set -ex 3 set -ex
4 4
5 -  
6 -  
7 function install() { 5 function install() {
8 pip install torch==2.3.1+cpu torchaudio==2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html 6 pip install torch==2.3.1+cpu torchaudio==2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
9 7
  1 +# Introduction
  2 +
  3 +This directory contains models converted from
  4 +https://github.com/FunAudioLLM/SenseVoice
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +We use
  6 +https://hf-mirror.com/yuekai/model_repo_sense_voice_small/blob/main/export_onnx.py
  7 +as a reference while writing this file.
  8 +
  9 +Thanks to https://github.com/yuekaizhang for making the file public.
  10 +"""
  11 +
  12 +import os
  13 +from typing import Any, Dict, Tuple
  14 +
  15 +import onnx
  16 +import torch
  17 +from model import SenseVoiceSmall
  18 +from onnxruntime.quantization import QuantType, quantize_dynamic
  19 +
  20 +
  21 +def add_meta_data(filename: str, meta_data: Dict[str, Any]):
  22 + """Add meta data to an ONNX model. It is changed in-place.
  23 +
  24 + Args:
  25 + filename:
  26 + Filename of the ONNX model to be changed.
  27 + meta_data:
  28 + Key-value pairs.
  29 + """
  30 + model = onnx.load(filename)
  31 + while len(model.metadata_props):
  32 + model.metadata_props.pop()
  33 +
  34 + for key, value in meta_data.items():
  35 + meta = model.metadata_props.add()
  36 + meta.key = key
  37 + meta.value = str(value)
  38 +
  39 + onnx.save(model, filename)
  40 +
  41 +
  42 +def modified_forward(
  43 + self,
  44 + x: torch.Tensor,
  45 + x_length: torch.Tensor,
  46 + language: torch.Tensor,
  47 + text_norm: torch.Tensor,
  48 +):
  49 + """
  50 + Args:
  51 + x:
  52 + A 3-D tensor of shape (N, T, C) with dtype torch.float32
  53 + x_length:
  54 + A 1-D tensor of shape (N,) with dtype torch.int32
  55 + language:
  56 + A 1-D tensor of shape (N,) with dtype torch.int32
  57 + See also https://github.com/FunAudioLLM/SenseVoice/blob/a80e676461b24419cf1130a33d4dd2f04053e5cc/model.py#L640
  58 + text_norm:
  59 + A 1-D tensor of shape (N,) with dtype torch.int32
  60 + See also https://github.com/FunAudioLLM/SenseVoice/blob/a80e676461b24419cf1130a33d4dd2f04053e5cc/model.py#L642
  61 + """
  62 + language_query = self.embed(language).unsqueeze(1)
  63 + text_norm_query = self.embed(text_norm).unsqueeze(1)
  64 +
  65 + event_emo_query = self.embed(torch.LongTensor([[1, 2]])).repeat(x.size(0), 1, 1)
  66 +
  67 + x = torch.cat((language_query, event_emo_query, text_norm_query, x), dim=1)
  68 + x_length += 4
  69 +
  70 + encoder_out, encoder_out_lens = self.encoder(x, x_length)
  71 + if isinstance(encoder_out, tuple):
  72 + encoder_out = encoder_out[0]
  73 +
  74 + ctc_logits = self.ctc.ctc_lo(encoder_out)
  75 +
  76 + return ctc_logits
  77 +
  78 +
  79 +def load_cmvn(filename) -> Tuple[str, str]:
  80 + neg_mean = None
  81 + inv_stddev = None
  82 +
  83 + with open(filename) as f:
  84 + for line in f:
  85 + if not line.startswith("<LearnRateCoef>"):
  86 + continue
  87 + t = line.split()[3:-1]
  88 +
  89 + if neg_mean is None:
  90 + neg_mean = ",".join(t)
  91 + else:
  92 + inv_stddev = ",".join(t)
  93 +
  94 + return neg_mean, inv_stddev
  95 +
  96 +
  97 +def generate_tokens(params):
  98 + sp = params["tokenizer"].sp
  99 + with open("tokens.txt", "w", encoding="utf-8") as f:
  100 + for i in range(sp.vocab_size()):
  101 + f.write(f"{sp.id_to_piece(i)} {i}\n")
  102 +
  103 + os.system("head tokens.txt; tail -n200 tokens.txt")
  104 +
  105 +
  106 +def display_params(params):
  107 + print("----------params----------")
  108 + print(params)
  109 +
  110 + print("----------frontend_conf----------")
  111 + print(params["frontend_conf"])
  112 +
  113 + os.system(f"cat {params['frontend_conf']['cmvn_file']}")
  114 +
  115 + print("----------config----------")
  116 + print(params["config"])
  117 +
  118 + os.system(f"cat {params['config']}")
  119 +
  120 +
  121 +def main():
  122 + model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall")
  123 + display_params(params)
  124 +
  125 + generate_tokens(params)
  126 +
  127 + model.__class__.forward = modified_forward
  128 +
  129 + x = torch.randn(2, 100, 560, dtype=torch.float32)
  130 + x_length = torch.tensor([80, 100], dtype=torch.int32)
  131 + language = torch.tensor([0, 3], dtype=torch.int32)
  132 + text_norm = torch.tensor([14, 15], dtype=torch.int32)
  133 +
  134 + opset_version = 13
  135 + filename = "model.onnx"
  136 + torch.onnx.export(
  137 + model,
  138 + (x, x_length, language, text_norm),
  139 + filename,
  140 + opset_version=opset_version,
  141 + input_names=["x", "x_length", "language", "text_norm"],
  142 + output_names=["logits"],
  143 + dynamic_axes={
  144 + "x": {0: "N", 1: "T"},
  145 + "x_length": {0: "N"},
  146 + "language": {0: "N"},
  147 + "text_norm": {0: "N"},
  148 + "logits": {0: "N", 1: "T"},
  149 + },
  150 + )
  151 +
  152 + lfr_window_size = params["frontend_conf"]["lfr_m"]
  153 + lfr_window_shift = params["frontend_conf"]["lfr_n"]
  154 +
  155 + neg_mean, inv_stddev = load_cmvn(params["frontend_conf"]["cmvn_file"])
  156 + vocab_size = params["tokenizer"].sp.vocab_size()
  157 +
  158 + meta_data = {
  159 + "lfr_window_size": lfr_window_size,
  160 + "lfr_window_shift": lfr_window_shift,
  161 + "neg_mean": neg_mean,
  162 + "inv_stddev": inv_stddev,
  163 + "model_type": "sense_voice_ctc",
  164 + "version": "1",
  165 + "model_author": "iic",
  166 + "maintainer": "k2-fsa",
  167 + "vocab_size": vocab_size,
  168 + "comment": "iic/SenseVoiceSmall",
  169 + "lang_auto": model.lid_dict["auto"],
  170 + "lang_zh": model.lid_dict["zh"],
  171 + "lang_en": model.lid_dict["en"],
  172 + "lang_yue": model.lid_dict["yue"], # cantonese
  173 + "lang_ja": model.lid_dict["ja"],
  174 + "lang_ko": model.lid_dict["ko"],
  175 + "lang_nospeech": model.lid_dict["nospeech"],
  176 + "with_itn": model.textnorm_dict["withitn"],
  177 + "without_itn": model.textnorm_dict["woitn"],
  178 + "url": "https://huggingface.co/FunAudioLLM/SenseVoiceSmall",
  179 + }
  180 + add_meta_data(filename=filename, meta_data=meta_data)
  181 +
  182 + filename_int8 = "model.int8.onnx"
  183 + quantize_dynamic(
  184 + model_input=filename,
  185 + model_output=filename_int8,
  186 + op_types_to_quantize=["MatMul"],
  187 + weight_type=QuantType.QInt8,
  188 + )
  189 +
  190 +
  191 +if __name__ == "__main__":
  192 + torch.manual_seed(20240717)
  193 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +
  6 +function install() {
  7 + pip install torch==2.3.1+cpu torchaudio==2.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
  8 +
  9 + pushd /tmp
  10 +
  11 + git clone https://github.com/alibaba/FunASR.git
  12 + cd FunASR
  13 + pip3 install -qq -e ./
  14 + cd ..
  15 +
  16 + git clone https://github.com/FunAudioLLM/SenseVoice
  17 + cd SenseVoice
  18 + pip install -qq -r ./requirements.txt
  19 + cd ..
  20 +
  21 + pip install soundfile onnx onnxruntime kaldi-native-fbank librosa soundfile
  22 +
  23 + popd
  24 +}
  25 +
  26 +install
  27 +
  28 +export PYTHONPATH=/tmp/FunASR:$PYTHONPATH
  29 +export PYTHONPATH=/tmp/SenseVoice:$PYTHONPATH
  30 +
  31 +echo "pwd: $PWD"
  32 +
  33 +./export-onnx.py
  34 +
  35 +./show-info.py
  36 +
  37 +ls -lh
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnxruntime
  5 +
  6 +
  7 +def show(filename):
  8 + session_opts = onnxruntime.SessionOptions()
  9 + session_opts.log_severity_level = 3
  10 + sess = onnxruntime.InferenceSession(filename, session_opts)
  11 + for i in sess.get_inputs():
  12 + print(i)
  13 +
  14 + print("-----")
  15 +
  16 + for i in sess.get_outputs():
  17 + print(i)
  18 +
  19 + meta = sess.get_modelmeta().custom_metadata_map
  20 + print("*****************************************")
  21 + print("meta\n", meta)
  22 +
  23 +
  24 +def main():
  25 + print("=========model==========")
  26 + show("./model.onnx")
  27 +
  28 +
  29 +if __name__ == "__main__":
  30 + main()
  31 +"""
  32 +=========model==========
  33 +NodeArg(name='x', type='tensor(float)', shape=['N', 'T', 560])
  34 +NodeArg(name='x_length', type='tensor(int32)', shape=['N'])
  35 +NodeArg(name='language', type='tensor(int32)', shape=['N'])
  36 +NodeArg(name='text_norm', type='tensor(int32)', shape=['N'])
  37 +-----
  38 +NodeArg(name='logits', type='tensor(float)', shape=['N', 'T', 25055])
  39 +*****************************************
  40 +"""