Committed by
GitHub
Add UVR models for source separation. (#2266)
正在显示
7 个修改的文件
包含
575 行增加
和
3 行删除
.github/workflows/export-uvr-to-onnx.yaml
0 → 100644
| 1 | +name: export-uvr-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - uvr | ||
| 7 | + workflow_dispatch: | ||
| 8 | + | ||
| 9 | +concurrency: | ||
| 10 | + group: export-uvr-to-onnx-${{ github.ref }} | ||
| 11 | + cancel-in-progress: true | ||
| 12 | + | ||
| 13 | +jobs: | ||
| 14 | + export-uvr-to-onnx: | ||
| 15 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 16 | + name: export UVR to ONNX | ||
| 17 | + runs-on: ${{ matrix.os }} | ||
| 18 | + strategy: | ||
| 19 | + fail-fast: false | ||
| 20 | + matrix: | ||
| 21 | + os: [macos-latest] | ||
| 22 | + python-version: ["3.10"] | ||
| 23 | + | ||
| 24 | + steps: | ||
| 25 | + - uses: actions/checkout@v4 | ||
| 26 | + | ||
| 27 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 28 | + uses: actions/setup-python@v5 | ||
| 29 | + with: | ||
| 30 | + python-version: ${{ matrix.python-version }} | ||
| 31 | + | ||
| 32 | + - name: Install dependencies | ||
| 33 | + shell: bash | ||
| 34 | + run: | | ||
| 35 | + pip install "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools kaldi-native-fbank librosa soundfile | ||
| 36 | + | ||
| 37 | + - name: Run | ||
| 38 | + shell: bash | ||
| 39 | + run: | | ||
| 40 | + cd scripts/uvr_mdx | ||
| 41 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav | ||
| 42 | + ls -lh audio_example.wav | ||
| 43 | + ./run.sh | ||
| 44 | + | ||
| 45 | + - name: Collect mp3 files | ||
| 46 | + shell: bash | ||
| 47 | + run: | | ||
| 48 | + mv -v scripts/uvr_mdx/*.mp3 ./ | ||
| 49 | + ls -lh *.mp3 | ||
| 50 | + | ||
| 51 | + - uses: actions/upload-artifact@v4 | ||
| 52 | + with: | ||
| 53 | + name: generated-mp3 | ||
| 54 | + path: ./*.mp3 | ||
| 55 | + | ||
| 56 | + - name: Collect models | ||
| 57 | + shell: bash | ||
| 58 | + run: | | ||
| 59 | + mv -v scripts/uvr_mdx/models/*.onnx ./ | ||
| 60 | + ls -lh *.onnx | ||
| 61 | + | ||
| 62 | + - name: Release | ||
| 63 | + uses: svenstaro/upload-release-action@v2 | ||
| 64 | + with: | ||
| 65 | + file_glob: true | ||
| 66 | + file: ./*.onnx | ||
| 67 | + overwrite: true | ||
| 68 | + repo_name: k2-fsa/sherpa-onnx | ||
| 69 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 70 | + tag: source-separation-models | ||
| 71 | + | ||
| 72 | + - name: Publish to huggingface | ||
| 73 | + env: | ||
| 74 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 75 | + uses: nick-fields/retry@v3 | ||
| 76 | + with: | ||
| 77 | + max_attempts: 20 | ||
| 78 | + timeout_seconds: 200 | ||
| 79 | + shell: bash | ||
| 80 | + command: | | ||
| 81 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 82 | + git config --global user.name "Fangjun Kuang" | ||
| 83 | + | ||
| 84 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 85 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 86 | + | ||
| 87 | + rm -rf huggingface | ||
| 88 | + git clone https://huggingface.co/k2-fsa/sherpa-onnx-models huggingface | ||
| 89 | + cd huggingface | ||
| 90 | + mkdir -p source-separation-models | ||
| 91 | + cp -av ../*.onnx ./source-separation-models | ||
| 92 | + git lfs track "*.onnx" | ||
| 93 | + git status | ||
| 94 | + git add . | ||
| 95 | + ls -lh | ||
| 96 | + git status | ||
| 97 | + git commit -m "add source separation models" | ||
| 98 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-models main |
scripts/uvr_mdx/READEME.md
0 → 100644
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +from pathlib import Path | ||
| 6 | + | ||
| 7 | +import onnx | ||
| 8 | +import onnxmltools | ||
| 9 | +import onnxruntime | ||
| 10 | +from onnxmltools.utils.float16_converter import convert_float_to_float16 | ||
| 11 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +def get_args(): | ||
| 15 | + parser = argparse.ArgumentParser( | ||
| 16 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + parser.add_argument( | ||
| 20 | + "--filename", | ||
| 21 | + type=str, | ||
| 22 | + required=True, | ||
| 23 | + help="Path to onnx model", | ||
| 24 | + ) | ||
| 25 | + | ||
| 26 | + return parser.parse_args() | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): | ||
| 30 | + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) | ||
| 31 | + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) | ||
| 32 | + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) | ||
| 33 | + | ||
| 34 | + | ||
| 35 | +def validate(model: onnxruntime.InferenceSession): | ||
| 36 | + for i in model.get_inputs(): | ||
| 37 | + print(i) | ||
| 38 | + | ||
| 39 | + print("-----") | ||
| 40 | + | ||
| 41 | + for i in model.get_outputs(): | ||
| 42 | + print(i) | ||
| 43 | + | ||
| 44 | + assert len(model.get_inputs()) == 1, len(model.get_inputs()) | ||
| 45 | + assert len(model.get_outputs()) == 1, len(model.get_outputs()) | ||
| 46 | + | ||
| 47 | + inp = model.get_inputs()[0] | ||
| 48 | + outp = model.get_outputs()[0] | ||
| 49 | + | ||
| 50 | + assert len(inp.shape) == 4, inp.shape | ||
| 51 | + assert len(outp.shape) == 4, outp.shape | ||
| 52 | + | ||
| 53 | + assert inp.shape[1:] == outp.shape[1:], (inp.shape, outp.shape) | ||
| 54 | + | ||
| 55 | + | ||
| 56 | +def add_meta_data(filename, meta_data): | ||
| 57 | + model = onnx.load(filename) | ||
| 58 | + | ||
| 59 | + print(model.metadata_props) | ||
| 60 | + | ||
| 61 | + while len(model.metadata_props): | ||
| 62 | + model.metadata_props.pop() | ||
| 63 | + | ||
| 64 | + for key, value in meta_data.items(): | ||
| 65 | + meta = model.metadata_props.add() | ||
| 66 | + meta.key = key | ||
| 67 | + meta.value = str(value) | ||
| 68 | + print("--------------------") | ||
| 69 | + | ||
| 70 | + print(model.metadata_props) | ||
| 71 | + | ||
| 72 | + onnx.save(model, filename) | ||
| 73 | + | ||
| 74 | + | ||
| 75 | +def main(): | ||
| 76 | + args = get_args() | ||
| 77 | + filename = Path(args.filename) | ||
| 78 | + if not filename.is_file(): | ||
| 79 | + raise ValueError(f"{filename} does not exist") | ||
| 80 | + | ||
| 81 | + name = filename.stem | ||
| 82 | + print("name", name) | ||
| 83 | + | ||
| 84 | + model = onnx.load(str(filename)) | ||
| 85 | + | ||
| 86 | + session_opts = onnxruntime.SessionOptions() | ||
| 87 | + session_opts.log_severity_level = 3 | ||
| 88 | + sess = onnxruntime.InferenceSession( | ||
| 89 | + str(filename), session_opts, providers=["CPUExecutionProvider"] | ||
| 90 | + ) | ||
| 91 | + validate(sess) | ||
| 92 | + | ||
| 93 | + inp = sess.get_inputs()[0] | ||
| 94 | + outp = sess.get_outputs()[0] | ||
| 95 | + | ||
| 96 | + meta_data = { | ||
| 97 | + "model_type": "UVR", | ||
| 98 | + "model_name": name, | ||
| 99 | + "sample_rate": 44100, | ||
| 100 | + "comment": "This model is downloaded from https://github.com/TRvlvr/model_repo/releases", | ||
| 101 | + "n_fft": inp.shape[2] * 2, | ||
| 102 | + "center": 1, | ||
| 103 | + "window_type": "hann", | ||
| 104 | + "win_length": inp.shape[2] * 2, | ||
| 105 | + "hop_length": 1024, | ||
| 106 | + "dim_t": inp.shape[3], | ||
| 107 | + "dim_f": inp.shape[2], | ||
| 108 | + "dim_c": inp.shape[1], | ||
| 109 | + "stems": 2, | ||
| 110 | + } | ||
| 111 | + add_meta_data(str(filename), meta_data) | ||
| 112 | + | ||
| 113 | + filename_fp16 = f"./{name}.fp16.onnx" | ||
| 114 | + export_onnx_fp16(filename, filename_fp16) | ||
| 115 | + | ||
| 116 | + | ||
| 117 | +if __name__ == "__main__": | ||
| 118 | + main() |
scripts/uvr_mdx/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +set -ex | ||
| 5 | + | ||
| 6 | + | ||
| 7 | +# Please see https://github.com/TRvlvr/model_repo/releases/tag/all_public_uvr_models | ||
| 8 | +models=( | ||
| 9 | +UVR-MDX-NET-Inst_1.onnx | ||
| 10 | +UVR-MDX-NET-Inst_2.onnx | ||
| 11 | +UVR-MDX-NET-Inst_3.onnx | ||
| 12 | +UVR-MDX-NET-Inst_HQ_1.onnx | ||
| 13 | +UVR-MDX-NET-Inst_HQ_2.onnx | ||
| 14 | +UVR-MDX-NET-Inst_HQ_3.onnx | ||
| 15 | +UVR-MDX-NET-Inst_HQ_4.onnx | ||
| 16 | +UVR-MDX-NET-Inst_HQ_5.onnx | ||
| 17 | +UVR-MDX-NET-Inst_Main.onnx | ||
| 18 | +UVR-MDX-NET-Voc_FT.onnx | ||
| 19 | +UVR-MDX-NET_Crowd_HQ_1.onnx | ||
| 20 | +UVR_MDXNET_1_9703.onnx | ||
| 21 | +UVR_MDXNET_2_9682.onnx | ||
| 22 | +UVR_MDXNET_3_9662.onnx | ||
| 23 | +UVR_MDXNET_9482.onnx | ||
| 24 | +UVR_MDXNET_KARA.onnx | ||
| 25 | +UVR_MDXNET_KARA_2.onnx | ||
| 26 | +UVR_MDXNET_Main.onnx | ||
| 27 | +) | ||
| 28 | + | ||
| 29 | +mkdir -p models | ||
| 30 | +for m in ${models[@]}; do | ||
| 31 | + if [ ! -f models/$m ]; then | ||
| 32 | + curl -SL --output models/$m https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/$m | ||
| 33 | + fi | ||
| 34 | +done | ||
| 35 | + | ||
| 36 | +ls -lh models | ||
| 37 | + | ||
| 38 | +for m in ${models[@]}; do | ||
| 39 | + echo "----------$m----------" | ||
| 40 | + python3 ./add_meta_data_and_quantize.py --filename models/$m | ||
| 41 | + | ||
| 42 | + ls -lh models/ | ||
| 43 | +done | ||
| 44 | + | ||
| 45 | +if [ -f ./audio_example.wav ]; then | ||
| 46 | + for m in ${models[@]}; do | ||
| 47 | + ./test.py --model-filename ./models/$m --audio-filename ./audio_example.wav | ||
| 48 | + name=$(basename -s .onnx $m) | ||
| 49 | + mv -v vocals.mp3 ${name}_vocals.mp3 | ||
| 50 | + mv -v non_vocals.mp3 ${name}_non_vocals.mp3 | ||
| 51 | + done | ||
| 52 | + | ||
| 53 | + ls -lh *.mp3 | ||
| 54 | +fi |
scripts/uvr_mdx/show.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import onnxruntime | ||
| 5 | +import onnx | ||
| 6 | + | ||
| 7 | +""" | ||
| 8 | +[] | ||
| 9 | +NodeArg(name='input', type='tensor(float)', shape=['batch_size', 4, 3072, 256]) | ||
| 10 | +----- | ||
| 11 | +NodeArg(name='output', type='tensor(float)', shape=['batch_size', 4, 3072, 256]) | ||
| 12 | +""" | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def show(filename): | ||
| 16 | + model = onnx.load(filename) | ||
| 17 | + print(model.metadata_props) | ||
| 18 | + | ||
| 19 | + session_opts = onnxruntime.SessionOptions() | ||
| 20 | + session_opts.log_severity_level = 3 | ||
| 21 | + sess = onnxruntime.InferenceSession( | ||
| 22 | + filename, session_opts, providers=["CPUExecutionProvider"] | ||
| 23 | + ) | ||
| 24 | + for i in sess.get_inputs(): | ||
| 25 | + print(i) | ||
| 26 | + | ||
| 27 | + print("-----") | ||
| 28 | + | ||
| 29 | + for i in sess.get_outputs(): | ||
| 30 | + print(i) | ||
| 31 | + | ||
| 32 | + | ||
| 33 | +def main(): | ||
| 34 | + # show("./UVR-MDX-NET-Voc_FT.onnx") | ||
| 35 | + show("./UVR_MDXNET_1_9703.onnx") | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +if __name__ == "__main__": | ||
| 39 | + main() |
scripts/uvr_mdx/test.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import time | ||
| 5 | + | ||
| 6 | +import argparse | ||
| 7 | +import kaldi_native_fbank as knf | ||
| 8 | +import librosa | ||
| 9 | +import numpy as np | ||
| 10 | +import onnxruntime as ort | ||
| 11 | +import soundfile as sf | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +def get_args(): | ||
| 15 | + parser = argparse.ArgumentParser( | ||
| 16 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + parser.add_argument( | ||
| 20 | + "--model-filename", | ||
| 21 | + type=str, | ||
| 22 | + required=True, | ||
| 23 | + help="Path to onnx model", | ||
| 24 | + ) | ||
| 25 | + | ||
| 26 | + parser.add_argument( | ||
| 27 | + "--audio-filename", | ||
| 28 | + type=str, | ||
| 29 | + required=True, | ||
| 30 | + help="Path to input audio file", | ||
| 31 | + ) | ||
| 32 | + | ||
| 33 | + return parser.parse_args() | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +class OnnxModel: | ||
| 37 | + def __init__(self, filename): | ||
| 38 | + session_opts = ort.SessionOptions() | ||
| 39 | + session_opts.inter_op_num_threads = 4 | ||
| 40 | + session_opts.intra_op_num_threads = 4 | ||
| 41 | + | ||
| 42 | + self.session_opts = session_opts | ||
| 43 | + self.model = ort.InferenceSession( | ||
| 44 | + filename, | ||
| 45 | + sess_options=self.session_opts, | ||
| 46 | + providers=["CPUExecutionProvider"], | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + self.dim_t = self.model.get_outputs()[0].shape[3] | ||
| 50 | + | ||
| 51 | + self.dim_f = self.model.get_outputs()[0].shape[2] | ||
| 52 | + | ||
| 53 | + self.n_fft = self.dim_f * 2 | ||
| 54 | + | ||
| 55 | + self.dim_c = self.model.get_outputs()[0].shape[1] | ||
| 56 | + assert self.dim_c == 4, self.dim_c | ||
| 57 | + | ||
| 58 | + self.hop = 1024 | ||
| 59 | + self.n_bins = self.n_fft // 2 + 1 | ||
| 60 | + self.chunk_size = self.hop * (self.dim_t - 1) | ||
| 61 | + | ||
| 62 | + self.freq_pad = np.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t]) | ||
| 63 | + | ||
| 64 | + print(f"----------inputs for {filename}----------") | ||
| 65 | + for i in self.model.get_inputs(): | ||
| 66 | + print(i) | ||
| 67 | + | ||
| 68 | + print(f"----------outputs for {filename}----------") | ||
| 69 | + | ||
| 70 | + for i in self.model.get_outputs(): | ||
| 71 | + print(i) | ||
| 72 | + print(i.shape) | ||
| 73 | + print("--------------------") | ||
| 74 | + | ||
| 75 | + def __call__(self, x): | ||
| 76 | + """ | ||
| 77 | + Args: | ||
| 78 | + x: (batch_size, 4, self.dim_f, self.dim_t) | ||
| 79 | + Returns: | ||
| 80 | + spec: (batch_size, 4, self.dim_f, self.dim_t) | ||
| 81 | + """ | ||
| 82 | + spec = self.model.run( | ||
| 83 | + [ | ||
| 84 | + self.model.get_outputs()[0].name, | ||
| 85 | + ], | ||
| 86 | + { | ||
| 87 | + self.model.get_inputs()[0].name: x, | ||
| 88 | + }, | ||
| 89 | + )[0] | ||
| 90 | + | ||
| 91 | + return spec | ||
| 92 | + | ||
| 93 | + | ||
| 94 | +def main(): | ||
| 95 | + args = get_args() | ||
| 96 | + m = OnnxModel(args.model_filename) | ||
| 97 | + | ||
| 98 | + stft_config = knf.StftConfig( | ||
| 99 | + n_fft=m.n_fft, | ||
| 100 | + hop_length=m.hop, | ||
| 101 | + win_length=m.n_fft, | ||
| 102 | + center=True, | ||
| 103 | + window_type="hann", | ||
| 104 | + ) | ||
| 105 | + knf_stft = knf.Stft(stft_config) | ||
| 106 | + knf_istft = knf.IStft(stft_config) | ||
| 107 | + | ||
| 108 | + sample_rate = 44100 | ||
| 109 | + | ||
| 110 | + samples, rate = librosa.load(args.audio_filename, mono=False, sr=sample_rate) | ||
| 111 | + | ||
| 112 | + start_time = time.time() | ||
| 113 | + | ||
| 114 | + assert rate == sample_rate, (rate, sample_rate) | ||
| 115 | + | ||
| 116 | + # samples: (2, 479832) , (num_channels, num_samples), 44100, 10.88 | ||
| 117 | + print("samples", samples.shape, rate, samples.shape[1] / rate) | ||
| 118 | + | ||
| 119 | + assert samples.ndim == 2, samples.shape | ||
| 120 | + assert samples.shape[0] == 2, samples.shape | ||
| 121 | + | ||
| 122 | + margin = sample_rate | ||
| 123 | + | ||
| 124 | + num_chunks = 15 | ||
| 125 | + chunk_size = num_chunks * sample_rate | ||
| 126 | + | ||
| 127 | + # if they are too few samples, reset chunk_size | ||
| 128 | + if samples.shape[1] < chunk_size: | ||
| 129 | + chunk_size = samples.shape[1] | ||
| 130 | + | ||
| 131 | + if margin > chunk_size: | ||
| 132 | + margin = chunk_size | ||
| 133 | + | ||
| 134 | + segments = [] | ||
| 135 | + for skip in range(0, samples.shape[1], chunk_size): | ||
| 136 | + start = max(0, skip - margin) | ||
| 137 | + end = min(skip + chunk_size + margin, samples.shape[1]) | ||
| 138 | + segments.append(samples[:, start:end]) | ||
| 139 | + if end == samples.shape[1]: | ||
| 140 | + break | ||
| 141 | + | ||
| 142 | + sources = [] | ||
| 143 | + for kk, s in enumerate(segments): | ||
| 144 | + num_samples = s.shape[1] | ||
| 145 | + trim = m.n_fft // 2 | ||
| 146 | + gen_size = m.chunk_size - 2 * trim | ||
| 147 | + pad = gen_size - s.shape[1] % gen_size | ||
| 148 | + mix_p = np.concatenate( | ||
| 149 | + ( | ||
| 150 | + np.zeros((2, trim)), | ||
| 151 | + s, | ||
| 152 | + np.zeros((2, pad)), | ||
| 153 | + np.zeros((2, trim)), | ||
| 154 | + ), | ||
| 155 | + axis=1, | ||
| 156 | + ) | ||
| 157 | + | ||
| 158 | + chunk_list = [] | ||
| 159 | + i = 0 | ||
| 160 | + while i < s.shape[1] + pad: | ||
| 161 | + chunk_list.append(mix_p[:, i : i + m.chunk_size]) | ||
| 162 | + i += gen_size | ||
| 163 | + | ||
| 164 | + mix_waves = np.array(chunk_list) | ||
| 165 | + | ||
| 166 | + mix_waves_reshaped = mix_waves.reshape(-1, m.chunk_size) | ||
| 167 | + stft_results = [] | ||
| 168 | + for w in mix_waves_reshaped: | ||
| 169 | + stft = knf_stft(w) | ||
| 170 | + stft_results.append(stft) | ||
| 171 | + real = np.array( | ||
| 172 | + [np.array(s.real).reshape(s.num_frames, -1) for s in stft_results], | ||
| 173 | + dtype=np.float32, | ||
| 174 | + )[:, :, :-1] | ||
| 175 | + # real: (6, 256, 3072) | ||
| 176 | + | ||
| 177 | + real = real.transpose(0, 2, 1) | ||
| 178 | + # real: (6, 3072, 256) | ||
| 179 | + | ||
| 180 | + imag = np.array( | ||
| 181 | + [np.array(s.imag).reshape(s.num_frames, -1) for s in stft_results], | ||
| 182 | + dtype=np.float32, | ||
| 183 | + )[:, :, :-1] | ||
| 184 | + imag = imag.transpose(0, 2, 1) | ||
| 185 | + # imag: (6, 3072, 256) | ||
| 186 | + | ||
| 187 | + x = np.stack([real, imag], axis=1) | ||
| 188 | + # x: (6, 2, 3072, 256) -> (batch_size, real_imag, 3072, 256) | ||
| 189 | + x = x.reshape(-1, m.dim_c, m.dim_f, m.dim_t) | ||
| 190 | + # x: (3, 4, 3072, 256) | ||
| 191 | + spec = m(x) | ||
| 192 | + | ||
| 193 | + freq_pad = np.repeat(m.freq_pad, spec.shape[0], axis=0) | ||
| 194 | + | ||
| 195 | + x = np.concatenate([spec, freq_pad], axis=2) | ||
| 196 | + # x: (3, 4, 3073, 256) | ||
| 197 | + x = x.reshape(-1, 2, m.n_bins, m.dim_t) | ||
| 198 | + # x: (6, 2, 3073, 256) | ||
| 199 | + x = x.transpose(0, 1, 3, 2) | ||
| 200 | + # x: (6, 2, 256, 3073) | ||
| 201 | + num_frames = x.shape[2] | ||
| 202 | + | ||
| 203 | + x = x.reshape(x.shape[0], x.shape[1], -1) | ||
| 204 | + wav_list = [] | ||
| 205 | + for k in range(x.shape[0]): | ||
| 206 | + istft_result = knf.StftResult( | ||
| 207 | + real=x[k, 0].reshape(-1).tolist(), | ||
| 208 | + imag=x[k, 1].reshape(-1).tolist(), | ||
| 209 | + num_frames=num_frames, | ||
| 210 | + ) | ||
| 211 | + wav = knf_istft(istft_result) | ||
| 212 | + wav_list.append(wav) | ||
| 213 | + wav = np.array(wav_list, dtype=np.float32) | ||
| 214 | + # wav: (6, 261120) | ||
| 215 | + | ||
| 216 | + wav = wav.reshape(-1, 2, wav.shape[-1]) | ||
| 217 | + # wav: (3, 2, 261120) | ||
| 218 | + | ||
| 219 | + wav = wav[:, :, trim:-trim] | ||
| 220 | + # wav: (3, 2, 254976) | ||
| 221 | + | ||
| 222 | + wav = wav.transpose(1, 0, 2) | ||
| 223 | + # wav: (2, 3, 254976) | ||
| 224 | + | ||
| 225 | + wav = wav.reshape(2, -1) | ||
| 226 | + # wav: (2, 764928) | ||
| 227 | + | ||
| 228 | + wav = wav[:, :-pad] | ||
| 229 | + # wav: 2, 705600) | ||
| 230 | + if kk == 0: | ||
| 231 | + start = 0 | ||
| 232 | + else: | ||
| 233 | + start = margin | ||
| 234 | + | ||
| 235 | + if kk == len(segments) - 1: | ||
| 236 | + end = None | ||
| 237 | + else: | ||
| 238 | + end = -margin | ||
| 239 | + | ||
| 240 | + sources.append(wav[:, start:end]) | ||
| 241 | + | ||
| 242 | + sources = np.concatenate(sources, axis=-1) | ||
| 243 | + | ||
| 244 | + vocals = sources | ||
| 245 | + non_vocals = samples - vocals | ||
| 246 | + end_time = time.time() | ||
| 247 | + elapsed_seconds = end_time - start_time | ||
| 248 | + print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 249 | + | ||
| 250 | + audio_duration = samples.shape[1] / sample_rate | ||
| 251 | + real_time_factor = elapsed_seconds / audio_duration | ||
| 252 | + print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 253 | + print(f"Audio duration in seconds: {audio_duration:.3f}") | ||
| 254 | + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") | ||
| 255 | + | ||
| 256 | + sf.write(f"./vocals.mp3", np.transpose(vocals), sample_rate) | ||
| 257 | + sf.write(f"./non_vocals.mp3", np.transpose(non_vocals), sample_rate) | ||
| 258 | + | ||
| 259 | + | ||
| 260 | +if __name__ == "__main__": | ||
| 261 | + main() |
-
请 注册 或 登录 后发表评论