Committed by
GitHub
Export spleeter model to onnx for source separation (#2237)
正在显示
10 个修改的文件
包含
1100 行增加
和
0 行删除
| 1 | +name: export-spleeter-to-onnx | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - spleeter-2 | ||
| 7 | + workflow_dispatch: | ||
| 8 | + | ||
| 9 | +concurrency: | ||
| 10 | + group: export-spleeter-to-onnx-${{ github.ref }} | ||
| 11 | + cancel-in-progress: true | ||
| 12 | + | ||
| 13 | +jobs: | ||
| 14 | + export-spleeter-to-onnx: | ||
| 15 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 16 | + name: export spleeter to ONNX | ||
| 17 | + runs-on: ${{ matrix.os }} | ||
| 18 | + strategy: | ||
| 19 | + fail-fast: false | ||
| 20 | + matrix: | ||
| 21 | + os: [macos-latest] | ||
| 22 | + python-version: ["3.10"] | ||
| 23 | + | ||
| 24 | + steps: | ||
| 25 | + - uses: actions/checkout@v4 | ||
| 26 | + | ||
| 27 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 28 | + uses: actions/setup-python@v5 | ||
| 29 | + with: | ||
| 30 | + python-version: ${{ matrix.python-version }} | ||
| 31 | + | ||
| 32 | + - name: Install dependencies | ||
| 33 | + shell: bash | ||
| 34 | + run: | | ||
| 35 | + pip install tensorflow torch "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools | ||
| 36 | + | ||
| 37 | + - name: Run | ||
| 38 | + shell: bash | ||
| 39 | + run: | | ||
| 40 | + cd scripts/spleeter | ||
| 41 | + ./run.sh | ||
| 42 | + | ||
| 43 | + echo "---" | ||
| 44 | + ls -lh 2stems | ||
| 45 | + echo "---" | ||
| 46 | + ls -lh 2stems/*.onnx | ||
| 47 | + echo "---" | ||
| 48 | + | ||
| 49 | + mv -v 2stems/*.onnx ../.. | ||
| 50 | + | ||
| 51 | + - name: Collect models | ||
| 52 | + shell: bash | ||
| 53 | + run: | | ||
| 54 | + mkdir sherpa-onnx-spleeter-2stems | ||
| 55 | + mkdir sherpa-onnx-spleeter-2stems-int8 | ||
| 56 | + mkdir sherpa-onnx-spleeter-2stems-fp16 | ||
| 57 | + | ||
| 58 | + mv -v vocals.onnx sherpa-onnx-spleeter-2stems/ | ||
| 59 | + mv -v accompaniment.onnx sherpa-onnx-spleeter-2stems/ | ||
| 60 | + | ||
| 61 | + mv -v vocals.int8.onnx sherpa-onnx-spleeter-2stems-int8/ | ||
| 62 | + mv -v accompaniment.int8.onnx sherpa-onnx-spleeter-2stems-int8/ | ||
| 63 | + | ||
| 64 | + mv -v vocals.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/ | ||
| 65 | + mv -v accompaniment.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/ | ||
| 66 | + | ||
| 67 | + tar cjvf sherpa-onnx-spleeter-2stems.tar.bz2 sherpa-onnx-spleeter-2stems | ||
| 68 | + tar cjvf sherpa-onnx-spleeter-2stems-int8.tar.bz2 sherpa-onnx-spleeter-2stems-int8 | ||
| 69 | + tar cjvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 sherpa-onnx-spleeter-2stems-fp16 | ||
| 70 | + | ||
| 71 | + ls -lh *.tar.bz2 | ||
| 72 | + | ||
| 73 | + - name: Release | ||
| 74 | + uses: svenstaro/upload-release-action@v2 | ||
| 75 | + with: | ||
| 76 | + file_glob: true | ||
| 77 | + file: ./*.tar.bz2 | ||
| 78 | + overwrite: true | ||
| 79 | + repo_name: k2-fsa/sherpa-onnx | ||
| 80 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 81 | + tag: source-separation-models | ||
| 82 | + | ||
| 83 | + - name: Publish to huggingface | ||
| 84 | + env: | ||
| 85 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 86 | + uses: nick-fields/retry@v3 | ||
| 87 | + with: | ||
| 88 | + max_attempts: 20 | ||
| 89 | + timeout_seconds: 200 | ||
| 90 | + shell: bash | ||
| 91 | + command: | | ||
| 92 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 93 | + git config --global user.name "Fangjun Kuang" | ||
| 94 | + | ||
| 95 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 96 | + export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 97 | + | ||
| 98 | + names=( | ||
| 99 | + sherpa-onnx-spleeter-2stems | ||
| 100 | + sherpa-onnx-spleeter-2stems-int8 | ||
| 101 | + sherpa-onnx-spleeter-2stems-fp16 | ||
| 102 | + ) | ||
| 103 | + for d in ${names[@]}; do | ||
| 104 | + rm -rf huggingface | ||
| 105 | + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface | ||
| 106 | + cp -v $d/*onnx huggingface | ||
| 107 | + | ||
| 108 | + cd huggingface | ||
| 109 | + git lfs track "*.onnx" | ||
| 110 | + git status | ||
| 111 | + git add . | ||
| 112 | + ls -lh | ||
| 113 | + git status | ||
| 114 | + git commit -m "add models" | ||
| 115 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main | ||
| 116 | + cd .. | ||
| 117 | + done |
scripts/spleeter/.gitignore
0 → 100644
scripts/spleeter/__init__.py
0 → 100644
scripts/spleeter/convert_to_pb.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +# Code in this file is modified from | ||
| 4 | +# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc | ||
| 5 | +# | ||
| 6 | +# Please see ./run.sh for usages | ||
| 7 | +import argparse | ||
| 8 | + | ||
| 9 | +import tensorflow as tf | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def freeze_graph(model_dir, output_node_names, output_filename): | ||
| 13 | + """Extract the sub graph defined by the output nodes and convert all its | ||
| 14 | + variables into constant | ||
| 15 | + | ||
| 16 | + Args: | ||
| 17 | + model_dir: | ||
| 18 | + the root folder containing the checkpoint state file | ||
| 19 | + output_node_names: | ||
| 20 | + a string, containing all the output node's names, comma separated | ||
| 21 | + output_filename: | ||
| 22 | + Filename to save the graph. | ||
| 23 | + """ | ||
| 24 | + if not tf.compat.v1.gfile.Exists(model_dir): | ||
| 25 | + raise AssertionError( | ||
| 26 | + "Export directory doesn't exists. Please specify an export " | ||
| 27 | + "directory: %s" % model_dir | ||
| 28 | + ) | ||
| 29 | + | ||
| 30 | + if not output_node_names: | ||
| 31 | + print("You need to supply the name of a node to --output_node_names.") | ||
| 32 | + return -1 | ||
| 33 | + | ||
| 34 | + # We retrieve our checkpoint fullpath | ||
| 35 | + checkpoint = tf.train.get_checkpoint_state(model_dir) | ||
| 36 | + input_checkpoint = checkpoint.model_checkpoint_path | ||
| 37 | + | ||
| 38 | + # We precise the file fullname of our freezed graph | ||
| 39 | + output_graph = output_filename | ||
| 40 | + | ||
| 41 | + # We clear devices to allow TensorFlow to control on which device it will load operations | ||
| 42 | + clear_devices = True | ||
| 43 | + | ||
| 44 | + # We start a session using a temporary fresh Graph | ||
| 45 | + with tf.compat.v1.Session(graph=tf.Graph()) as sess: | ||
| 46 | + # We import the meta graph in the current default Graph | ||
| 47 | + saver = tf.compat.v1.train.import_meta_graph( | ||
| 48 | + input_checkpoint + ".meta", clear_devices=clear_devices | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + # We restore the weights | ||
| 52 | + saver.restore(sess, input_checkpoint) | ||
| 53 | + | ||
| 54 | + # We use a built-in TF helper to export variables to constants | ||
| 55 | + output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( | ||
| 56 | + sess, # The session is used to retrieve the weights | ||
| 57 | + tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes | ||
| 58 | + output_node_names.split( | ||
| 59 | + "," | ||
| 60 | + ), # The output node names are used to select the usefull nodes | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + # Finally we serialize and dump the output graph to the filesystem | ||
| 64 | + with tf.compat.v1.gfile.GFile(output_graph, "wb") as f: | ||
| 65 | + f.write(output_graph_def.SerializeToString()) | ||
| 66 | + print("%d ops in the final graph." % len(output_graph_def.node)) | ||
| 67 | + | ||
| 68 | + return output_graph_def | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +if __name__ == "__main__": | ||
| 72 | + parser = argparse.ArgumentParser() | ||
| 73 | + parser.add_argument( | ||
| 74 | + "--model-dir", type=str, default="", help="Model folder to export" | ||
| 75 | + ) | ||
| 76 | + parser.add_argument( | ||
| 77 | + "--output-node-names", | ||
| 78 | + type=str, | ||
| 79 | + default="vocals_spectrogram/mul,accompaniment_spectrogram/mul", | ||
| 80 | + help="The name of the output nodes, comma separated.", | ||
| 81 | + ) | ||
| 82 | + | ||
| 83 | + parser.add_argument( | ||
| 84 | + "--output-filename", | ||
| 85 | + type=str, | ||
| 86 | + ) | ||
| 87 | + args = parser.parse_args() | ||
| 88 | + | ||
| 89 | + freeze_graph(args.model_dir, args.output_node_names, args.output_filename) |
scripts/spleeter/convert_to_torch.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +# Please see ./run.sh for usage | ||
| 5 | + | ||
| 6 | +import argparse | ||
| 7 | + | ||
| 8 | +import numpy as np | ||
| 9 | +import tensorflow as tf | ||
| 10 | +import torch | ||
| 11 | + | ||
| 12 | +from unet import UNet | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def load_graph(frozen_graph_filename): | ||
| 16 | + # This function is modified from | ||
| 17 | + # https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc | ||
| 18 | + | ||
| 19 | + # We load the protobuf file from the disk and parse it to retrieve the | ||
| 20 | + # unserialized graph_def | ||
| 21 | + with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f: | ||
| 22 | + graph_def = tf.compat.v1.GraphDef() | ||
| 23 | + graph_def.ParseFromString(f.read()) | ||
| 24 | + | ||
| 25 | + # Then, we import the graph_def into a new Graph and returns it | ||
| 26 | + with tf.Graph().as_default() as graph: | ||
| 27 | + # The name var will prefix every op/nodes in your graph | ||
| 28 | + # Since we load everything in a new graph, this is not needed | ||
| 29 | + # tf.import_graph_def(graph_def, name="prefix") | ||
| 30 | + tf.import_graph_def(graph_def, name="") | ||
| 31 | + return graph | ||
| 32 | + | ||
| 33 | + | ||
| 34 | +def generate_waveform(): | ||
| 35 | + np.random.seed(20230821) | ||
| 36 | + waveform = np.random.rand(60 * 44100).astype(np.float32) | ||
| 37 | + | ||
| 38 | + # (num_samples, num_channels) | ||
| 39 | + waveform = waveform.reshape(-1, 2) | ||
| 40 | + return waveform | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +def get_param(graph, name): | ||
| 44 | + with tf.compat.v1.Session(graph=graph) as sess: | ||
| 45 | + constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] | ||
| 46 | + for constant_op in constant_ops: | ||
| 47 | + if constant_op.name != name: | ||
| 48 | + continue | ||
| 49 | + | ||
| 50 | + value = sess.run(constant_op.outputs[0]) | ||
| 51 | + return torch.from_numpy(value) | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +@torch.no_grad() | ||
| 55 | +def main(name): | ||
| 56 | + graph = load_graph(f"./2stems/frozen_{name}_model.pb") | ||
| 57 | + # for op in graph.get_operations(): | ||
| 58 | + # print(op.name) | ||
| 59 | + x = graph.get_tensor_by_name("waveform:0") | ||
| 60 | + # y = graph.get_tensor_by_name("Reshape:0") | ||
| 61 | + y0 = graph.get_tensor_by_name("strided_slice_3:0") | ||
| 62 | + # y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0") | ||
| 63 | + # y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0") | ||
| 64 | + # y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0") | ||
| 65 | + # y1 = graph.get_tensor_by_name("re_lu/Relu:0") | ||
| 66 | + # y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0") | ||
| 67 | + # y1 = graph.get_tensor_by_name("concatenate/concat:0") | ||
| 68 | + # y1 = graph.get_tensor_by_name("concatenate_1/concat:0") | ||
| 69 | + # y1 = graph.get_tensor_by_name("concatenate_4/concat:0") | ||
| 70 | + # y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0") | ||
| 71 | + # y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0") | ||
| 72 | + y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0") | ||
| 73 | + | ||
| 74 | + unet = UNet() | ||
| 75 | + unet.eval() | ||
| 76 | + | ||
| 77 | + # For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel) | ||
| 78 | + # default input shape is NHWC | ||
| 79 | + | ||
| 80 | + # For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w) | ||
| 81 | + # default input shape is NCHW | ||
| 82 | + state_dict = unet.state_dict() | ||
| 83 | + # print(list(state_dict.keys())) | ||
| 84 | + | ||
| 85 | + if name == "vocals": | ||
| 86 | + state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute( | ||
| 87 | + 3, 2, 0, 1 | ||
| 88 | + ) | ||
| 89 | + state_dict["conv.bias"] = get_param(graph, "conv2d/bias") | ||
| 90 | + | ||
| 91 | + state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma") | ||
| 92 | + state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta") | ||
| 93 | + state_dict["bn.running_mean"] = get_param( | ||
| 94 | + graph, "batch_normalization/moving_mean" | ||
| 95 | + ) | ||
| 96 | + state_dict["bn.running_var"] = get_param( | ||
| 97 | + graph, "batch_normalization/moving_variance" | ||
| 98 | + ) | ||
| 99 | + | ||
| 100 | + conv_offset = 0 | ||
| 101 | + bn_offset = 0 | ||
| 102 | + else: | ||
| 103 | + state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute( | ||
| 104 | + 3, 2, 0, 1 | ||
| 105 | + ) | ||
| 106 | + state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias") | ||
| 107 | + | ||
| 108 | + state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma") | ||
| 109 | + state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta") | ||
| 110 | + state_dict["bn.running_mean"] = get_param( | ||
| 111 | + graph, "batch_normalization_12/moving_mean" | ||
| 112 | + ) | ||
| 113 | + state_dict["bn.running_var"] = get_param( | ||
| 114 | + graph, "batch_normalization_12/moving_variance" | ||
| 115 | + ) | ||
| 116 | + conv_offset = 7 | ||
| 117 | + bn_offset = 12 | ||
| 118 | + | ||
| 119 | + for i in range(1, 6): | ||
| 120 | + state_dict[f"conv{i}.weight"] = get_param( | ||
| 121 | + graph, f"conv2d_{i+conv_offset}/kernel" | ||
| 122 | + ).permute(3, 2, 0, 1) | ||
| 123 | + state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias") | ||
| 124 | + if i >= 5: | ||
| 125 | + continue | ||
| 126 | + state_dict[f"bn{i}.weight"] = get_param( | ||
| 127 | + graph, f"batch_normalization_{i+bn_offset}/gamma" | ||
| 128 | + ) | ||
| 129 | + state_dict[f"bn{i}.bias"] = get_param( | ||
| 130 | + graph, f"batch_normalization_{i+bn_offset}/beta" | ||
| 131 | + ) | ||
| 132 | + state_dict[f"bn{i}.running_mean"] = get_param( | ||
| 133 | + graph, f"batch_normalization_{i+bn_offset}/moving_mean" | ||
| 134 | + ) | ||
| 135 | + state_dict[f"bn{i}.running_var"] = get_param( | ||
| 136 | + graph, f"batch_normalization_{i+bn_offset}/moving_variance" | ||
| 137 | + ) | ||
| 138 | + | ||
| 139 | + if name == "vocals": | ||
| 140 | + state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute( | ||
| 141 | + 3, 2, 0, 1 | ||
| 142 | + ) | ||
| 143 | + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias") | ||
| 144 | + | ||
| 145 | + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma") | ||
| 146 | + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta") | ||
| 147 | + state_dict["bn5.running_mean"] = get_param( | ||
| 148 | + graph, "batch_normalization_6/moving_mean" | ||
| 149 | + ) | ||
| 150 | + state_dict["bn5.running_var"] = get_param( | ||
| 151 | + graph, "batch_normalization_6/moving_variance" | ||
| 152 | + ) | ||
| 153 | + conv_offset = 0 | ||
| 154 | + bn_offset = 0 | ||
| 155 | + else: | ||
| 156 | + state_dict["up1.weight"] = get_param( | ||
| 157 | + graph, "conv2d_transpose_6/kernel" | ||
| 158 | + ).permute(3, 2, 0, 1) | ||
| 159 | + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias") | ||
| 160 | + | ||
| 161 | + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma") | ||
| 162 | + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta") | ||
| 163 | + state_dict["bn5.running_mean"] = get_param( | ||
| 164 | + graph, "batch_normalization_18/moving_mean" | ||
| 165 | + ) | ||
| 166 | + state_dict["bn5.running_var"] = get_param( | ||
| 167 | + graph, "batch_normalization_18/moving_variance" | ||
| 168 | + ) | ||
| 169 | + conv_offset = 6 | ||
| 170 | + bn_offset = 12 | ||
| 171 | + | ||
| 172 | + for i in range(1, 6): | ||
| 173 | + state_dict[f"up{i+1}.weight"] = get_param( | ||
| 174 | + graph, f"conv2d_transpose_{i+conv_offset}/kernel" | ||
| 175 | + ).permute(3, 2, 0, 1) | ||
| 176 | + | ||
| 177 | + state_dict[f"up{i+1}.bias"] = get_param( | ||
| 178 | + graph, f"conv2d_transpose_{i+conv_offset}/bias" | ||
| 179 | + ) | ||
| 180 | + | ||
| 181 | + state_dict[f"bn{5+i}.weight"] = get_param( | ||
| 182 | + graph, f"batch_normalization_{6+i+bn_offset}/gamma" | ||
| 183 | + ) | ||
| 184 | + state_dict[f"bn{5+i}.bias"] = get_param( | ||
| 185 | + graph, f"batch_normalization_{6+i+bn_offset}/beta" | ||
| 186 | + ) | ||
| 187 | + state_dict[f"bn{5+i}.running_mean"] = get_param( | ||
| 188 | + graph, f"batch_normalization_{6+i+bn_offset}/moving_mean" | ||
| 189 | + ) | ||
| 190 | + state_dict[f"bn{5+i}.running_var"] = get_param( | ||
| 191 | + graph, f"batch_normalization_{6+i+bn_offset}/moving_variance" | ||
| 192 | + ) | ||
| 193 | + | ||
| 194 | + if name == "vocals": | ||
| 195 | + state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute( | ||
| 196 | + 3, 2, 0, 1 | ||
| 197 | + ) | ||
| 198 | + state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias") | ||
| 199 | + else: | ||
| 200 | + state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute( | ||
| 201 | + 3, 2, 0, 1 | ||
| 202 | + ) | ||
| 203 | + state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias") | ||
| 204 | + | ||
| 205 | + unet.load_state_dict(state_dict) | ||
| 206 | + | ||
| 207 | + with tf.compat.v1.Session(graph=graph) as sess: | ||
| 208 | + y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()}) | ||
| 209 | + # y0_out = sess.run(y0, feed_dict={x: generate_waveform()}) | ||
| 210 | + # y1_out = sess.run(y1, feed_dict={x: generate_waveform()}) | ||
| 211 | + # print(y0_out.shape) | ||
| 212 | + # print(y1_out.shape) | ||
| 213 | + | ||
| 214 | + # for the batchnormalization in tensorflow, | ||
| 215 | + # default input shape is NHWC | ||
| 216 | + | ||
| 217 | + # for the batchnormalization in torch, | ||
| 218 | + # default input shape is NCHW | ||
| 219 | + | ||
| 220 | + # NHWC to NCHW | ||
| 221 | + torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) | ||
| 222 | + | ||
| 223 | + # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) | ||
| 224 | + assert torch.allclose( | ||
| 225 | + torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1 | ||
| 226 | + ), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max()) | ||
| 227 | + torch.save(unet.state_dict(), f"2stems/{name}.pt") | ||
| 228 | + | ||
| 229 | + | ||
| 230 | +if __name__ == "__main__": | ||
| 231 | + parser = argparse.ArgumentParser() | ||
| 232 | + parser.add_argument( | ||
| 233 | + "--name", | ||
| 234 | + type=str, | ||
| 235 | + required=True, | ||
| 236 | + choices=["vocals", "accompaniment"], | ||
| 237 | + ) | ||
| 238 | + args = parser.parse_args() | ||
| 239 | + print(vars(args)) | ||
| 240 | + main(args.name) |
scripts/spleeter/export_onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import onnx | ||
| 5 | +import onnxmltools | ||
| 6 | +import torch | ||
| 7 | +from onnxmltools.utils.float16_converter import convert_float_to_float16 | ||
| 8 | +from onnxruntime.quantization import QuantType, quantize_dynamic | ||
| 9 | + | ||
| 10 | +from unet import UNet | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): | ||
| 14 | + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) | ||
| 15 | + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) | ||
| 16 | + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +def add_meta_data(filename, prefix): | ||
| 20 | + meta_data = { | ||
| 21 | + "model_type": "spleeter", | ||
| 22 | + "sample_rate": 41000, | ||
| 23 | + "version": 1, | ||
| 24 | + "model_url": "https://github.com/deezer/spleeter", | ||
| 25 | + "stems": 2, | ||
| 26 | + "comment": prefix, | ||
| 27 | + "model_name": "2stems.tar.gz", | ||
| 28 | + } | ||
| 29 | + model = onnx.load(filename) | ||
| 30 | + | ||
| 31 | + print(model.metadata_props) | ||
| 32 | + | ||
| 33 | + while len(model.metadata_props): | ||
| 34 | + model.metadata_props.pop() | ||
| 35 | + | ||
| 36 | + for key, value in meta_data.items(): | ||
| 37 | + meta = model.metadata_props.add() | ||
| 38 | + meta.key = key | ||
| 39 | + meta.value = str(value) | ||
| 40 | + print("--------------------") | ||
| 41 | + | ||
| 42 | + print(model.metadata_props) | ||
| 43 | + | ||
| 44 | + onnx.save(model, filename) | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def export(model, prefix): | ||
| 48 | + num_splits = 1 | ||
| 49 | + x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) | ||
| 50 | + | ||
| 51 | + filename = f"./2stems/{prefix}.onnx" | ||
| 52 | + torch.onnx.export( | ||
| 53 | + model, | ||
| 54 | + x, | ||
| 55 | + filename, | ||
| 56 | + input_names=["x"], | ||
| 57 | + output_names=["y"], | ||
| 58 | + dynamic_axes={ | ||
| 59 | + "x": {0: "num_splits"}, | ||
| 60 | + }, | ||
| 61 | + opset_version=13, | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + add_meta_data(filename, prefix) | ||
| 65 | + | ||
| 66 | + filename_int8 = f"./2stems/{prefix}.int8.onnx" | ||
| 67 | + quantize_dynamic( | ||
| 68 | + model_input=filename, | ||
| 69 | + model_output=filename_int8, | ||
| 70 | + weight_type=QuantType.QUInt8, | ||
| 71 | + ) | ||
| 72 | + | ||
| 73 | + filename_fp16 = f"./2stems/{prefix}.fp16.onnx" | ||
| 74 | + export_onnx_fp16(filename, filename_fp16) | ||
| 75 | + | ||
| 76 | + | ||
| 77 | +@torch.no_grad() | ||
| 78 | +def main(): | ||
| 79 | + vocals = UNet() | ||
| 80 | + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu") | ||
| 81 | + vocals.load_state_dict(state_dict) | ||
| 82 | + vocals.eval() | ||
| 83 | + | ||
| 84 | + accompaniment = UNet() | ||
| 85 | + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu") | ||
| 86 | + accompaniment.load_state_dict(state_dict) | ||
| 87 | + accompaniment.eval() | ||
| 88 | + | ||
| 89 | + export(vocals, "vocals") | ||
| 90 | + export(accompaniment, "accompaniment") | ||
| 91 | + | ||
| 92 | + | ||
| 93 | +if __name__ == "__main__": | ||
| 94 | + main() |
scripts/spleeter/run.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | + | ||
| 4 | +if [ ! -f 2stems.tar.gz ]; then | ||
| 5 | + curl -SL -O https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz | ||
| 6 | +fi | ||
| 7 | + | ||
| 8 | +if [ ! -d ./2stems ]; then | ||
| 9 | + mkdir -p 2stems | ||
| 10 | + cd 2stems | ||
| 11 | + tar xvf ../2stems.tar.gz | ||
| 12 | + cd .. | ||
| 13 | +fi | ||
| 14 | + | ||
| 15 | +ls -lh | ||
| 16 | + | ||
| 17 | +ls -lh 2stems | ||
| 18 | + | ||
| 19 | +if [ ! -f 2stems/frozen_vocals_model.pb ]; then | ||
| 20 | + python3 ./convert_to_pb.py \ | ||
| 21 | + --model-dir ./2stems \ | ||
| 22 | + --output-node-names vocals_spectrogram/mul \ | ||
| 23 | + --output-filename ./2stems/frozen_vocals_model.pb | ||
| 24 | +fi | ||
| 25 | + | ||
| 26 | +ls -lh 2stems | ||
| 27 | + | ||
| 28 | +if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then | ||
| 29 | + python3 ./convert_to_pb.py \ | ||
| 30 | + --model-dir ./2stems \ | ||
| 31 | + --output-node-names accompaniment_spectrogram/mul \ | ||
| 32 | + --output-filename ./2stems/frozen_accompaniment_model.pb | ||
| 33 | +fi | ||
| 34 | + | ||
| 35 | +ls -lh 2stems | ||
| 36 | + | ||
| 37 | +python3 ./convert_to_torch.py --name vocals | ||
| 38 | +python3 ./convert_to_torch.py --name accompaniment | ||
| 39 | +python3 ./export_onnx.py | ||
| 40 | + | ||
| 41 | +ls -lh 2stems |
scripts/spleeter/separate.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +# Please see ./run.sh for usage | ||
| 5 | + | ||
| 6 | +from typing import Optional | ||
| 7 | + | ||
| 8 | +import ffmpeg | ||
| 9 | +import numpy as np | ||
| 10 | +import soundfile as sf | ||
| 11 | +import torch | ||
| 12 | +from pydub import AudioSegment | ||
| 13 | + | ||
| 14 | +from unet import UNet | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +def load_audio(filename, sample_rate: Optional[int] = 44100): | ||
| 18 | + probe = ffmpeg.probe(filename) | ||
| 19 | + if "streams" not in probe or len(probe["streams"]) == 0: | ||
| 20 | + raise ValueError("No stream was found with ffprobe") | ||
| 21 | + | ||
| 22 | + metadata = next( | ||
| 23 | + stream for stream in probe["streams"] if stream["codec_type"] == "audio" | ||
| 24 | + ) | ||
| 25 | + n_channels = metadata["channels"] | ||
| 26 | + | ||
| 27 | + if sample_rate is None: | ||
| 28 | + sample_rate = metadata["sample_rate"] | ||
| 29 | + | ||
| 30 | + process = ( | ||
| 31 | + ffmpeg.input(filename) | ||
| 32 | + .output("pipe:", format="f32le", ar=sample_rate) | ||
| 33 | + .run_async(pipe_stdout=True, pipe_stderr=True) | ||
| 34 | + ) | ||
| 35 | + buffer, _ = process.communicate() | ||
| 36 | + waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels) | ||
| 37 | + | ||
| 38 | + waveform = torch.from_numpy(np.copy(waveform)).to(torch.float32) | ||
| 39 | + if n_channels == 1: | ||
| 40 | + waveform = waveform.tile(1, 2) | ||
| 41 | + | ||
| 42 | + if n_channels > 2: | ||
| 43 | + waveform = waveform[:, :2] | ||
| 44 | + | ||
| 45 | + return waveform, sample_rate | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +@torch.no_grad() | ||
| 49 | +def main(): | ||
| 50 | + vocals = UNet() | ||
| 51 | + vocals.eval() | ||
| 52 | + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu") | ||
| 53 | + vocals.load_state_dict(state_dict) | ||
| 54 | + | ||
| 55 | + accompaniment = UNet() | ||
| 56 | + accompaniment.eval() | ||
| 57 | + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu") | ||
| 58 | + accompaniment.load_state_dict(state_dict) | ||
| 59 | + | ||
| 60 | + # | ||
| 61 | + # waveform, sample_rate = load_audio("./audio_example.mp3") | ||
| 62 | + | ||
| 63 | + # You can download the following two mp3 from | ||
| 64 | + # https://huggingface.co/spaces/csukuangfj/music-source-separation/tree/main/examples | ||
| 65 | + waveform, sample_rate = load_audio("./qi-feng-le.mp3") | ||
| 66 | + # waveform, sample_rate = load_audio("./Yesterday_Once_More-Carpenters.mp3") | ||
| 67 | + assert waveform.shape[1] == 2, waveform.shape | ||
| 68 | + | ||
| 69 | + waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096)) | ||
| 70 | + | ||
| 71 | + # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform | ||
| 72 | + stft = torch.stft( | ||
| 73 | + waveform.t(), | ||
| 74 | + n_fft=4096, | ||
| 75 | + hop_length=1024, | ||
| 76 | + window=torch.hann_window(4096, periodic=True), | ||
| 77 | + center=False, | ||
| 78 | + onesided=True, | ||
| 79 | + return_complex=True, | ||
| 80 | + ) | ||
| 81 | + print("stft", stft.shape) | ||
| 82 | + | ||
| 83 | + # stft: (2, 2049, 465) | ||
| 84 | + # stft is a complex tensor | ||
| 85 | + y = stft.permute(2, 1, 0) | ||
| 86 | + print("y0", y.shape) | ||
| 87 | + # (465, 2049, 2) | ||
| 88 | + | ||
| 89 | + y = y[:, :1024, :] | ||
| 90 | + # (465, 1024, 2) | ||
| 91 | + | ||
| 92 | + tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512 | ||
| 93 | + pad_size = 512 - tensor_size | ||
| 94 | + y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size)) | ||
| 95 | + # (512, 1024, 2) | ||
| 96 | + print("y1", y.shape, y.dtype) | ||
| 97 | + | ||
| 98 | + num_splits = int(y.shape[0] / 512) | ||
| 99 | + y = y.reshape([num_splits, 512] + list(y.shape[1:])) | ||
| 100 | + # y: (1, 512, 1024, 2) | ||
| 101 | + print("y2", y.shape, y.dtype) | ||
| 102 | + | ||
| 103 | + y = y.abs() | ||
| 104 | + y = y.permute(0, 3, 1, 2) | ||
| 105 | + # (1, 2, 512, 1024) | ||
| 106 | + print("y3", y.shape, y.dtype) | ||
| 107 | + | ||
| 108 | + vocals_spec = vocals(y) | ||
| 109 | + accompaniment_spec = accompaniment(y) | ||
| 110 | + | ||
| 111 | + sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 | ||
| 112 | + print( | ||
| 113 | + "vocals_spec", | ||
| 114 | + vocals_spec.shape, | ||
| 115 | + accompaniment_spec.shape, | ||
| 116 | + sum_spec.shape, | ||
| 117 | + vocals_spec.dtype, | ||
| 118 | + ) | ||
| 119 | + | ||
| 120 | + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec | ||
| 121 | + # (1, 2, 512, 1024) | ||
| 122 | + | ||
| 123 | + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec | ||
| 124 | + # (1, 2, 512, 1024) | ||
| 125 | + | ||
| 126 | + for name, spec in zip( | ||
| 127 | + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] | ||
| 128 | + ): | ||
| 129 | + spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0)) | ||
| 130 | + # (1, 2, 512, 2049) | ||
| 131 | + | ||
| 132 | + spec = spec.permute(0, 2, 3, 1) | ||
| 133 | + # (1, 512, 2049, 2) | ||
| 134 | + print("here00", spec.shape) | ||
| 135 | + | ||
| 136 | + spec = spec.reshape(-1, spec.shape[2], spec.shape[3]) | ||
| 137 | + # (512, 2049, 2) | ||
| 138 | + | ||
| 139 | + print("here2", spec.shape) | ||
| 140 | + # (512, 2049, 2) | ||
| 141 | + | ||
| 142 | + spec = spec[: stft.shape[2], :, :] | ||
| 143 | + # (465, 2049, 2) | ||
| 144 | + print("here 3", spec.shape, stft.shape) | ||
| 145 | + | ||
| 146 | + spec = spec.permute(2, 1, 0) | ||
| 147 | + # (2, 2049, 465) | ||
| 148 | + | ||
| 149 | + masked_stft = spec * stft | ||
| 150 | + | ||
| 151 | + wave = torch.istft( | ||
| 152 | + masked_stft, | ||
| 153 | + 4096, | ||
| 154 | + 1024, | ||
| 155 | + window=torch.hann_window(4096, periodic=True), | ||
| 156 | + onesided=True, | ||
| 157 | + ) * (2 / 3) | ||
| 158 | + | ||
| 159 | + print(wave.shape, wave.dtype) | ||
| 160 | + sf.write(f"{name}.wav", wave.t(), 44100) | ||
| 161 | + | ||
| 162 | + wave = (wave.t() * 32768).to(torch.int16) | ||
| 163 | + sound = AudioSegment( | ||
| 164 | + data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 | ||
| 165 | + ) | ||
| 166 | + sound.export(f"{name}.mp3", format="mp3", bitrate="128k") | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +if __name__ == "__main__": | ||
| 170 | + main() |
scripts/spleeter/separate_onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import time | ||
| 4 | + | ||
| 5 | +import kaldi_native_fbank as knf | ||
| 6 | +import numpy as np | ||
| 7 | +import onnxruntime as ort | ||
| 8 | +import soundfile as sf | ||
| 9 | +import torch | ||
| 10 | + | ||
| 11 | +from separate import load_audio | ||
| 12 | + | ||
| 13 | +""" | ||
| 14 | +----------inputs for ./2stems/vocals.onnx---------- | ||
| 15 | +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) | ||
| 16 | +----------outputs for ./2stems/vocals.onnx---------- | ||
| 17 | +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) | ||
| 18 | + | ||
| 19 | +----------inputs for ./2stems/accompaniment.onnx---------- | ||
| 20 | +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) | ||
| 21 | +----------outputs for ./2stems/accompaniment.onnx---------- | ||
| 22 | +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) | ||
| 23 | + | ||
| 24 | +""" | ||
| 25 | + | ||
| 26 | + | ||
| 27 | +class OnnxModel: | ||
| 28 | + def __init__(self, filename): | ||
| 29 | + session_opts = ort.SessionOptions() | ||
| 30 | + session_opts.inter_op_num_threads = 1 | ||
| 31 | + session_opts.intra_op_num_threads = 1 | ||
| 32 | + | ||
| 33 | + self.session_opts = session_opts | ||
| 34 | + self.model = ort.InferenceSession( | ||
| 35 | + filename, | ||
| 36 | + sess_options=self.session_opts, | ||
| 37 | + providers=["CPUExecutionProvider"], | ||
| 38 | + ) | ||
| 39 | + | ||
| 40 | + print(f"----------inputs for {filename}----------") | ||
| 41 | + for i in self.model.get_inputs(): | ||
| 42 | + print(i) | ||
| 43 | + | ||
| 44 | + print(f"----------outputs for {filename}----------") | ||
| 45 | + | ||
| 46 | + for i in self.model.get_outputs(): | ||
| 47 | + print(i) | ||
| 48 | + print("--------------------") | ||
| 49 | + | ||
| 50 | + def __call__(self, x): | ||
| 51 | + """ | ||
| 52 | + Args: | ||
| 53 | + x: (num_splits, 2, 512, 1024) | ||
| 54 | + """ | ||
| 55 | + spec = self.model.run( | ||
| 56 | + [ | ||
| 57 | + self.model.get_outputs()[0].name, | ||
| 58 | + ], | ||
| 59 | + { | ||
| 60 | + self.model.get_inputs()[0].name: x.numpy(), | ||
| 61 | + }, | ||
| 62 | + )[0] | ||
| 63 | + | ||
| 64 | + return torch.from_numpy(spec) | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +def main(): | ||
| 68 | + vocals = OnnxModel("./2stems/vocals.onnx") | ||
| 69 | + accompaniment = OnnxModel("./2stems/accompaniment.onnx") | ||
| 70 | + | ||
| 71 | + waveform, sample_rate = load_audio("./qi-feng-le.mp3") | ||
| 72 | + waveform = waveform[: 44100 * 10, :] | ||
| 73 | + | ||
| 74 | + stft_config = knf.StftConfig( | ||
| 75 | + n_fft=4096, | ||
| 76 | + hop_length=1024, | ||
| 77 | + win_length=4096, | ||
| 78 | + center=False, | ||
| 79 | + window_type="hann", | ||
| 80 | + ) | ||
| 81 | + knf_stft = knf.Stft(stft_config) | ||
| 82 | + knf_istft = knf.IStft(stft_config) | ||
| 83 | + | ||
| 84 | + start = time.time() | ||
| 85 | + | ||
| 86 | + stft_result_c0 = knf_stft(waveform[:, 0].tolist()) | ||
| 87 | + stft_result_c1 = knf_stft(waveform[:, 1].tolist()) | ||
| 88 | + print("c0 stft", stft_result_c0.num_frames) | ||
| 89 | + | ||
| 90 | + orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape( | ||
| 91 | + stft_result_c0.num_frames, -1 | ||
| 92 | + ) | ||
| 93 | + orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape( | ||
| 94 | + stft_result_c0.num_frames, -1 | ||
| 95 | + ) | ||
| 96 | + | ||
| 97 | + orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape( | ||
| 98 | + stft_result_c1.num_frames, -1 | ||
| 99 | + ) | ||
| 100 | + orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape( | ||
| 101 | + stft_result_c1.num_frames, -1 | ||
| 102 | + ) | ||
| 103 | + | ||
| 104 | + real0 = torch.from_numpy(orig_real0) | ||
| 105 | + imag0 = torch.from_numpy(orig_imag0) | ||
| 106 | + real1 = torch.from_numpy(orig_real1) | ||
| 107 | + imag1 = torch.from_numpy(orig_imag1) | ||
| 108 | + # (num_frames, n_fft/2_1) | ||
| 109 | + print("real0", real0.shape) | ||
| 110 | + | ||
| 111 | + # keep only the first 1024 bins | ||
| 112 | + real0 = real0[:, :1024] | ||
| 113 | + imag0 = imag0[:, :1024] | ||
| 114 | + real1 = real1[:, :1024] | ||
| 115 | + imag1 = imag1[:, :1024] | ||
| 116 | + | ||
| 117 | + stft0 = (real0.square() + imag0.square()).sqrt() | ||
| 118 | + stft1 = (real1.square() + imag1.square()).sqrt() | ||
| 119 | + | ||
| 120 | + # pad it to multiple of 512 | ||
| 121 | + padding = 512 - real0.shape[0] % 512 | ||
| 122 | + print("padding", padding) | ||
| 123 | + if padding > 0: | ||
| 124 | + stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) | ||
| 125 | + stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) | ||
| 126 | + stft0 = stft0.reshape(-1, 1, 512, 1024) | ||
| 127 | + stft1 = stft1.reshape(-1, 1, 512, 1024) | ||
| 128 | + | ||
| 129 | + stft_01 = torch.cat([stft0, stft1], axis=1) | ||
| 130 | + | ||
| 131 | + print("stft_01", stft_01.shape, stft_01.dtype) | ||
| 132 | + | ||
| 133 | + vocals_spec = vocals(stft_01) | ||
| 134 | + accompaniment_spec = accompaniment(stft_01) | ||
| 135 | + # (num_splits, num_channels, 512, 1024) | ||
| 136 | + | ||
| 137 | + sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 | ||
| 138 | + | ||
| 139 | + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec | ||
| 140 | + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec | ||
| 141 | + | ||
| 142 | + for name, spec in zip( | ||
| 143 | + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] | ||
| 144 | + ): | ||
| 145 | + spec_c0 = spec[:, 0, :, :] | ||
| 146 | + spec_c1 = spec[:, 1, :, :] | ||
| 147 | + | ||
| 148 | + spec_c0 = spec_c0.reshape(-1, 1024) | ||
| 149 | + spec_c1 = spec_c1.reshape(-1, 1024) | ||
| 150 | + | ||
| 151 | + spec_c0 = spec_c0[: stft_result_c0.num_frames, :] | ||
| 152 | + spec_c1 = spec_c1[: stft_result_c0.num_frames, :] | ||
| 153 | + | ||
| 154 | + spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0)) | ||
| 155 | + spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0)) | ||
| 156 | + | ||
| 157 | + spec_c0_real = spec_c0 * orig_real0 | ||
| 158 | + spec_c0_imag = spec_c0 * orig_imag0 | ||
| 159 | + | ||
| 160 | + spec_c1_real = spec_c1 * orig_real1 | ||
| 161 | + spec_c1_imag = spec_c1 * orig_imag1 | ||
| 162 | + | ||
| 163 | + result0 = knf.StftResult( | ||
| 164 | + real=spec_c0_real.reshape(-1).tolist(), | ||
| 165 | + imag=spec_c0_imag.reshape(-1).tolist(), | ||
| 166 | + num_frames=orig_real0.shape[0], | ||
| 167 | + ) | ||
| 168 | + | ||
| 169 | + result1 = knf.StftResult( | ||
| 170 | + real=spec_c1_real.reshape(-1).tolist(), | ||
| 171 | + imag=spec_c1_imag.reshape(-1).tolist(), | ||
| 172 | + num_frames=orig_real1.shape[0], | ||
| 173 | + ) | ||
| 174 | + | ||
| 175 | + wav0 = knf_istft(result0) | ||
| 176 | + wav1 = knf_istft(result1) | ||
| 177 | + | ||
| 178 | + wav = np.array([wav0, wav1], dtype=np.float32) | ||
| 179 | + wav = np.transpose(wav) | ||
| 180 | + # now wav is (num_samples, num_channels) | ||
| 181 | + | ||
| 182 | + sf.write(f"./onnx-{name}.wav", wav, 44100) | ||
| 183 | + | ||
| 184 | + print(f"Saved to ./onnx-{name}.wav") | ||
| 185 | + | ||
| 186 | + end = time.time() | ||
| 187 | + elapsed_seconds = end - start | ||
| 188 | + audio_duration = waveform.shape[0] / sample_rate | ||
| 189 | + real_time_factor = elapsed_seconds / audio_duration | ||
| 190 | + | ||
| 191 | + print(f"Elapsed seconds: {elapsed_seconds:.3f}") | ||
| 192 | + print(f"Audio duration in seconds: {audio_duration:.3f}") | ||
| 193 | + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") | ||
| 194 | + | ||
| 195 | + | ||
| 196 | +if __name__ == "__main__": | ||
| 197 | + main() |
scripts/spleeter/unet.py
0 → 100644
| 1 | +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 2 | + | ||
| 3 | +import torch | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +class UNet(torch.nn.Module): | ||
| 7 | + def __init__(self): | ||
| 8 | + super().__init__() | ||
| 9 | + self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0) | ||
| 10 | + self.bn = torch.nn.BatchNorm2d( | ||
| 11 | + 16, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 12 | + ) | ||
| 13 | + # | ||
| 14 | + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0) | ||
| 15 | + self.bn1 = torch.nn.BatchNorm2d( | ||
| 16 | + 32, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0) | ||
| 20 | + self.bn2 = torch.nn.BatchNorm2d( | ||
| 21 | + 64, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 22 | + ) | ||
| 23 | + | ||
| 24 | + self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0) | ||
| 25 | + self.bn3 = torch.nn.BatchNorm2d( | ||
| 26 | + 128, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 27 | + ) | ||
| 28 | + | ||
| 29 | + self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0) | ||
| 30 | + self.bn4 = torch.nn.BatchNorm2d( | ||
| 31 | + 256, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0) | ||
| 35 | + | ||
| 36 | + self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2) | ||
| 37 | + self.bn5 = torch.nn.BatchNorm2d( | ||
| 38 | + 256, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2) | ||
| 42 | + self.bn6 = torch.nn.BatchNorm2d( | ||
| 43 | + 128, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 44 | + ) | ||
| 45 | + | ||
| 46 | + self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2) | ||
| 47 | + self.bn7 = torch.nn.BatchNorm2d( | ||
| 48 | + 64, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2) | ||
| 52 | + self.bn8 = torch.nn.BatchNorm2d( | ||
| 53 | + 32, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 54 | + ) | ||
| 55 | + | ||
| 56 | + self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2) | ||
| 57 | + self.bn9 = torch.nn.BatchNorm2d( | ||
| 58 | + 16, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 59 | + ) | ||
| 60 | + | ||
| 61 | + self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2) | ||
| 62 | + self.bn10 = torch.nn.BatchNorm2d( | ||
| 63 | + 1, track_running_stats=True, eps=1e-3, momentum=0.01 | ||
| 64 | + ) | ||
| 65 | + | ||
| 66 | + # output logit is False, so we need self.up7 | ||
| 67 | + self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) | ||
| 68 | + | ||
| 69 | + def forward(self, x): | ||
| 70 | + in_x = x | ||
| 71 | + # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) | ||
| 72 | + x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) | ||
| 73 | + conv1 = self.conv(x) | ||
| 74 | + batch1 = self.bn(conv1) | ||
| 75 | + rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2) | ||
| 76 | + | ||
| 77 | + x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0) | ||
| 78 | + conv2 = self.conv1(x) # (3, 32, 128, 256) | ||
| 79 | + batch2 = self.bn1(conv2) | ||
| 80 | + rel2 = torch.nn.functional.leaky_relu( | ||
| 81 | + batch2, negative_slope=0.2 | ||
| 82 | + ) # (3, 32, 128, 256) | ||
| 83 | + | ||
| 84 | + x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0) | ||
| 85 | + conv3 = self.conv2(x) # (3, 64, 64, 128) | ||
| 86 | + batch3 = self.bn2(conv3) | ||
| 87 | + rel3 = torch.nn.functional.leaky_relu( | ||
| 88 | + batch3, negative_slope=0.2 | ||
| 89 | + ) # (3, 64, 64, 128) | ||
| 90 | + | ||
| 91 | + x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0) | ||
| 92 | + conv4 = self.conv3(x) # (3, 128, 32, 64) | ||
| 93 | + batch4 = self.bn3(conv4) | ||
| 94 | + rel4 = torch.nn.functional.leaky_relu( | ||
| 95 | + batch4, negative_slope=0.2 | ||
| 96 | + ) # (3, 128, 32, 64) | ||
| 97 | + | ||
| 98 | + x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0) | ||
| 99 | + conv5 = self.conv4(x) # (3, 256, 16, 32) | ||
| 100 | + batch5 = self.bn4(conv5) | ||
| 101 | + rel6 = torch.nn.functional.leaky_relu( | ||
| 102 | + batch5, negative_slope=0.2 | ||
| 103 | + ) # (3, 256, 16, 32) | ||
| 104 | + | ||
| 105 | + x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0) | ||
| 106 | + conv6 = self.conv5(x) # (3, 512, 8, 16) | ||
| 107 | + | ||
| 108 | + up1 = self.up1(conv6) | ||
| 109 | + up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32) | ||
| 110 | + up1 = torch.nn.functional.relu(up1) | ||
| 111 | + batch7 = self.bn5(up1) | ||
| 112 | + merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32) | ||
| 113 | + | ||
| 114 | + up2 = self.up2(merge1) | ||
| 115 | + up2 = up2[:, :, 1:-2, 1:-2] | ||
| 116 | + up2 = torch.nn.functional.relu(up2) | ||
| 117 | + batch8 = self.bn6(up2) | ||
| 118 | + | ||
| 119 | + merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64) | ||
| 120 | + | ||
| 121 | + up3 = self.up3(merge2) | ||
| 122 | + up3 = up3[:, :, 1:-2, 1:-2] | ||
| 123 | + up3 = torch.nn.functional.relu(up3) | ||
| 124 | + batch9 = self.bn7(up3) | ||
| 125 | + | ||
| 126 | + merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128) | ||
| 127 | + | ||
| 128 | + up4 = self.up4(merge3) | ||
| 129 | + up4 = up4[:, :, 1:-2, 1:-2] | ||
| 130 | + up4 = torch.nn.functional.relu(up4) | ||
| 131 | + batch10 = self.bn8(up4) | ||
| 132 | + | ||
| 133 | + merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256) | ||
| 134 | + | ||
| 135 | + up5 = self.up5(merge4) | ||
| 136 | + up5 = up5[:, :, 1:-2, 1:-2] | ||
| 137 | + up5 = torch.nn.functional.relu(up5) | ||
| 138 | + batch11 = self.bn9(up5) | ||
| 139 | + | ||
| 140 | + merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512) | ||
| 141 | + | ||
| 142 | + up6 = self.up6(merge5) | ||
| 143 | + up6 = up6[:, :, 1:-2, 1:-2] | ||
| 144 | + up6 = torch.nn.functional.relu(up6) | ||
| 145 | + batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024) | ||
| 146 | + | ||
| 147 | + up7 = self.up7(batch12) | ||
| 148 | + up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) | ||
| 149 | + | ||
| 150 | + return up7 * in_x |
-
请 注册 或 登录 后发表评论