正在显示
5 个修改的文件
包含
495 行增加
和
0 行删除
.github/workflows/export-vocos.yaml
0 → 100644
| 1 | +name: export-vocos-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - export-vocos | ||
| 7 | + | ||
| 8 | + workflow_dispatch: | ||
| 9 | + | ||
| 10 | +concurrency: | ||
| 11 | + group: export-vocos-to-onnx-${{ github.ref }} | ||
| 12 | + cancel-in-progress: true | ||
| 13 | + | ||
| 14 | +jobs: | ||
| 15 | + export-vocos-to-onnx: | ||
| 16 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 17 | + name: export vocos ${{ matrix.version }} | ||
| 18 | + runs-on: ${{ matrix.os }} | ||
| 19 | + strategy: | ||
| 20 | + fail-fast: false | ||
| 21 | + matrix: | ||
| 22 | + os: [ubuntu-latest] | ||
| 23 | + python-version: ["3.10"] | ||
| 24 | + | ||
| 25 | + steps: | ||
| 26 | + - uses: actions/checkout@v4 | ||
| 27 | + | ||
| 28 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 29 | + uses: actions/setup-python@v5 | ||
| 30 | + with: | ||
| 31 | + python-version: ${{ matrix.python-version }} | ||
| 32 | + | ||
| 33 | + - name: Install Python dependencies | ||
| 34 | + shell: bash | ||
| 35 | + run: | | ||
| 36 | + pip install "numpy<=1.26.4" onnx==1.16.0 onnxruntime==1.17.1 soundfile piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html kaldi_native_fbank | ||
| 37 | + | ||
| 38 | + - name: Run | ||
| 39 | + shell: bash | ||
| 40 | + run: | | ||
| 41 | + cd scripts/vocos | ||
| 42 | + ./run.sh | ||
| 43 | + ls -lh | ||
| 44 | + | ||
| 45 | + - name: Collect results | ||
| 46 | + shell: bash | ||
| 47 | + run: | | ||
| 48 | + cp -v scripts/vocos/vocos-22khz-univ.onnx . | ||
| 49 | + cp -v scripts/vocos/*.wav . | ||
| 50 | + | ||
| 51 | + - uses: actions/upload-artifact@v4 | ||
| 52 | + with: | ||
| 53 | + name: generated-waves | ||
| 54 | + path: ./*.wav | ||
| 55 | + | ||
| 56 | + - name: Publish to huggingface | ||
| 57 | + env: | ||
| 58 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 59 | + uses: nick-fields/retry@v3 | ||
| 60 | + with: | ||
| 61 | + max_attempts: 20 | ||
| 62 | + timeout_seconds: 200 | ||
| 63 | + shell: bash | ||
| 64 | + command: | | ||
| 65 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 66 | + git config --global user.name "Fangjun Kuang" | ||
| 67 | + | ||
| 68 | + rm -rf huggingface | ||
| 69 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 70 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 71 | + | ||
| 72 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-models huggingface | ||
| 73 | + cd huggingface | ||
| 74 | + git fetch | ||
| 75 | + git pull | ||
| 76 | + | ||
| 77 | + d=vocoder-models | ||
| 78 | + mkdir -p $d | ||
| 79 | + | ||
| 80 | + cp -a ../vocos-22khz-univ.onnx $d/ | ||
| 81 | + | ||
| 82 | + git lfs track "*.onnx" | ||
| 83 | + git add . | ||
| 84 | + | ||
| 85 | + ls -lh | ||
| 86 | + | ||
| 87 | + git status | ||
| 88 | + | ||
| 89 | + git commit -m "add models" | ||
| 90 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-models main || true | ||
| 91 | + | ||
| 92 | + - name: Release | ||
| 93 | + if: github.repository_owner == 'csukuangfj' | ||
| 94 | + uses: svenstaro/upload-release-action@v2 | ||
| 95 | + with: | ||
| 96 | + file_glob: true | ||
| 97 | + file: ./*.onnx | ||
| 98 | + overwrite: true | ||
| 99 | + repo_name: k2-fsa/sherpa-onnx | ||
| 100 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 101 | + tag: vocoder-models | ||
| 102 | + | ||
| 103 | + - name: Release | ||
| 104 | + if: github.repository_owner == 'k2-fsa' | ||
| 105 | + uses: svenstaro/upload-release-action@v2 | ||
| 106 | + with: | ||
| 107 | + file_glob: true | ||
| 108 | + file: ./*.onnx | ||
| 109 | + overwrite: true | ||
| 110 | + tag: vocoder-models | ||
| 111 | + |
scripts/vocos/README.md
0 → 100644
scripts/vocos/add_meta_data.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | + | ||
| 5 | +import argparse | ||
| 6 | + | ||
| 7 | +import onnx | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def get_args(): | ||
| 11 | + parser = argparse.ArgumentParser() | ||
| 12 | + parser.add_argument("--in-model", type=str, required=True, help="input onnx model") | ||
| 13 | + | ||
| 14 | + parser.add_argument( | ||
| 15 | + "--out-model", type=str, required=True, help="output onnx model" | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + return parser.parse_args() | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def main(): | ||
| 22 | + args = get_args() | ||
| 23 | + print(args.in_model, args.out_model) | ||
| 24 | + | ||
| 25 | + model = onnx.load(args.in_model) | ||
| 26 | + | ||
| 27 | + meta_data = { | ||
| 28 | + "model_type": "vocos", | ||
| 29 | + "model_filename": "mel_spec_22khz_univ.onnx", | ||
| 30 | + "sample_rate": 22050, | ||
| 31 | + "version": 1, | ||
| 32 | + "model_author": "BSC-LT", | ||
| 33 | + "maintainer": "k2-fsa", | ||
| 34 | + "n_fft": 1024, | ||
| 35 | + "hop_length": 256, | ||
| 36 | + "win_length": 1024, | ||
| 37 | + "window_type": "hann", | ||
| 38 | + "center": 1, | ||
| 39 | + "pad_mode": "reflect", | ||
| 40 | + "normalized": 0, | ||
| 41 | + "url1": "https://huggingface.co/BSC-LT/vocos-mel-22khz", | ||
| 42 | + "url2": "https://github.com/gemelo-ai/vocos", | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + print(model.metadata_props) | ||
| 46 | + | ||
| 47 | + while len(model.metadata_props): | ||
| 48 | + model.metadata_props.pop() | ||
| 49 | + | ||
| 50 | + for key, value in meta_data.items(): | ||
| 51 | + meta = model.metadata_props.add() | ||
| 52 | + meta.key = key | ||
| 53 | + meta.value = str(value) | ||
| 54 | + print("--------------------") | ||
| 55 | + | ||
| 56 | + print(model.metadata_props) | ||
| 57 | + | ||
| 58 | + onnx.save(model, args.out_model) | ||
| 59 | + | ||
| 60 | + print(f"Saved to {args.out_model}") | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +if __name__ == "__main__": | ||
| 64 | + main() |
scripts/vocos/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +if [ ! -f mel_spec_22khz_univ.onnx ]; then | ||
| 6 | + curl -SL -O https://huggingface.co/BSC-LT/vocos-mel-22khz/resolve/main/mel_spec_22khz_univ.onnx | ||
| 7 | +fi | ||
| 8 | + | ||
| 9 | +if [ ! -f ./vocos-22khz-univ.onnx ]; then | ||
| 10 | + python3 ./add_meta_data.py --in-model ./mel_spec_22khz_univ.onnx --out-model ./vocos-22khz-univ.onnx | ||
| 11 | +fi | ||
| 12 | + | ||
| 13 | +# The following is for testing | ||
| 14 | +if [ ! -f ./matcha-icefall-en_US-ljspeech/tokens.txt ]; then | ||
| 15 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-en_US-ljspeech.tar.bz2 | ||
| 16 | + tar xf matcha-icefall-en_US-ljspeech.tar.bz2 | ||
| 17 | + rm matcha-icefall-en_US-ljspeech.tar.bz2 | ||
| 18 | +fi | ||
| 19 | + | ||
| 20 | +if [ ! -f ./hifigan_v2.onnx ]; then | ||
| 21 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx | ||
| 22 | +fi | ||
| 23 | + | ||
| 24 | +python3 ./test.py | ||
| 25 | +ls -lh |
scripts/vocos/test.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import datetime as dt | ||
| 5 | + | ||
| 6 | +import kaldi_native_fbank as knf | ||
| 7 | +import numpy as np | ||
| 8 | +import onnxruntime as ort | ||
| 9 | +import soundfile as sf | ||
| 10 | + | ||
| 11 | +try: | ||
| 12 | + from piper_phonemize import phonemize_espeak | ||
| 13 | +except Exception as ex: | ||
| 14 | + raise RuntimeError( | ||
| 15 | + f"{ex}\nPlease run\n" | ||
| 16 | + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +class OnnxVocosModel: | ||
| 21 | + def __init__( | ||
| 22 | + self, | ||
| 23 | + filename: str, | ||
| 24 | + ): | ||
| 25 | + session_opts = ort.SessionOptions() | ||
| 26 | + session_opts.inter_op_num_threads = 1 | ||
| 27 | + session_opts.intra_op_num_threads = 1 | ||
| 28 | + | ||
| 29 | + self.session_opts = session_opts | ||
| 30 | + self.model = ort.InferenceSession( | ||
| 31 | + filename, | ||
| 32 | + sess_options=self.session_opts, | ||
| 33 | + providers=["CPUExecutionProvider"], | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + print("----------vocos----------") | ||
| 37 | + for i in self.model.get_inputs(): | ||
| 38 | + print(i) | ||
| 39 | + | ||
| 40 | + print("-----") | ||
| 41 | + | ||
| 42 | + for i in self.model.get_outputs(): | ||
| 43 | + print(i) | ||
| 44 | + print() | ||
| 45 | + | ||
| 46 | + def __call__(self, x: np.ndarray): | ||
| 47 | + """ | ||
| 48 | + Args: | ||
| 49 | + x: (N, feat_dim, num_frames) | ||
| 50 | + Returns: | ||
| 51 | + mag: (N, n_fft/2+1, num_frames) | ||
| 52 | + x: (N, n_fft/2+1, num_frames) | ||
| 53 | + y: (N, n_fft/2+1, num_frames) | ||
| 54 | + | ||
| 55 | + The complex spectrum is mag * (x + j*y) | ||
| 56 | + """ | ||
| 57 | + assert x.ndim == 3, x.shape | ||
| 58 | + assert x.shape[0] == 1, x.shape | ||
| 59 | + | ||
| 60 | + mag, x, y = self.model.run( | ||
| 61 | + [ | ||
| 62 | + self.model.get_outputs()[0].name, | ||
| 63 | + self.model.get_outputs()[1].name, | ||
| 64 | + self.model.get_outputs()[2].name, | ||
| 65 | + ], | ||
| 66 | + { | ||
| 67 | + self.model.get_inputs()[0].name: x, | ||
| 68 | + }, | ||
| 69 | + ) | ||
| 70 | + | ||
| 71 | + return mag, x, y | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +class OnnxHifiGANModel: | ||
| 75 | + def __init__( | ||
| 76 | + self, | ||
| 77 | + filename: str, | ||
| 78 | + ): | ||
| 79 | + session_opts = ort.SessionOptions() | ||
| 80 | + session_opts.inter_op_num_threads = 1 | ||
| 81 | + session_opts.intra_op_num_threads = 1 | ||
| 82 | + | ||
| 83 | + self.session_opts = session_opts | ||
| 84 | + self.model = ort.InferenceSession( | ||
| 85 | + filename, | ||
| 86 | + sess_options=self.session_opts, | ||
| 87 | + providers=["CPUExecutionProvider"], | ||
| 88 | + ) | ||
| 89 | + | ||
| 90 | + print("----------hifigan----------") | ||
| 91 | + for i in self.model.get_inputs(): | ||
| 92 | + print(i) | ||
| 93 | + | ||
| 94 | + print("-----") | ||
| 95 | + | ||
| 96 | + for i in self.model.get_outputs(): | ||
| 97 | + print(i) | ||
| 98 | + print() | ||
| 99 | + | ||
| 100 | + def __call__(self, x: np.ndarray): | ||
| 101 | + """ | ||
| 102 | + Args: | ||
| 103 | + x: (N, feat_dim, num_frames) | ||
| 104 | + Returns: | ||
| 105 | + audio: (N, num_samples) | ||
| 106 | + """ | ||
| 107 | + assert x.ndim == 3, x.shape | ||
| 108 | + assert x.shape[0] == 1, x.shape | ||
| 109 | + | ||
| 110 | + audio = self.model.run( | ||
| 111 | + [self.model.get_outputs()[0].name], | ||
| 112 | + { | ||
| 113 | + self.model.get_inputs()[0].name: x, | ||
| 114 | + }, | ||
| 115 | + )[0] | ||
| 116 | + # audio: (batch_size, num_samples) | ||
| 117 | + | ||
| 118 | + return audio | ||
| 119 | + | ||
| 120 | + | ||
| 121 | +def load_tokens(filename): | ||
| 122 | + token2id = dict() | ||
| 123 | + with open(filename, encoding="utf-8") as f: | ||
| 124 | + for line in f: | ||
| 125 | + fields = line.strip().split() | ||
| 126 | + if len(fields) == 1: | ||
| 127 | + t = " " | ||
| 128 | + idx = int(fields[0]) | ||
| 129 | + else: | ||
| 130 | + t, idx = line.strip().split() | ||
| 131 | + token2id[t] = int(idx) | ||
| 132 | + return token2id | ||
| 133 | + | ||
| 134 | + | ||
| 135 | +class OnnxModel: | ||
| 136 | + def __init__( | ||
| 137 | + self, | ||
| 138 | + filename: str, | ||
| 139 | + tokens: str, | ||
| 140 | + ): | ||
| 141 | + self.token2id = load_tokens(tokens) | ||
| 142 | + session_opts = ort.SessionOptions() | ||
| 143 | + session_opts.inter_op_num_threads = 1 | ||
| 144 | + session_opts.intra_op_num_threads = 1 | ||
| 145 | + | ||
| 146 | + self.session_opts = session_opts | ||
| 147 | + self.model = ort.InferenceSession( | ||
| 148 | + filename, | ||
| 149 | + sess_options=self.session_opts, | ||
| 150 | + providers=["CPUExecutionProvider"], | ||
| 151 | + ) | ||
| 152 | + | ||
| 153 | + print(f"{self.model.get_modelmeta().custom_metadata_map}") | ||
| 154 | + metadata = self.model.get_modelmeta().custom_metadata_map | ||
| 155 | + self.sample_rate = int(metadata["sample_rate"]) | ||
| 156 | + | ||
| 157 | + print("----------matcha----------") | ||
| 158 | + for i in self.model.get_inputs(): | ||
| 159 | + print(i) | ||
| 160 | + | ||
| 161 | + print("-----") | ||
| 162 | + | ||
| 163 | + for i in self.model.get_outputs(): | ||
| 164 | + print(i) | ||
| 165 | + print() | ||
| 166 | + | ||
| 167 | + def __call__(self, x: np.ndim): | ||
| 168 | + """ | ||
| 169 | + Args: | ||
| 170 | + """ | ||
| 171 | + assert x.ndim == 2, x.shape | ||
| 172 | + assert x.shape[0] == 1, x.shape | ||
| 173 | + | ||
| 174 | + x_lengths = np.array([x.shape[1]], dtype=np.int64) | ||
| 175 | + | ||
| 176 | + noise_scale = np.array([1.0], dtype=np.float32) | ||
| 177 | + length_scale = np.array([1.0], dtype=np.float32) | ||
| 178 | + | ||
| 179 | + mel = self.model.run( | ||
| 180 | + [self.model.get_outputs()[0].name], | ||
| 181 | + { | ||
| 182 | + self.model.get_inputs()[0].name: x, | ||
| 183 | + self.model.get_inputs()[1].name: x_lengths, | ||
| 184 | + self.model.get_inputs()[2].name: noise_scale, | ||
| 185 | + self.model.get_inputs()[3].name: length_scale, | ||
| 186 | + }, | ||
| 187 | + )[0] | ||
| 188 | + # mel: (batch_size, feat_dim, num_frames) | ||
| 189 | + | ||
| 190 | + return mel | ||
| 191 | + | ||
| 192 | + | ||
| 193 | +def main(): | ||
| 194 | + am = OnnxModel( | ||
| 195 | + filename="./matcha-icefall-en_US-ljspeech/model-steps-3.onnx", | ||
| 196 | + tokens="./matcha-icefall-en_US-ljspeech/tokens.txt", | ||
| 197 | + ) | ||
| 198 | + vocoder = OnnxHifiGANModel("./hifigan_v2.onnx") | ||
| 199 | + vocos = OnnxVocosModel("./mel_spec_22khz_univ.onnx") | ||
| 200 | + | ||
| 201 | + text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." | ||
| 202 | + tokens_list = phonemize_espeak(text, "en-us") | ||
| 203 | + print(tokens_list) | ||
| 204 | + tokens = [] | ||
| 205 | + for t in tokens_list: | ||
| 206 | + tokens.extend(t) | ||
| 207 | + | ||
| 208 | + token_ids = [] | ||
| 209 | + for t in tokens: | ||
| 210 | + if t not in am.token2id: | ||
| 211 | + print(f"Skip OOV '{t}'") | ||
| 212 | + continue | ||
| 213 | + token_ids.append(am.token2id[t]) | ||
| 214 | + | ||
| 215 | + token_ids2 = [am.token2id["_"]] * (len(token_ids) * 2 + 1) | ||
| 216 | + token_ids2[1::2] = token_ids | ||
| 217 | + token_ids = token_ids2 | ||
| 218 | + x = np.array([token_ids], dtype=np.int64) | ||
| 219 | + | ||
| 220 | + mel_start_t = dt.datetime.now() | ||
| 221 | + mel = am(x) | ||
| 222 | + mel_end_t = dt.datetime.now() | ||
| 223 | + | ||
| 224 | + print("mel", mel.shape) | ||
| 225 | + # mel:(1, 80, 78) | ||
| 226 | + | ||
| 227 | + vocos_start_t = dt.datetime.now() | ||
| 228 | + mag, x, y = vocos(mel) | ||
| 229 | + stft_result = knf.StftResult( | ||
| 230 | + real=(mag * x)[0].transpose().reshape(-1).tolist(), | ||
| 231 | + imag=(mag * y)[0].transpose().reshape(-1).tolist(), | ||
| 232 | + num_frames=mag.shape[2], | ||
| 233 | + ) | ||
| 234 | + config = knf.StftConfig( | ||
| 235 | + n_fft=1024, | ||
| 236 | + hop_length=256, | ||
| 237 | + win_length=1024, | ||
| 238 | + window_type="hann", | ||
| 239 | + center=True, | ||
| 240 | + pad_mode="reflect", | ||
| 241 | + normalized=False, | ||
| 242 | + ) | ||
| 243 | + istft = knf.IStft(config) | ||
| 244 | + audio_vocos = istft(stft_result) | ||
| 245 | + vocos_end_t = dt.datetime.now() | ||
| 246 | + | ||
| 247 | + audio_vocos = np.array(audio_vocos) | ||
| 248 | + # audio = audio / 2 | ||
| 249 | + print("vocos max/min", np.max(audio_vocos), np.min(audio_vocos)) | ||
| 250 | + | ||
| 251 | + sf.write("vocos.wav", audio_vocos, am.sample_rate, "PCM_16") | ||
| 252 | + | ||
| 253 | + hifigan_start_t = dt.datetime.now() | ||
| 254 | + audio_hifigan = vocoder(mel) | ||
| 255 | + hifigan_end_t = dt.datetime.now() | ||
| 256 | + audio_hifigan = audio_hifigan.squeeze() | ||
| 257 | + | ||
| 258 | + print("hifigan max/min", np.max(audio_hifigan), np.min(audio_hifigan)) | ||
| 259 | + | ||
| 260 | + sample_rate = am.sample_rate | ||
| 261 | + sf.write("hifigan-v2.wav", audio_hifigan, sample_rate, "PCM_16") | ||
| 262 | + | ||
| 263 | + am_t = (mel_end_t - mel_start_t).total_seconds() | ||
| 264 | + vocos_t = (vocos_end_t - vocos_start_t).total_seconds() | ||
| 265 | + hifigan_t = (hifigan_end_t - hifigan_start_t).total_seconds() | ||
| 266 | + | ||
| 267 | + mean_audio_duration = ( | ||
| 268 | + (audio_vocos.shape[-1] + audio_hifigan.shape[-1]) / 2 / sample_rate | ||
| 269 | + ) | ||
| 270 | + rtf_am = am_t / mean_audio_duration | ||
| 271 | + | ||
| 272 | + rtf_vocos = vocos_t * sample_rate / audio_vocos.shape[-1] | ||
| 273 | + rtf_hifigan = hifigan_t * sample_rate / audio_hifigan.shape[-1] | ||
| 274 | + | ||
| 275 | + print( | ||
| 276 | + "Audio duration for vocos {:.3f} s".format(audio_vocos.shape[-1] / sample_rate) | ||
| 277 | + ) | ||
| 278 | + print( | ||
| 279 | + "Audio duration for hifigan {:.3f} s".format( | ||
| 280 | + audio_hifigan.shape[-1] / sample_rate | ||
| 281 | + ) | ||
| 282 | + ) | ||
| 283 | + print("Mean audio duration: {:.3f} s".format(mean_audio_duration)) | ||
| 284 | + print("RTF for acoustic model {:.3f}".format(rtf_am)) | ||
| 285 | + print("RTF for vocos {:.3f}".format(rtf_vocos)) | ||
| 286 | + print("RTF for hifigan {:.3f}".format(rtf_hifigan)) | ||
| 287 | + | ||
| 288 | + | ||
| 289 | +if __name__ == "__main__": | ||
| 290 | + main() |
-
请 注册 或 登录 后发表评论