Fangjun Kuang
Committed by GitHub
  1 +name: export-moonshine-to-onnx
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-moonshine-to-onnx-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-moonshine-to-onnx:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export moonshine models to ONNX
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [macos-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install Python dependencies
  30 + shell: bash
  31 + run: |
  32 + pip install -q onnx onnxruntime librosa tokenizers soundfile
  33 +
  34 + - name: Run
  35 + shell: bash
  36 + run: |
  37 + pushd scripts/moonshine
  38 + ./run.sh
  39 + popd
  40 +
  41 + mv -v scripts/moonshine/*.tar.bz2 .
  42 + mv -v scripts/moonshine/sherpa-onnx-* ./
  43 +
  44 + - name: Release
  45 + uses: svenstaro/upload-release-action@v2
  46 + with:
  47 + file_glob: true
  48 + file: ./*.tar.bz2
  49 + overwrite: true
  50 + repo_name: k2-fsa/sherpa-onnx
  51 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  52 + tag: asr-models
  53 +
  54 + - name: Publish to huggingface (tiny)
  55 + env:
  56 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  57 + uses: nick-fields/retry@v3
  58 + with:
  59 + max_attempts: 20
  60 + timeout_seconds: 200
  61 + shell: bash
  62 + command: |
  63 + git config --global user.email "csukuangfj@gmail.com"
  64 + git config --global user.name "Fangjun Kuang"
  65 +
  66 + d=sherpa-onnx-moonshine-tiny-en-int8
  67 + export GIT_LFS_SKIP_SMUDGE=1
  68 + export GIT_CLONE_PROTECTION_ACTIVE=false
  69 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
  70 + mv -v $d/* ./huggingface
  71 + cd huggingface
  72 + git lfs track "*.onnx"
  73 + git lfs track "*.wav"
  74 + git status
  75 + git add .
  76 + git status
  77 + git commit -m "add models"
  78 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
  79 + rm -rf huggingface
  80 +
  81 + - name: Publish to huggingface (base)
  82 + env:
  83 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  84 + uses: nick-fields/retry@v3
  85 + with:
  86 + max_attempts: 20
  87 + timeout_seconds: 200
  88 + shell: bash
  89 + command: |
  90 + git config --global user.email "csukuangfj@gmail.com"
  91 + git config --global user.name "Fangjun Kuang"
  92 +
  93 + d=sherpa-onnx-moonshine-base-en-int8
  94 + export GIT_LFS_SKIP_SMUDGE=1
  95 + export GIT_CLONE_PROTECTION_ACTIVE=false
  96 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
  97 + mv -v $d/* ./huggingface
  98 + cd huggingface
  99 + git lfs track "*.onnx"
  100 + git lfs track "*.wav"
  101 + git status
  102 + git add .
  103 + git status
  104 + git commit -m "add models"
  105 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
  106 + rm -rf huggingface
  1 +# Introduction
  2 +
  3 +This directory contains models from
  4 +https://github.com/usefulsensors/moonshine
  5 +
  6 +See its license at
  7 +https://github.com/usefulsensors/moonshine/blob/main/LICENSE
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +from pathlib import Path
  5 +
  6 +import tokenizers
  7 +from onnxruntime.quantization import QuantType, quantize_dynamic
  8 +
  9 +
  10 +def generate_tokens():
  11 + if Path("./tokens.txt").is_file():
  12 + return
  13 + print("Generating tokens.txt")
  14 + tokenizer = tokenizers.Tokenizer.from_file("./tokenizer.json")
  15 + vocab_size = tokenizer.get_vocab_size()
  16 + with open("tokens.txt", "w", encoding="utf-8") as f:
  17 + for i in range(vocab_size):
  18 + s = tokenizer.id_to_token(i).strip()
  19 + f.write(f"{s}\t{i}\n")
  20 +
  21 +
  22 +def main():
  23 + generate_tokens()
  24 +
  25 + # Note(fangjun): Don't use int8 for the preprocessor since it has
  26 + # a larger impact on the accuracy
  27 + for f in ["uncached_decode", "cached_decode", "encode"]:
  28 + if Path(f"{f}.int8.onnx").is_file():
  29 + continue
  30 +
  31 + print("processing", f)
  32 + quantize_dynamic(
  33 + model_input=f"{f}.onnx",
  34 + model_output=f"{f}.int8.onnx",
  35 + weight_type=QuantType.QInt8,
  36 + )
  37 +
  38 +
  39 +if __name__ == "__main__":
  40 + main()
  1 +#!/usr/bin/env bash
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +set -ex
  4 +
  5 +cat >LICENSE <<EOF
  6 +MIT License
  7 +
  8 +Copyright (c) 2024 Useful Sensors
  9 +
  10 +Permission is hereby granted, free of charge, to any person obtaining a copy
  11 +of this software and associated documentation files (the "Software"), to deal
  12 +in the Software without restriction, including without limitation the rights
  13 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  14 +copies of the Software, and to permit persons to whom the Software is
  15 +furnished to do so, subject to the following conditions:
  16 +
  17 +The above copyright notice and this permission notice shall be included in all
  18 +copies or substantial portions of the Software.
  19 +
  20 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  21 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  22 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  23 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  24 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  25 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  26 +SOFTWARE.
  27 +EOF
  28 +
  29 +function download_files() {
  30 + for d in tiny base; do
  31 + mkdir $d
  32 +
  33 + pushd $d
  34 + curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/preprocess.onnx
  35 + curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/encode.onnx
  36 + curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/uncached_decode.onnx
  37 + curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/cached_decode.onnx
  38 + popd
  39 + done
  40 +
  41 + curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/0.wav
  42 + curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/1.wav
  43 + curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/8k.wav
  44 + curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/trans.txt
  45 +
  46 + curl -SL -O https://raw.githubusercontent.com/usefulsensors/moonshine/refs/heads/main/moonshine/assets/tokenizer.json
  47 +}
  48 +
  49 +function quantize() {
  50 + for d in tiny base; do
  51 + echo "==========$d=========="
  52 + ls -lh
  53 + mv $d/*.onnx .
  54 + ./export-onnx.py
  55 + rm cached_decode.onnx
  56 + rm uncached_decode.onnx
  57 + rm encode.onnx
  58 + ls -lh
  59 +
  60 + ./test.py
  61 +
  62 + mv *.onnx $d
  63 + mv tokens.txt $d
  64 + ls -lh $d
  65 +
  66 + done
  67 +}
  68 +
  69 +function zip() {
  70 + for d in tiny base; do
  71 + s=sherpa-onnx-moonshine-$d-en-int8
  72 + mv $d $s
  73 +
  74 + mkdir $s/test_wavs
  75 +
  76 + cp -v *.wav $s/test_wavs
  77 + cp trans.txt $s/test_wavs
  78 + cp LICENSE $s/
  79 + cp ./README.md $s
  80 +
  81 + ls -lh $s
  82 + tar cjfv $s.tar.bz2 $s
  83 + done
  84 +}
  85 +
  86 +download_files
  87 +quantize
  88 +zip
  89 +
  90 +ls -lh
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +import datetime as dt
  4 +
  5 +import librosa
  6 +import numpy as np
  7 +import onnxruntime as ort
  8 +import soundfile as sf
  9 +
  10 +
  11 +def display(sess, name):
  12 + print(f"=========={name} Input==========")
  13 + for i in sess.get_inputs():
  14 + print(i)
  15 + print(f"=========={name} Output==========")
  16 + for i in sess.get_outputs():
  17 + print(i)
  18 +
  19 +
  20 +class OnnxModel:
  21 + def __init__(
  22 + self,
  23 + preprocess: str,
  24 + encode: str,
  25 + uncached_decode: str,
  26 + cached_decode: str,
  27 + ):
  28 + self.init_preprocess(preprocess)
  29 + display(self.preprocess, "preprocess")
  30 +
  31 + self.init_encode(encode)
  32 + display(self.encode, "encode")
  33 +
  34 + self.init_uncached_decode(uncached_decode)
  35 + display(self.uncached_decode, "uncached_decode")
  36 +
  37 + self.init_cached_decode(cached_decode)
  38 + display(self.cached_decode, "cached_decode")
  39 +
  40 + def init_preprocess(self, preprocess):
  41 + session_opts = ort.SessionOptions()
  42 + session_opts.inter_op_num_threads = 1
  43 + session_opts.intra_op_num_threads = 1
  44 +
  45 + self.preprocess = ort.InferenceSession(
  46 + preprocess,
  47 + sess_options=session_opts,
  48 + providers=["CPUExecutionProvider"],
  49 + )
  50 +
  51 + def init_encode(self, encode):
  52 + session_opts = ort.SessionOptions()
  53 + session_opts.inter_op_num_threads = 1
  54 + session_opts.intra_op_num_threads = 1
  55 +
  56 + self.encode = ort.InferenceSession(
  57 + encode,
  58 + sess_options=session_opts,
  59 + providers=["CPUExecutionProvider"],
  60 + )
  61 +
  62 + def init_uncached_decode(self, uncached_decode):
  63 + session_opts = ort.SessionOptions()
  64 + session_opts.inter_op_num_threads = 1
  65 + session_opts.intra_op_num_threads = 1
  66 +
  67 + self.uncached_decode = ort.InferenceSession(
  68 + uncached_decode,
  69 + sess_options=session_opts,
  70 + providers=["CPUExecutionProvider"],
  71 + )
  72 +
  73 + def init_cached_decode(self, cached_decode):
  74 + session_opts = ort.SessionOptions()
  75 + session_opts.inter_op_num_threads = 1
  76 + session_opts.intra_op_num_threads = 1
  77 +
  78 + self.cached_decode = ort.InferenceSession(
  79 + cached_decode,
  80 + sess_options=session_opts,
  81 + providers=["CPUExecutionProvider"],
  82 + )
  83 +
  84 + def run_preprocess(self, audio):
  85 + """
  86 + Args:
  87 + audio: (batch_size, num_samples), float32
  88 + Returns:
  89 + A tensor of shape (batch_size, T, dim), float32
  90 + """
  91 + return self.preprocess.run(
  92 + [
  93 + self.preprocess.get_outputs()[0].name,
  94 + ],
  95 + {
  96 + self.preprocess.get_inputs()[0].name: audio,
  97 + },
  98 + )[0]
  99 +
  100 + def run_encode(self, features):
  101 + """
  102 + Args:
  103 + features: (batch_size, T, dim)
  104 + Returns:
  105 + A tensor of shape (batch_size, T, dim)
  106 + """
  107 + features_len = np.array([features.shape[1]], dtype=np.int32)
  108 +
  109 + return self.encode.run(
  110 + [
  111 + self.encode.get_outputs()[0].name,
  112 + ],
  113 + {
  114 + self.encode.get_inputs()[0].name: features,
  115 + self.encode.get_inputs()[1].name: features_len,
  116 + },
  117 + )[0]
  118 +
  119 + def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
  120 + """
  121 + Args:
  122 + token: The current token
  123 + token_len: Number of predicted tokens so far
  124 + encoder_out: A tensor fo shape (batch_size, T, dim)
  125 + Returns:
  126 + A a tuple:
  127 + - a tensor of shape (batch_size, 1, dim)
  128 + - a list of states
  129 + """
  130 + token_tensor = np.array([[token]], dtype=np.int32)
  131 + token_len_tensor = np.array([token_len], dtype=np.int32)
  132 +
  133 + num_outs = len(self.uncached_decode.get_outputs())
  134 + out_names = [
  135 + self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
  136 + ]
  137 +
  138 + out = self.uncached_decode.run(
  139 + out_names,
  140 + {
  141 + self.uncached_decode.get_inputs()[0].name: token_tensor,
  142 + self.uncached_decode.get_inputs()[1].name: encoder_out,
  143 + self.uncached_decode.get_inputs()[2].name: token_len_tensor,
  144 + },
  145 + )
  146 +
  147 + logits = out[0]
  148 + states = out[1:]
  149 +
  150 + return logits, states
  151 +
  152 + def run_cached_decode(
  153 + self, token: int, token_len: int, encoder_out: np.ndarray, states
  154 + ):
  155 + """
  156 + Args:
  157 + token: The current token
  158 + token_len: Number of predicted tokens so far
  159 + encoder_out: A tensor of shape (batch_size, T, dim)
  160 + states: previous states
  161 + Returns:
  162 + A a tuple:
  163 + - a tensor of shape (batch_size, 1, dim)
  164 + - a list of states
  165 + """
  166 + token_tensor = np.array([[token]], dtype=np.int32)
  167 + token_len_tensor = np.array([token_len], dtype=np.int32)
  168 +
  169 + num_outs = len(self.cached_decode.get_outputs())
  170 + out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
  171 +
  172 + states_inputs = {}
  173 + for i in range(3, len(self.cached_decode.get_inputs())):
  174 + name = self.cached_decode.get_inputs()[i].name
  175 + states_inputs[name] = states[i - 3]
  176 +
  177 + out = self.cached_decode.run(
  178 + out_names,
  179 + {
  180 + self.cached_decode.get_inputs()[0].name: token_tensor,
  181 + self.cached_decode.get_inputs()[1].name: encoder_out,
  182 + self.cached_decode.get_inputs()[2].name: token_len_tensor,
  183 + **states_inputs,
  184 + },
  185 + )
  186 +
  187 + logits = out[0]
  188 + states = out[1:]
  189 +
  190 + return logits, states
  191 +
  192 +
  193 +def main():
  194 + wave = "./1.wav"
  195 + id2token = dict()
  196 + token2id = dict()
  197 + with open("./tokens.txt", encoding="utf-8") as f:
  198 + for k, line in enumerate(f):
  199 + t, idx = line.split("\t")
  200 + id2token[int(idx)] = t
  201 + token2id[t] = int(idx)
  202 +
  203 + model = OnnxModel(
  204 + preprocess="./preprocess.onnx",
  205 + encode="./encode.int8.onnx",
  206 + uncached_decode="./uncached_decode.int8.onnx",
  207 + cached_decode="./cached_decode.int8.onnx",
  208 + )
  209 +
  210 + audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
  211 + audio = audio[:, 0] # only use the first channel
  212 + if sample_rate != 16000:
  213 + audio = librosa.resample(
  214 + audio,
  215 + orig_sr=sample_rate,
  216 + target_sr=16000,
  217 + )
  218 + sample_rate = 16000
  219 + audio = audio[None] # (1, num_samples)
  220 + print("audio.shape", audio.shape) # (1, 159414)
  221 +
  222 + start_t = dt.datetime.now()
  223 +
  224 + features = model.run_preprocess(audio) # (1, 413, 288)
  225 + print("features", features.shape)
  226 +
  227 + sos = token2id["<s>"]
  228 + eos = token2id["</s>"]
  229 +
  230 + tokens = [sos]
  231 +
  232 + encoder_out = model.run_encode(features)
  233 + print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
  234 +
  235 + logits, states = model.run_uncached_decode(
  236 + token=tokens[-1],
  237 + token_len=len(tokens),
  238 + encoder_out=encoder_out,
  239 + )
  240 +
  241 + print("logits.shape", logits.shape) # (1, 1, 32768)
  242 + print("len(states)", len(states)) # 24
  243 +
  244 + max_len = int((audio.shape[-1] / 16000) * 6)
  245 +
  246 + for i in range(max_len):
  247 + token = logits.squeeze().argmax()
  248 + if token == eos:
  249 + break
  250 + tokens.append(token)
  251 +
  252 + logits, states = model.run_cached_decode(
  253 + token=tokens[-1],
  254 + token_len=len(tokens),
  255 + encoder_out=encoder_out,
  256 + states=states,
  257 + )
  258 +
  259 + tokens = tokens[1:] # remove sos
  260 + words = [id2token[i] for i in tokens]
  261 + underline = "▁"
  262 + # underline = b"\xe2\x96\x81".decode()
  263 + text = "".join(words).replace(underline, " ").strip()
  264 +
  265 + end_t = dt.datetime.now()
  266 + t = (end_t - start_t).total_seconds()
  267 + rtf = t * 16000 / audio.shape[-1]
  268 +
  269 + print(text)
  270 + print("RTF:", rtf)
  271 +
  272 +
  273 +if __name__ == "__main__":
  274 + main()