Fangjun Kuang
Committed by GitHub

Export spleeter model to onnx for source separation (#2237)

  1 +name: export-spleeter-to-onnx
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - spleeter-2
  7 + workflow_dispatch:
  8 +
  9 +concurrency:
  10 + group: export-spleeter-to-onnx-${{ github.ref }}
  11 + cancel-in-progress: true
  12 +
  13 +jobs:
  14 + export-spleeter-to-onnx:
  15 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  16 + name: export spleeter to ONNX
  17 + runs-on: ${{ matrix.os }}
  18 + strategy:
  19 + fail-fast: false
  20 + matrix:
  21 + os: [macos-latest]
  22 + python-version: ["3.10"]
  23 +
  24 + steps:
  25 + - uses: actions/checkout@v4
  26 +
  27 + - name: Setup Python ${{ matrix.python-version }}
  28 + uses: actions/setup-python@v5
  29 + with:
  30 + python-version: ${{ matrix.python-version }}
  31 +
  32 + - name: Install dependencies
  33 + shell: bash
  34 + run: |
  35 + pip install tensorflow torch "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools
  36 +
  37 + - name: Run
  38 + shell: bash
  39 + run: |
  40 + cd scripts/spleeter
  41 + ./run.sh
  42 +
  43 + echo "---"
  44 + ls -lh 2stems
  45 + echo "---"
  46 + ls -lh 2stems/*.onnx
  47 + echo "---"
  48 +
  49 + mv -v 2stems/*.onnx ../..
  50 +
  51 + - name: Collect models
  52 + shell: bash
  53 + run: |
  54 + mkdir sherpa-onnx-spleeter-2stems
  55 + mkdir sherpa-onnx-spleeter-2stems-int8
  56 + mkdir sherpa-onnx-spleeter-2stems-fp16
  57 +
  58 + mv -v vocals.onnx sherpa-onnx-spleeter-2stems/
  59 + mv -v accompaniment.onnx sherpa-onnx-spleeter-2stems/
  60 +
  61 + mv -v vocals.int8.onnx sherpa-onnx-spleeter-2stems-int8/
  62 + mv -v accompaniment.int8.onnx sherpa-onnx-spleeter-2stems-int8/
  63 +
  64 + mv -v vocals.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
  65 + mv -v accompaniment.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
  66 +
  67 + tar cjvf sherpa-onnx-spleeter-2stems.tar.bz2 sherpa-onnx-spleeter-2stems
  68 + tar cjvf sherpa-onnx-spleeter-2stems-int8.tar.bz2 sherpa-onnx-spleeter-2stems-int8
  69 + tar cjvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 sherpa-onnx-spleeter-2stems-fp16
  70 +
  71 + ls -lh *.tar.bz2
  72 +
  73 + - name: Release
  74 + uses: svenstaro/upload-release-action@v2
  75 + with:
  76 + file_glob: true
  77 + file: ./*.tar.bz2
  78 + overwrite: true
  79 + repo_name: k2-fsa/sherpa-onnx
  80 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  81 + tag: source-separation-models
  82 +
  83 + - name: Publish to huggingface
  84 + env:
  85 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  86 + uses: nick-fields/retry@v3
  87 + with:
  88 + max_attempts: 20
  89 + timeout_seconds: 200
  90 + shell: bash
  91 + command: |
  92 + git config --global user.email "csukuangfj@gmail.com"
  93 + git config --global user.name "Fangjun Kuang"
  94 +
  95 + export GIT_LFS_SKIP_SMUDGE=1
  96 + export GIT_CLONE_PROTECTION_ACTIVE=false
  97 +
  98 + names=(
  99 + sherpa-onnx-spleeter-2stems
  100 + sherpa-onnx-spleeter-2stems-int8
  101 + sherpa-onnx-spleeter-2stems-fp16
  102 + )
  103 + for d in ${names[@]}; do
  104 + rm -rf huggingface
  105 + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
  106 + cp -v $d/*onnx huggingface
  107 +
  108 + cd huggingface
  109 + git lfs track "*.onnx"
  110 + git status
  111 + git add .
  112 + ls -lh
  113 + git status
  114 + git commit -m "add models"
  115 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
  116 + cd ..
  117 + done
  1 +2stems.tar.gz
  2 +2stems
  1 +#!/usr/bin/env python3
  2 +
  3 +# Code in this file is modified from
  4 +# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
  5 +#
  6 +# Please see ./run.sh for usages
  7 +import argparse
  8 +
  9 +import tensorflow as tf
  10 +
  11 +
  12 +def freeze_graph(model_dir, output_node_names, output_filename):
  13 + """Extract the sub graph defined by the output nodes and convert all its
  14 + variables into constant
  15 +
  16 + Args:
  17 + model_dir:
  18 + the root folder containing the checkpoint state file
  19 + output_node_names:
  20 + a string, containing all the output node's names, comma separated
  21 + output_filename:
  22 + Filename to save the graph.
  23 + """
  24 + if not tf.compat.v1.gfile.Exists(model_dir):
  25 + raise AssertionError(
  26 + "Export directory doesn't exists. Please specify an export "
  27 + "directory: %s" % model_dir
  28 + )
  29 +
  30 + if not output_node_names:
  31 + print("You need to supply the name of a node to --output_node_names.")
  32 + return -1
  33 +
  34 + # We retrieve our checkpoint fullpath
  35 + checkpoint = tf.train.get_checkpoint_state(model_dir)
  36 + input_checkpoint = checkpoint.model_checkpoint_path
  37 +
  38 + # We precise the file fullname of our freezed graph
  39 + output_graph = output_filename
  40 +
  41 + # We clear devices to allow TensorFlow to control on which device it will load operations
  42 + clear_devices = True
  43 +
  44 + # We start a session using a temporary fresh Graph
  45 + with tf.compat.v1.Session(graph=tf.Graph()) as sess:
  46 + # We import the meta graph in the current default Graph
  47 + saver = tf.compat.v1.train.import_meta_graph(
  48 + input_checkpoint + ".meta", clear_devices=clear_devices
  49 + )
  50 +
  51 + # We restore the weights
  52 + saver.restore(sess, input_checkpoint)
  53 +
  54 + # We use a built-in TF helper to export variables to constants
  55 + output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
  56 + sess, # The session is used to retrieve the weights
  57 + tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
  58 + output_node_names.split(
  59 + ","
  60 + ), # The output node names are used to select the usefull nodes
  61 + )
  62 +
  63 + # Finally we serialize and dump the output graph to the filesystem
  64 + with tf.compat.v1.gfile.GFile(output_graph, "wb") as f:
  65 + f.write(output_graph_def.SerializeToString())
  66 + print("%d ops in the final graph." % len(output_graph_def.node))
  67 +
  68 + return output_graph_def
  69 +
  70 +
  71 +if __name__ == "__main__":
  72 + parser = argparse.ArgumentParser()
  73 + parser.add_argument(
  74 + "--model-dir", type=str, default="", help="Model folder to export"
  75 + )
  76 + parser.add_argument(
  77 + "--output-node-names",
  78 + type=str,
  79 + default="vocals_spectrogram/mul,accompaniment_spectrogram/mul",
  80 + help="The name of the output nodes, comma separated.",
  81 + )
  82 +
  83 + parser.add_argument(
  84 + "--output-filename",
  85 + type=str,
  86 + )
  87 + args = parser.parse_args()
  88 +
  89 + freeze_graph(args.model_dir, args.output_node_names, args.output_filename)
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# Please see ./run.sh for usage
  5 +
  6 +import argparse
  7 +
  8 +import numpy as np
  9 +import tensorflow as tf
  10 +import torch
  11 +
  12 +from unet import UNet
  13 +
  14 +
  15 +def load_graph(frozen_graph_filename):
  16 + # This function is modified from
  17 + # https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
  18 +
  19 + # We load the protobuf file from the disk and parse it to retrieve the
  20 + # unserialized graph_def
  21 + with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
  22 + graph_def = tf.compat.v1.GraphDef()
  23 + graph_def.ParseFromString(f.read())
  24 +
  25 + # Then, we import the graph_def into a new Graph and returns it
  26 + with tf.Graph().as_default() as graph:
  27 + # The name var will prefix every op/nodes in your graph
  28 + # Since we load everything in a new graph, this is not needed
  29 + # tf.import_graph_def(graph_def, name="prefix")
  30 + tf.import_graph_def(graph_def, name="")
  31 + return graph
  32 +
  33 +
  34 +def generate_waveform():
  35 + np.random.seed(20230821)
  36 + waveform = np.random.rand(60 * 44100).astype(np.float32)
  37 +
  38 + # (num_samples, num_channels)
  39 + waveform = waveform.reshape(-1, 2)
  40 + return waveform
  41 +
  42 +
  43 +def get_param(graph, name):
  44 + with tf.compat.v1.Session(graph=graph) as sess:
  45 + constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  46 + for constant_op in constant_ops:
  47 + if constant_op.name != name:
  48 + continue
  49 +
  50 + value = sess.run(constant_op.outputs[0])
  51 + return torch.from_numpy(value)
  52 +
  53 +
  54 +@torch.no_grad()
  55 +def main(name):
  56 + graph = load_graph(f"./2stems/frozen_{name}_model.pb")
  57 + # for op in graph.get_operations():
  58 + # print(op.name)
  59 + x = graph.get_tensor_by_name("waveform:0")
  60 + # y = graph.get_tensor_by_name("Reshape:0")
  61 + y0 = graph.get_tensor_by_name("strided_slice_3:0")
  62 + # y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0")
  63 + # y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0")
  64 + # y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0")
  65 + # y1 = graph.get_tensor_by_name("re_lu/Relu:0")
  66 + # y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0")
  67 + # y1 = graph.get_tensor_by_name("concatenate/concat:0")
  68 + # y1 = graph.get_tensor_by_name("concatenate_1/concat:0")
  69 + # y1 = graph.get_tensor_by_name("concatenate_4/concat:0")
  70 + # y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0")
  71 + # y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0")
  72 + y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0")
  73 +
  74 + unet = UNet()
  75 + unet.eval()
  76 +
  77 + # For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel)
  78 + # default input shape is NHWC
  79 +
  80 + # For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w)
  81 + # default input shape is NCHW
  82 + state_dict = unet.state_dict()
  83 + # print(list(state_dict.keys()))
  84 +
  85 + if name == "vocals":
  86 + state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute(
  87 + 3, 2, 0, 1
  88 + )
  89 + state_dict["conv.bias"] = get_param(graph, "conv2d/bias")
  90 +
  91 + state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma")
  92 + state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta")
  93 + state_dict["bn.running_mean"] = get_param(
  94 + graph, "batch_normalization/moving_mean"
  95 + )
  96 + state_dict["bn.running_var"] = get_param(
  97 + graph, "batch_normalization/moving_variance"
  98 + )
  99 +
  100 + conv_offset = 0
  101 + bn_offset = 0
  102 + else:
  103 + state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute(
  104 + 3, 2, 0, 1
  105 + )
  106 + state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias")
  107 +
  108 + state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma")
  109 + state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta")
  110 + state_dict["bn.running_mean"] = get_param(
  111 + graph, "batch_normalization_12/moving_mean"
  112 + )
  113 + state_dict["bn.running_var"] = get_param(
  114 + graph, "batch_normalization_12/moving_variance"
  115 + )
  116 + conv_offset = 7
  117 + bn_offset = 12
  118 +
  119 + for i in range(1, 6):
  120 + state_dict[f"conv{i}.weight"] = get_param(
  121 + graph, f"conv2d_{i+conv_offset}/kernel"
  122 + ).permute(3, 2, 0, 1)
  123 + state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias")
  124 + if i >= 5:
  125 + continue
  126 + state_dict[f"bn{i}.weight"] = get_param(
  127 + graph, f"batch_normalization_{i+bn_offset}/gamma"
  128 + )
  129 + state_dict[f"bn{i}.bias"] = get_param(
  130 + graph, f"batch_normalization_{i+bn_offset}/beta"
  131 + )
  132 + state_dict[f"bn{i}.running_mean"] = get_param(
  133 + graph, f"batch_normalization_{i+bn_offset}/moving_mean"
  134 + )
  135 + state_dict[f"bn{i}.running_var"] = get_param(
  136 + graph, f"batch_normalization_{i+bn_offset}/moving_variance"
  137 + )
  138 +
  139 + if name == "vocals":
  140 + state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute(
  141 + 3, 2, 0, 1
  142 + )
  143 + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias")
  144 +
  145 + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma")
  146 + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta")
  147 + state_dict["bn5.running_mean"] = get_param(
  148 + graph, "batch_normalization_6/moving_mean"
  149 + )
  150 + state_dict["bn5.running_var"] = get_param(
  151 + graph, "batch_normalization_6/moving_variance"
  152 + )
  153 + conv_offset = 0
  154 + bn_offset = 0
  155 + else:
  156 + state_dict["up1.weight"] = get_param(
  157 + graph, "conv2d_transpose_6/kernel"
  158 + ).permute(3, 2, 0, 1)
  159 + state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias")
  160 +
  161 + state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma")
  162 + state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta")
  163 + state_dict["bn5.running_mean"] = get_param(
  164 + graph, "batch_normalization_18/moving_mean"
  165 + )
  166 + state_dict["bn5.running_var"] = get_param(
  167 + graph, "batch_normalization_18/moving_variance"
  168 + )
  169 + conv_offset = 6
  170 + bn_offset = 12
  171 +
  172 + for i in range(1, 6):
  173 + state_dict[f"up{i+1}.weight"] = get_param(
  174 + graph, f"conv2d_transpose_{i+conv_offset}/kernel"
  175 + ).permute(3, 2, 0, 1)
  176 +
  177 + state_dict[f"up{i+1}.bias"] = get_param(
  178 + graph, f"conv2d_transpose_{i+conv_offset}/bias"
  179 + )
  180 +
  181 + state_dict[f"bn{5+i}.weight"] = get_param(
  182 + graph, f"batch_normalization_{6+i+bn_offset}/gamma"
  183 + )
  184 + state_dict[f"bn{5+i}.bias"] = get_param(
  185 + graph, f"batch_normalization_{6+i+bn_offset}/beta"
  186 + )
  187 + state_dict[f"bn{5+i}.running_mean"] = get_param(
  188 + graph, f"batch_normalization_{6+i+bn_offset}/moving_mean"
  189 + )
  190 + state_dict[f"bn{5+i}.running_var"] = get_param(
  191 + graph, f"batch_normalization_{6+i+bn_offset}/moving_variance"
  192 + )
  193 +
  194 + if name == "vocals":
  195 + state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute(
  196 + 3, 2, 0, 1
  197 + )
  198 + state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias")
  199 + else:
  200 + state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute(
  201 + 3, 2, 0, 1
  202 + )
  203 + state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias")
  204 +
  205 + unet.load_state_dict(state_dict)
  206 +
  207 + with tf.compat.v1.Session(graph=graph) as sess:
  208 + y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()})
  209 + # y0_out = sess.run(y0, feed_dict={x: generate_waveform()})
  210 + # y1_out = sess.run(y1, feed_dict={x: generate_waveform()})
  211 + # print(y0_out.shape)
  212 + # print(y1_out.shape)
  213 +
  214 + # for the batchnormalization in tensorflow,
  215 + # default input shape is NHWC
  216 +
  217 + # for the batchnormalization in torch,
  218 + # default input shape is NCHW
  219 +
  220 + # NHWC to NCHW
  221 + torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
  222 +
  223 + # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
  224 + assert torch.allclose(
  225 + torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1
  226 + ), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max())
  227 + torch.save(unet.state_dict(), f"2stems/{name}.pt")
  228 +
  229 +
  230 +if __name__ == "__main__":
  231 + parser = argparse.ArgumentParser()
  232 + parser.add_argument(
  233 + "--name",
  234 + type=str,
  235 + required=True,
  236 + choices=["vocals", "accompaniment"],
  237 + )
  238 + args = parser.parse_args()
  239 + print(vars(args))
  240 + main(args.name)
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnx
  5 +import onnxmltools
  6 +import torch
  7 +from onnxmltools.utils.float16_converter import convert_float_to_float16
  8 +from onnxruntime.quantization import QuantType, quantize_dynamic
  9 +
  10 +from unet import UNet
  11 +
  12 +
  13 +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
  14 + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
  15 + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
  16 + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
  17 +
  18 +
  19 +def add_meta_data(filename, prefix):
  20 + meta_data = {
  21 + "model_type": "spleeter",
  22 + "sample_rate": 41000,
  23 + "version": 1,
  24 + "model_url": "https://github.com/deezer/spleeter",
  25 + "stems": 2,
  26 + "comment": prefix,
  27 + "model_name": "2stems.tar.gz",
  28 + }
  29 + model = onnx.load(filename)
  30 +
  31 + print(model.metadata_props)
  32 +
  33 + while len(model.metadata_props):
  34 + model.metadata_props.pop()
  35 +
  36 + for key, value in meta_data.items():
  37 + meta = model.metadata_props.add()
  38 + meta.key = key
  39 + meta.value = str(value)
  40 + print("--------------------")
  41 +
  42 + print(model.metadata_props)
  43 +
  44 + onnx.save(model, filename)
  45 +
  46 +
  47 +def export(model, prefix):
  48 + num_splits = 1
  49 + x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
  50 +
  51 + filename = f"./2stems/{prefix}.onnx"
  52 + torch.onnx.export(
  53 + model,
  54 + x,
  55 + filename,
  56 + input_names=["x"],
  57 + output_names=["y"],
  58 + dynamic_axes={
  59 + "x": {0: "num_splits"},
  60 + },
  61 + opset_version=13,
  62 + )
  63 +
  64 + add_meta_data(filename, prefix)
  65 +
  66 + filename_int8 = f"./2stems/{prefix}.int8.onnx"
  67 + quantize_dynamic(
  68 + model_input=filename,
  69 + model_output=filename_int8,
  70 + weight_type=QuantType.QUInt8,
  71 + )
  72 +
  73 + filename_fp16 = f"./2stems/{prefix}.fp16.onnx"
  74 + export_onnx_fp16(filename, filename_fp16)
  75 +
  76 +
  77 +@torch.no_grad()
  78 +def main():
  79 + vocals = UNet()
  80 + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
  81 + vocals.load_state_dict(state_dict)
  82 + vocals.eval()
  83 +
  84 + accompaniment = UNet()
  85 + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
  86 + accompaniment.load_state_dict(state_dict)
  87 + accompaniment.eval()
  88 +
  89 + export(vocals, "vocals")
  90 + export(accompaniment, "accompaniment")
  91 +
  92 +
  93 +if __name__ == "__main__":
  94 + main()
  1 +#!/usr/bin/env bash
  2 +
  3 +
  4 +if [ ! -f 2stems.tar.gz ]; then
  5 + curl -SL -O https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz
  6 +fi
  7 +
  8 +if [ ! -d ./2stems ]; then
  9 + mkdir -p 2stems
  10 + cd 2stems
  11 + tar xvf ../2stems.tar.gz
  12 + cd ..
  13 +fi
  14 +
  15 +ls -lh
  16 +
  17 +ls -lh 2stems
  18 +
  19 +if [ ! -f 2stems/frozen_vocals_model.pb ]; then
  20 + python3 ./convert_to_pb.py \
  21 + --model-dir ./2stems \
  22 + --output-node-names vocals_spectrogram/mul \
  23 + --output-filename ./2stems/frozen_vocals_model.pb
  24 +fi
  25 +
  26 +ls -lh 2stems
  27 +
  28 +if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then
  29 + python3 ./convert_to_pb.py \
  30 + --model-dir ./2stems \
  31 + --output-node-names accompaniment_spectrogram/mul \
  32 + --output-filename ./2stems/frozen_accompaniment_model.pb
  33 +fi
  34 +
  35 +ls -lh 2stems
  36 +
  37 +python3 ./convert_to_torch.py --name vocals
  38 +python3 ./convert_to_torch.py --name accompaniment
  39 +python3 ./export_onnx.py
  40 +
  41 +ls -lh 2stems
  1 +#!/usr/bin/env python3
  2 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# Please see ./run.sh for usage
  5 +
  6 +from typing import Optional
  7 +
  8 +import ffmpeg
  9 +import numpy as np
  10 +import soundfile as sf
  11 +import torch
  12 +from pydub import AudioSegment
  13 +
  14 +from unet import UNet
  15 +
  16 +
  17 +def load_audio(filename, sample_rate: Optional[int] = 44100):
  18 + probe = ffmpeg.probe(filename)
  19 + if "streams" not in probe or len(probe["streams"]) == 0:
  20 + raise ValueError("No stream was found with ffprobe")
  21 +
  22 + metadata = next(
  23 + stream for stream in probe["streams"] if stream["codec_type"] == "audio"
  24 + )
  25 + n_channels = metadata["channels"]
  26 +
  27 + if sample_rate is None:
  28 + sample_rate = metadata["sample_rate"]
  29 +
  30 + process = (
  31 + ffmpeg.input(filename)
  32 + .output("pipe:", format="f32le", ar=sample_rate)
  33 + .run_async(pipe_stdout=True, pipe_stderr=True)
  34 + )
  35 + buffer, _ = process.communicate()
  36 + waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
  37 +
  38 + waveform = torch.from_numpy(np.copy(waveform)).to(torch.float32)
  39 + if n_channels == 1:
  40 + waveform = waveform.tile(1, 2)
  41 +
  42 + if n_channels > 2:
  43 + waveform = waveform[:, :2]
  44 +
  45 + return waveform, sample_rate
  46 +
  47 +
  48 +@torch.no_grad()
  49 +def main():
  50 + vocals = UNet()
  51 + vocals.eval()
  52 + state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
  53 + vocals.load_state_dict(state_dict)
  54 +
  55 + accompaniment = UNet()
  56 + accompaniment.eval()
  57 + state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
  58 + accompaniment.load_state_dict(state_dict)
  59 +
  60 + #
  61 + # waveform, sample_rate = load_audio("./audio_example.mp3")
  62 +
  63 + # You can download the following two mp3 from
  64 + # https://huggingface.co/spaces/csukuangfj/music-source-separation/tree/main/examples
  65 + waveform, sample_rate = load_audio("./qi-feng-le.mp3")
  66 + # waveform, sample_rate = load_audio("./Yesterday_Once_More-Carpenters.mp3")
  67 + assert waveform.shape[1] == 2, waveform.shape
  68 +
  69 + waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096))
  70 +
  71 + # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform
  72 + stft = torch.stft(
  73 + waveform.t(),
  74 + n_fft=4096,
  75 + hop_length=1024,
  76 + window=torch.hann_window(4096, periodic=True),
  77 + center=False,
  78 + onesided=True,
  79 + return_complex=True,
  80 + )
  81 + print("stft", stft.shape)
  82 +
  83 + # stft: (2, 2049, 465)
  84 + # stft is a complex tensor
  85 + y = stft.permute(2, 1, 0)
  86 + print("y0", y.shape)
  87 + # (465, 2049, 2)
  88 +
  89 + y = y[:, :1024, :]
  90 + # (465, 1024, 2)
  91 +
  92 + tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512
  93 + pad_size = 512 - tensor_size
  94 + y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size))
  95 + # (512, 1024, 2)
  96 + print("y1", y.shape, y.dtype)
  97 +
  98 + num_splits = int(y.shape[0] / 512)
  99 + y = y.reshape([num_splits, 512] + list(y.shape[1:]))
  100 + # y: (1, 512, 1024, 2)
  101 + print("y2", y.shape, y.dtype)
  102 +
  103 + y = y.abs()
  104 + y = y.permute(0, 3, 1, 2)
  105 + # (1, 2, 512, 1024)
  106 + print("y3", y.shape, y.dtype)
  107 +
  108 + vocals_spec = vocals(y)
  109 + accompaniment_spec = accompaniment(y)
  110 +
  111 + sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
  112 + print(
  113 + "vocals_spec",
  114 + vocals_spec.shape,
  115 + accompaniment_spec.shape,
  116 + sum_spec.shape,
  117 + vocals_spec.dtype,
  118 + )
  119 +
  120 + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
  121 + # (1, 2, 512, 1024)
  122 +
  123 + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
  124 + # (1, 2, 512, 1024)
  125 +
  126 + for name, spec in zip(
  127 + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
  128 + ):
  129 + spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0))
  130 + # (1, 2, 512, 2049)
  131 +
  132 + spec = spec.permute(0, 2, 3, 1)
  133 + # (1, 512, 2049, 2)
  134 + print("here00", spec.shape)
  135 +
  136 + spec = spec.reshape(-1, spec.shape[2], spec.shape[3])
  137 + # (512, 2049, 2)
  138 +
  139 + print("here2", spec.shape)
  140 + # (512, 2049, 2)
  141 +
  142 + spec = spec[: stft.shape[2], :, :]
  143 + # (465, 2049, 2)
  144 + print("here 3", spec.shape, stft.shape)
  145 +
  146 + spec = spec.permute(2, 1, 0)
  147 + # (2, 2049, 465)
  148 +
  149 + masked_stft = spec * stft
  150 +
  151 + wave = torch.istft(
  152 + masked_stft,
  153 + 4096,
  154 + 1024,
  155 + window=torch.hann_window(4096, periodic=True),
  156 + onesided=True,
  157 + ) * (2 / 3)
  158 +
  159 + print(wave.shape, wave.dtype)
  160 + sf.write(f"{name}.wav", wave.t(), 44100)
  161 +
  162 + wave = (wave.t() * 32768).to(torch.int16)
  163 + sound = AudioSegment(
  164 + data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
  165 + )
  166 + sound.export(f"{name}.mp3", format="mp3", bitrate="128k")
  167 +
  168 +
  169 +if __name__ == "__main__":
  170 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +import time
  4 +
  5 +import kaldi_native_fbank as knf
  6 +import numpy as np
  7 +import onnxruntime as ort
  8 +import soundfile as sf
  9 +import torch
  10 +
  11 +from separate import load_audio
  12 +
  13 +"""
  14 +----------inputs for ./2stems/vocals.onnx----------
  15 +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
  16 +----------outputs for ./2stems/vocals.onnx----------
  17 +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
  18 +
  19 +----------inputs for ./2stems/accompaniment.onnx----------
  20 +NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
  21 +----------outputs for ./2stems/accompaniment.onnx----------
  22 +NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
  23 +
  24 +"""
  25 +
  26 +
  27 +class OnnxModel:
  28 + def __init__(self, filename):
  29 + session_opts = ort.SessionOptions()
  30 + session_opts.inter_op_num_threads = 1
  31 + session_opts.intra_op_num_threads = 1
  32 +
  33 + self.session_opts = session_opts
  34 + self.model = ort.InferenceSession(
  35 + filename,
  36 + sess_options=self.session_opts,
  37 + providers=["CPUExecutionProvider"],
  38 + )
  39 +
  40 + print(f"----------inputs for {filename}----------")
  41 + for i in self.model.get_inputs():
  42 + print(i)
  43 +
  44 + print(f"----------outputs for {filename}----------")
  45 +
  46 + for i in self.model.get_outputs():
  47 + print(i)
  48 + print("--------------------")
  49 +
  50 + def __call__(self, x):
  51 + """
  52 + Args:
  53 + x: (num_splits, 2, 512, 1024)
  54 + """
  55 + spec = self.model.run(
  56 + [
  57 + self.model.get_outputs()[0].name,
  58 + ],
  59 + {
  60 + self.model.get_inputs()[0].name: x.numpy(),
  61 + },
  62 + )[0]
  63 +
  64 + return torch.from_numpy(spec)
  65 +
  66 +
  67 +def main():
  68 + vocals = OnnxModel("./2stems/vocals.onnx")
  69 + accompaniment = OnnxModel("./2stems/accompaniment.onnx")
  70 +
  71 + waveform, sample_rate = load_audio("./qi-feng-le.mp3")
  72 + waveform = waveform[: 44100 * 10, :]
  73 +
  74 + stft_config = knf.StftConfig(
  75 + n_fft=4096,
  76 + hop_length=1024,
  77 + win_length=4096,
  78 + center=False,
  79 + window_type="hann",
  80 + )
  81 + knf_stft = knf.Stft(stft_config)
  82 + knf_istft = knf.IStft(stft_config)
  83 +
  84 + start = time.time()
  85 +
  86 + stft_result_c0 = knf_stft(waveform[:, 0].tolist())
  87 + stft_result_c1 = knf_stft(waveform[:, 1].tolist())
  88 + print("c0 stft", stft_result_c0.num_frames)
  89 +
  90 + orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape(
  91 + stft_result_c0.num_frames, -1
  92 + )
  93 + orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape(
  94 + stft_result_c0.num_frames, -1
  95 + )
  96 +
  97 + orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape(
  98 + stft_result_c1.num_frames, -1
  99 + )
  100 + orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape(
  101 + stft_result_c1.num_frames, -1
  102 + )
  103 +
  104 + real0 = torch.from_numpy(orig_real0)
  105 + imag0 = torch.from_numpy(orig_imag0)
  106 + real1 = torch.from_numpy(orig_real1)
  107 + imag1 = torch.from_numpy(orig_imag1)
  108 + # (num_frames, n_fft/2_1)
  109 + print("real0", real0.shape)
  110 +
  111 + # keep only the first 1024 bins
  112 + real0 = real0[:, :1024]
  113 + imag0 = imag0[:, :1024]
  114 + real1 = real1[:, :1024]
  115 + imag1 = imag1[:, :1024]
  116 +
  117 + stft0 = (real0.square() + imag0.square()).sqrt()
  118 + stft1 = (real1.square() + imag1.square()).sqrt()
  119 +
  120 + # pad it to multiple of 512
  121 + padding = 512 - real0.shape[0] % 512
  122 + print("padding", padding)
  123 + if padding > 0:
  124 + stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
  125 + stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
  126 + stft0 = stft0.reshape(-1, 1, 512, 1024)
  127 + stft1 = stft1.reshape(-1, 1, 512, 1024)
  128 +
  129 + stft_01 = torch.cat([stft0, stft1], axis=1)
  130 +
  131 + print("stft_01", stft_01.shape, stft_01.dtype)
  132 +
  133 + vocals_spec = vocals(stft_01)
  134 + accompaniment_spec = accompaniment(stft_01)
  135 + # (num_splits, num_channels, 512, 1024)
  136 +
  137 + sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
  138 +
  139 + vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
  140 + accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
  141 +
  142 + for name, spec in zip(
  143 + ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
  144 + ):
  145 + spec_c0 = spec[:, 0, :, :]
  146 + spec_c1 = spec[:, 1, :, :]
  147 +
  148 + spec_c0 = spec_c0.reshape(-1, 1024)
  149 + spec_c1 = spec_c1.reshape(-1, 1024)
  150 +
  151 + spec_c0 = spec_c0[: stft_result_c0.num_frames, :]
  152 + spec_c1 = spec_c1[: stft_result_c0.num_frames, :]
  153 +
  154 + spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0))
  155 + spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0))
  156 +
  157 + spec_c0_real = spec_c0 * orig_real0
  158 + spec_c0_imag = spec_c0 * orig_imag0
  159 +
  160 + spec_c1_real = spec_c1 * orig_real1
  161 + spec_c1_imag = spec_c1 * orig_imag1
  162 +
  163 + result0 = knf.StftResult(
  164 + real=spec_c0_real.reshape(-1).tolist(),
  165 + imag=spec_c0_imag.reshape(-1).tolist(),
  166 + num_frames=orig_real0.shape[0],
  167 + )
  168 +
  169 + result1 = knf.StftResult(
  170 + real=spec_c1_real.reshape(-1).tolist(),
  171 + imag=spec_c1_imag.reshape(-1).tolist(),
  172 + num_frames=orig_real1.shape[0],
  173 + )
  174 +
  175 + wav0 = knf_istft(result0)
  176 + wav1 = knf_istft(result1)
  177 +
  178 + wav = np.array([wav0, wav1], dtype=np.float32)
  179 + wav = np.transpose(wav)
  180 + # now wav is (num_samples, num_channels)
  181 +
  182 + sf.write(f"./onnx-{name}.wav", wav, 44100)
  183 +
  184 + print(f"Saved to ./onnx-{name}.wav")
  185 +
  186 + end = time.time()
  187 + elapsed_seconds = end - start
  188 + audio_duration = waveform.shape[0] / sample_rate
  189 + real_time_factor = elapsed_seconds / audio_duration
  190 +
  191 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  192 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  193 + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
  194 +
  195 +
  196 +if __name__ == "__main__":
  197 + main()
  1 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
  2 +
  3 +import torch
  4 +
  5 +
  6 +class UNet(torch.nn.Module):
  7 + def __init__(self):
  8 + super().__init__()
  9 + self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0)
  10 + self.bn = torch.nn.BatchNorm2d(
  11 + 16, track_running_stats=True, eps=1e-3, momentum=0.01
  12 + )
  13 + #
  14 + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0)
  15 + self.bn1 = torch.nn.BatchNorm2d(
  16 + 32, track_running_stats=True, eps=1e-3, momentum=0.01
  17 + )
  18 +
  19 + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0)
  20 + self.bn2 = torch.nn.BatchNorm2d(
  21 + 64, track_running_stats=True, eps=1e-3, momentum=0.01
  22 + )
  23 +
  24 + self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0)
  25 + self.bn3 = torch.nn.BatchNorm2d(
  26 + 128, track_running_stats=True, eps=1e-3, momentum=0.01
  27 + )
  28 +
  29 + self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0)
  30 + self.bn4 = torch.nn.BatchNorm2d(
  31 + 256, track_running_stats=True, eps=1e-3, momentum=0.01
  32 + )
  33 +
  34 + self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0)
  35 +
  36 + self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
  37 + self.bn5 = torch.nn.BatchNorm2d(
  38 + 256, track_running_stats=True, eps=1e-3, momentum=0.01
  39 + )
  40 +
  41 + self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2)
  42 + self.bn6 = torch.nn.BatchNorm2d(
  43 + 128, track_running_stats=True, eps=1e-3, momentum=0.01
  44 + )
  45 +
  46 + self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2)
  47 + self.bn7 = torch.nn.BatchNorm2d(
  48 + 64, track_running_stats=True, eps=1e-3, momentum=0.01
  49 + )
  50 +
  51 + self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2)
  52 + self.bn8 = torch.nn.BatchNorm2d(
  53 + 32, track_running_stats=True, eps=1e-3, momentum=0.01
  54 + )
  55 +
  56 + self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2)
  57 + self.bn9 = torch.nn.BatchNorm2d(
  58 + 16, track_running_stats=True, eps=1e-3, momentum=0.01
  59 + )
  60 +
  61 + self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2)
  62 + self.bn10 = torch.nn.BatchNorm2d(
  63 + 1, track_running_stats=True, eps=1e-3, momentum=0.01
  64 + )
  65 +
  66 + # output logit is False, so we need self.up7
  67 + self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
  68 +
  69 + def forward(self, x):
  70 + in_x = x
  71 + # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
  72 + x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
  73 + conv1 = self.conv(x)
  74 + batch1 = self.bn(conv1)
  75 + rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2)
  76 +
  77 + x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0)
  78 + conv2 = self.conv1(x) # (3, 32, 128, 256)
  79 + batch2 = self.bn1(conv2)
  80 + rel2 = torch.nn.functional.leaky_relu(
  81 + batch2, negative_slope=0.2
  82 + ) # (3, 32, 128, 256)
  83 +
  84 + x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0)
  85 + conv3 = self.conv2(x) # (3, 64, 64, 128)
  86 + batch3 = self.bn2(conv3)
  87 + rel3 = torch.nn.functional.leaky_relu(
  88 + batch3, negative_slope=0.2
  89 + ) # (3, 64, 64, 128)
  90 +
  91 + x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0)
  92 + conv4 = self.conv3(x) # (3, 128, 32, 64)
  93 + batch4 = self.bn3(conv4)
  94 + rel4 = torch.nn.functional.leaky_relu(
  95 + batch4, negative_slope=0.2
  96 + ) # (3, 128, 32, 64)
  97 +
  98 + x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0)
  99 + conv5 = self.conv4(x) # (3, 256, 16, 32)
  100 + batch5 = self.bn4(conv5)
  101 + rel6 = torch.nn.functional.leaky_relu(
  102 + batch5, negative_slope=0.2
  103 + ) # (3, 256, 16, 32)
  104 +
  105 + x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0)
  106 + conv6 = self.conv5(x) # (3, 512, 8, 16)
  107 +
  108 + up1 = self.up1(conv6)
  109 + up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32)
  110 + up1 = torch.nn.functional.relu(up1)
  111 + batch7 = self.bn5(up1)
  112 + merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32)
  113 +
  114 + up2 = self.up2(merge1)
  115 + up2 = up2[:, :, 1:-2, 1:-2]
  116 + up2 = torch.nn.functional.relu(up2)
  117 + batch8 = self.bn6(up2)
  118 +
  119 + merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64)
  120 +
  121 + up3 = self.up3(merge2)
  122 + up3 = up3[:, :, 1:-2, 1:-2]
  123 + up3 = torch.nn.functional.relu(up3)
  124 + batch9 = self.bn7(up3)
  125 +
  126 + merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128)
  127 +
  128 + up4 = self.up4(merge3)
  129 + up4 = up4[:, :, 1:-2, 1:-2]
  130 + up4 = torch.nn.functional.relu(up4)
  131 + batch10 = self.bn8(up4)
  132 +
  133 + merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256)
  134 +
  135 + up5 = self.up5(merge4)
  136 + up5 = up5[:, :, 1:-2, 1:-2]
  137 + up5 = torch.nn.functional.relu(up5)
  138 + batch11 = self.bn9(up5)
  139 +
  140 + merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512)
  141 +
  142 + up6 = self.up6(merge5)
  143 + up6 = up6[:, :, 1:-2, 1:-2]
  144 + up6 = torch.nn.functional.relu(up6)
  145 + batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024)
  146 +
  147 + up7 = self.up7(batch12)
  148 + up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
  149 +
  150 + return up7 * in_x