Fangjun Kuang
Committed by GitHub

Export gtcrn models to sherpa-onnx (#1975)

  1 +name: export-gtcrn-to-onnx
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - export-gtcrn
  7 +
  8 + workflow_dispatch:
  9 +
  10 +concurrency:
  11 + group: export-gtcrn-to-onnx-${{ github.ref }}
  12 + cancel-in-progress: true
  13 +
  14 +jobs:
  15 + export-gtcrn-to-onnx:
  16 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  17 + name: export gtcrn ${{ matrix.version }}
  18 + runs-on: ${{ matrix.os }}
  19 + strategy:
  20 + fail-fast: false
  21 + matrix:
  22 + os: [ubuntu-latest]
  23 +
  24 + steps:
  25 + - uses: actions/checkout@v4
  26 +
  27 + - name: Setup Python ${{ matrix.python-version }}
  28 + uses: actions/setup-python@v5
  29 + with:
  30 + python-version: ${{ matrix.python-version }}
  31 +
  32 + - name: Install Python dependencies
  33 + shell: bash
  34 + run: |
  35 + pip install "numpy<=1.26.4" onnx==1.16.0 onnxruntime==1.17.1 librosa soundfile torch==2.6.0+cpu -f https://download.pytorch.org/whl/torch "kaldi-native-fbank>=1.21.1"
  36 +
  37 + - name: Run
  38 + shell: bash
  39 + run: |
  40 + cd scripts/gtcrn
  41 + ./run.sh
  42 + ./test.py
  43 + ls -lh
  44 +
  45 + - name: Collect results
  46 + shell: bash
  47 + run: |
  48 + src=scripts/gtcrn
  49 + cp -v $src/*.onnx ./
  50 + ls -lh *.onnx
  51 +
  52 + - name: Publish to huggingface 0.19
  53 + env:
  54 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  55 + uses: nick-fields/retry@v3
  56 + with:
  57 + max_attempts: 20
  58 + timeout_seconds: 200
  59 + shell: bash
  60 + command: |
  61 + git config --global user.email "csukuangfj@gmail.com"
  62 + git config --global user.name "Fangjun Kuang"
  63 +
  64 + rm -rf huggingface
  65 + export GIT_LFS_SKIP_SMUDGE=1
  66 + export GIT_CLONE_PROTECTION_ACTIVE=false
  67 +
  68 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/speech-enhancement-models huggingface
  69 + cd huggingface
  70 + git fetch
  71 + git pull
  72 +
  73 + cp -v ../gtcrn_simple.onnx ./
  74 +
  75 + git lfs track "*.onnx"
  76 + git add .
  77 +
  78 + ls -lh
  79 +
  80 + git status
  81 +
  82 + git commit -m "add models"
  83 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/speech-enhancement-models main || true
  84 +
  85 + - name: Release
  86 + if: github.repository_owner == 'csukuangfj'
  87 + uses: svenstaro/upload-release-action@v2
  88 + with:
  89 + file_glob: true
  90 + file: ./*.onnx
  91 + overwrite: true
  92 + repo_name: k2-fsa/sherpa-onnx
  93 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  94 + tag: speech-enhancement-models
  95 +
  96 + - name: Release
  97 + if: github.repository_owner == 'k2-fsa'
  98 + uses: svenstaro/upload-release-action@v2
  99 + with:
  100 + file_glob: true
  101 + file: ./*.onnx
  102 + overwrite: true
  103 + tag: speech-enhancement-models
  1 +# Introduction
  2 +
  3 +This folder contains scripts for adding metadata to models from
  4 +https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +"""
  5 +NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
  6 +NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
  7 +NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
  8 +NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
  9 +-----
  10 +NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
  11 +NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
  12 +NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
  13 +NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
  14 +"""
  15 +
  16 +import onnx
  17 +import onnxruntime as ort
  18 +
  19 +
  20 +def show(filename):
  21 + session_opts = ort.SessionOptions()
  22 + session_opts.log_severity_level = 3
  23 + sess = ort.InferenceSession(filename, session_opts)
  24 + for i in sess.get_inputs():
  25 + print(i)
  26 +
  27 + print("-----")
  28 +
  29 + for i in sess.get_outputs():
  30 + print(i)
  31 +
  32 +
  33 +def main():
  34 + filename = "./gtcrn_simple.onnx"
  35 + show(filename)
  36 + model = onnx.load(filename)
  37 +
  38 + meta_data = {
  39 + "model_type": "gtcrn",
  40 + "comment": "gtcrn_simple",
  41 + "version": 1,
  42 + "sample_rate": 16000,
  43 + "model_url": "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx",
  44 + "maintainer": "k2-fsa",
  45 + "comment2": "Please see also https://github.com/Xiaobin-Rong/gtcrn",
  46 + "conv_cache_shape": "2,1,16,16,33",
  47 + "tra_cache_shape": "2,3,1,1,16",
  48 + "inter_cache_shape": "2,1,33,16",
  49 + "n_fft": 512,
  50 + "hop_length": 256,
  51 + "window_length": 512,
  52 + "window_type": "hann_sqrt",
  53 + }
  54 +
  55 + print(model.metadata_props)
  56 +
  57 + while len(model.metadata_props):
  58 + model.metadata_props.pop()
  59 +
  60 + for key, value in meta_data.items():
  61 + meta = model.metadata_props.add()
  62 + meta.key = key
  63 + meta.value = str(value)
  64 + print("--------------------")
  65 +
  66 + print(model.metadata_props)
  67 +
  68 + onnx.save(model, filename)
  69 +
  70 +
  71 +if __name__ == "__main__":
  72 + main()
  1 +#!/usr/bin/env bash
  2 +#
  3 +
  4 +if [ ! -f gtcrn_simple.onnx ]; then
  5 + wget https://github.com/Xiaobin-Rong/gtcrn/raw/refs/heads/main/stream/onnx_models/gtcrn_simple.onnx
  6 +fi
  7 +
  8 +if [ ! -f ./inp_16k.wav ]; then
  9 + wget https://github.com/yuyun2000/SpeechDenoiser/raw/refs/heads/main/16k/inp_16k.wav
  10 +fi
  11 +
  12 +python3 ./add_meta_data.py
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +from typing import Tuple
  5 +
  6 +import kaldi_native_fbank as knf
  7 +import numpy as np
  8 +import onnxruntime as ort
  9 +import soundfile as sf
  10 +import torch
  11 +
  12 +
  13 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  14 + data, sample_rate = sf.read(
  15 + filename,
  16 + always_2d=True,
  17 + dtype="float32",
  18 + )
  19 + data = data[:, 0] # use only the first channel
  20 + samples = np.ascontiguousarray(data)
  21 + return samples, sample_rate
  22 +
  23 +
  24 +class OnnxModel:
  25 + def __init__(self):
  26 + session_opts = ort.SessionOptions()
  27 + session_opts.inter_op_num_threads = 1
  28 + session_opts.intra_op_num_threads = 1
  29 +
  30 + self.session_opts = session_opts
  31 + self.model = ort.InferenceSession(
  32 + "./gtcrn_simple.onnx",
  33 + sess_options=self.session_opts,
  34 + providers=["CPUExecutionProvider"],
  35 + )
  36 +
  37 + meta = self.model.get_modelmeta().custom_metadata_map
  38 + self.sample_rate = int(meta["sample_rate"])
  39 + self.n_fft = int(meta["n_fft"])
  40 + self.hop_length = int(meta["hop_length"])
  41 + self.window_length = int(meta["window_length"])
  42 + assert meta["window_type"] == "hann_sqrt", meta["window_type"]
  43 +
  44 + self.window = torch.hann_window(self.window_length).pow(0.5)
  45 +
  46 + def get_init_states(self):
  47 + meta = self.model.get_modelmeta().custom_metadata_map
  48 + conv_cache_shape = list(map(int, meta["conv_cache_shape"].split(",")))
  49 + tra_cache_shape = list(map(int, meta["tra_cache_shape"].split(",")))
  50 + inter_cache_shape = list(map(int, meta["inter_cache_shape"].split(",")))
  51 +
  52 + conv_cache_shape = np.zeros(conv_cache_shape, dtype=np.float32)
  53 + tra_cache = np.zeros(tra_cache_shape, dtype=np.float32)
  54 + inter_cache = np.zeros(inter_cache_shape, dtype=np.float32)
  55 +
  56 + return conv_cache_shape, tra_cache, inter_cache
  57 +
  58 + def __call__(self, x, states):
  59 + """
  60 + Args:
  61 + x: (1, n_fft/2+1, 1, 2)
  62 + Returns:
  63 + o: (1, n_fft/2+1, 1, 2)
  64 + """
  65 + out, next_conv_cache, next_tra_cache, next_inter_cache = self.model.run(
  66 + [
  67 + self.model.get_outputs()[0].name,
  68 + self.model.get_outputs()[1].name,
  69 + self.model.get_outputs()[2].name,
  70 + self.model.get_outputs()[3].name,
  71 + ],
  72 + {
  73 + self.model.get_inputs()[0].name: x,
  74 + self.model.get_inputs()[1].name: states[0],
  75 + self.model.get_inputs()[2].name: states[1],
  76 + self.model.get_inputs()[3].name: states[2],
  77 + },
  78 + )
  79 +
  80 + return out, (next_conv_cache, next_tra_cache, next_inter_cache)
  81 +
  82 +
  83 +def main():
  84 + model = OnnxModel()
  85 +
  86 + filename = "./inp_16k.wav"
  87 + wave, sample_rate = load_audio(filename)
  88 + if sample_rate != model.sample_rate:
  89 + import librosa
  90 +
  91 + wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=model.sample_rate)
  92 + sample_rate = model.sample_rate
  93 +
  94 + stft_config = knf.StftConfig(
  95 + n_fft=model.n_fft,
  96 + hop_length=model.hop_length,
  97 + win_length=model.window_length,
  98 + window=model.window.tolist(),
  99 + )
  100 + stft = knf.Stft(stft_config)
  101 + stft_result = stft(wave)
  102 + num_frames = stft_result.num_frames
  103 + real = np.array(stft_result.real, dtype=np.float32).reshape(num_frames, -1)
  104 + imag = np.array(stft_result.imag, dtype=np.float32).reshape(num_frames, -1)
  105 +
  106 + states = model.get_init_states()
  107 + outputs = []
  108 + for i in range(num_frames):
  109 + x_real = real[i : i + 1]
  110 + x_imag = imag[i : i + 1]
  111 + x = np.vstack([x_real, x_imag]).transpose()
  112 + x = np.expand_dims(x, axis=0)
  113 + x = np.expand_dims(x, axis=2)
  114 +
  115 + o, states = model(x, states)
  116 + outputs.append(o)
  117 +
  118 + outputs = np.concatenate(outputs, axis=2)
  119 + outputs = outputs.squeeze(0).transpose(1, 0, 2)
  120 +
  121 + enhanced_real = outputs[:, :, 0]
  122 + enhanced_imag = outputs[:, :, 1]
  123 + enhanced_stft_result = knf.StftResult(
  124 + real=enhanced_real.reshape(-1).tolist(),
  125 + imag=enhanced_imag.reshape(-1).tolist(),
  126 + num_frames=enhanced_real.shape[0],
  127 + )
  128 +
  129 + istft = knf.IStft(stft_config)
  130 + enhanced = istft(enhanced_stft_result)
  131 +
  132 + sf.write("./enhanced_16k.wav", enhanced, model.sample_rate)
  133 +
  134 +
  135 +if __name__ == "__main__":
  136 + main()