Fangjun Kuang
Committed by GitHub

Export spleeter model to onnx for source separation (#2237)

name: export-spleeter-to-onnx
on:
push:
branches:
- spleeter-2
workflow_dispatch:
concurrency:
group: export-spleeter-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-spleeter-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export spleeter to ONNX
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
shell: bash
run: |
pip install tensorflow torch "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools
- name: Run
shell: bash
run: |
cd scripts/spleeter
./run.sh
echo "---"
ls -lh 2stems
echo "---"
ls -lh 2stems/*.onnx
echo "---"
mv -v 2stems/*.onnx ../..
- name: Collect models
shell: bash
run: |
mkdir sherpa-onnx-spleeter-2stems
mkdir sherpa-onnx-spleeter-2stems-int8
mkdir sherpa-onnx-spleeter-2stems-fp16
mv -v vocals.onnx sherpa-onnx-spleeter-2stems/
mv -v accompaniment.onnx sherpa-onnx-spleeter-2stems/
mv -v vocals.int8.onnx sherpa-onnx-spleeter-2stems-int8/
mv -v accompaniment.int8.onnx sherpa-onnx-spleeter-2stems-int8/
mv -v vocals.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
mv -v accompaniment.fp16.onnx sherpa-onnx-spleeter-2stems-fp16/
tar cjvf sherpa-onnx-spleeter-2stems.tar.bz2 sherpa-onnx-spleeter-2stems
tar cjvf sherpa-onnx-spleeter-2stems-int8.tar.bz2 sherpa-onnx-spleeter-2stems-int8
tar cjvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 sherpa-onnx-spleeter-2stems-fp16
ls -lh *.tar.bz2
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: source-separation-models
- name: Publish to huggingface
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
names=(
sherpa-onnx-spleeter-2stems
sherpa-onnx-spleeter-2stems-int8
sherpa-onnx-spleeter-2stems-fp16
)
for d in ${names[@]}; do
rm -rf huggingface
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
cp -v $d/*onnx huggingface
cd huggingface
git lfs track "*.onnx"
git status
git add .
ls -lh
git status
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
cd ..
done
... ...
2stems.tar.gz
2stems
... ...
#!/usr/bin/env python3
# Code in this file is modified from
# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
#
# Please see ./run.sh for usages
import argparse
import tensorflow as tf
def freeze_graph(model_dir, output_node_names, output_filename):
"""Extract the sub graph defined by the output nodes and convert all its
variables into constant
Args:
model_dir:
the root folder containing the checkpoint state file
output_node_names:
a string, containing all the output node's names, comma separated
output_filename:
Filename to save the graph.
"""
if not tf.compat.v1.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exists. Please specify an export "
"directory: %s" % model_dir
)
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
output_graph = output_filename
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We start a session using a temporary fresh Graph
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.compat.v1.train.import_meta_graph(
input_checkpoint + ".meta", clear_devices=clear_devices
)
# We restore the weights
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names.split(
","
), # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.compat.v1.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
return output_graph_def
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-dir", type=str, default="", help="Model folder to export"
)
parser.add_argument(
"--output-node-names",
type=str,
default="vocals_spectrogram/mul,accompaniment_spectrogram/mul",
help="The name of the output nodes, comma separated.",
)
parser.add_argument(
"--output-filename",
type=str,
)
args = parser.parse_args()
freeze_graph(args.model_dir, args.output_node_names, args.output_filename)
... ...
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Please see ./run.sh for usage
import argparse
import numpy as np
import tensorflow as tf
import torch
from unet import UNet
def load_graph(frozen_graph_filename):
# This function is modified from
# https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
# tf.import_graph_def(graph_def, name="prefix")
tf.import_graph_def(graph_def, name="")
return graph
def generate_waveform():
np.random.seed(20230821)
waveform = np.random.rand(60 * 44100).astype(np.float32)
# (num_samples, num_channels)
waveform = waveform.reshape(-1, 2)
return waveform
def get_param(graph, name):
with tf.compat.v1.Session(graph=graph) as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
if constant_op.name != name:
continue
value = sess.run(constant_op.outputs[0])
return torch.from_numpy(value)
@torch.no_grad()
def main(name):
graph = load_graph(f"./2stems/frozen_{name}_model.pb")
# for op in graph.get_operations():
# print(op.name)
x = graph.get_tensor_by_name("waveform:0")
# y = graph.get_tensor_by_name("Reshape:0")
y0 = graph.get_tensor_by_name("strided_slice_3:0")
# y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0")
# y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0")
# y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0")
# y1 = graph.get_tensor_by_name("re_lu/Relu:0")
# y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0")
# y1 = graph.get_tensor_by_name("concatenate/concat:0")
# y1 = graph.get_tensor_by_name("concatenate_1/concat:0")
# y1 = graph.get_tensor_by_name("concatenate_4/concat:0")
# y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0")
# y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0")
y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0")
unet = UNet()
unet.eval()
# For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel)
# default input shape is NHWC
# For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w)
# default input shape is NCHW
state_dict = unet.state_dict()
# print(list(state_dict.keys()))
if name == "vocals":
state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute(
3, 2, 0, 1
)
state_dict["conv.bias"] = get_param(graph, "conv2d/bias")
state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma")
state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta")
state_dict["bn.running_mean"] = get_param(
graph, "batch_normalization/moving_mean"
)
state_dict["bn.running_var"] = get_param(
graph, "batch_normalization/moving_variance"
)
conv_offset = 0
bn_offset = 0
else:
state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute(
3, 2, 0, 1
)
state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias")
state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma")
state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta")
state_dict["bn.running_mean"] = get_param(
graph, "batch_normalization_12/moving_mean"
)
state_dict["bn.running_var"] = get_param(
graph, "batch_normalization_12/moving_variance"
)
conv_offset = 7
bn_offset = 12
for i in range(1, 6):
state_dict[f"conv{i}.weight"] = get_param(
graph, f"conv2d_{i+conv_offset}/kernel"
).permute(3, 2, 0, 1)
state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias")
if i >= 5:
continue
state_dict[f"bn{i}.weight"] = get_param(
graph, f"batch_normalization_{i+bn_offset}/gamma"
)
state_dict[f"bn{i}.bias"] = get_param(
graph, f"batch_normalization_{i+bn_offset}/beta"
)
state_dict[f"bn{i}.running_mean"] = get_param(
graph, f"batch_normalization_{i+bn_offset}/moving_mean"
)
state_dict[f"bn{i}.running_var"] = get_param(
graph, f"batch_normalization_{i+bn_offset}/moving_variance"
)
if name == "vocals":
state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute(
3, 2, 0, 1
)
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias")
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma")
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta")
state_dict["bn5.running_mean"] = get_param(
graph, "batch_normalization_6/moving_mean"
)
state_dict["bn5.running_var"] = get_param(
graph, "batch_normalization_6/moving_variance"
)
conv_offset = 0
bn_offset = 0
else:
state_dict["up1.weight"] = get_param(
graph, "conv2d_transpose_6/kernel"
).permute(3, 2, 0, 1)
state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias")
state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma")
state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta")
state_dict["bn5.running_mean"] = get_param(
graph, "batch_normalization_18/moving_mean"
)
state_dict["bn5.running_var"] = get_param(
graph, "batch_normalization_18/moving_variance"
)
conv_offset = 6
bn_offset = 12
for i in range(1, 6):
state_dict[f"up{i+1}.weight"] = get_param(
graph, f"conv2d_transpose_{i+conv_offset}/kernel"
).permute(3, 2, 0, 1)
state_dict[f"up{i+1}.bias"] = get_param(
graph, f"conv2d_transpose_{i+conv_offset}/bias"
)
state_dict[f"bn{5+i}.weight"] = get_param(
graph, f"batch_normalization_{6+i+bn_offset}/gamma"
)
state_dict[f"bn{5+i}.bias"] = get_param(
graph, f"batch_normalization_{6+i+bn_offset}/beta"
)
state_dict[f"bn{5+i}.running_mean"] = get_param(
graph, f"batch_normalization_{6+i+bn_offset}/moving_mean"
)
state_dict[f"bn{5+i}.running_var"] = get_param(
graph, f"batch_normalization_{6+i+bn_offset}/moving_variance"
)
if name == "vocals":
state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute(
3, 2, 0, 1
)
state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias")
else:
state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute(
3, 2, 0, 1
)
state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias")
unet.load_state_dict(state_dict)
with tf.compat.v1.Session(graph=graph) as sess:
y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()})
# y0_out = sess.run(y0, feed_dict={x: generate_waveform()})
# y1_out = sess.run(y1, feed_dict={x: generate_waveform()})
# print(y0_out.shape)
# print(y1_out.shape)
# for the batchnormalization in tensorflow,
# default input shape is NHWC
# for the batchnormalization in torch,
# default input shape is NCHW
# NHWC to NCHW
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
assert torch.allclose(
torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1
), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max())
torch.save(unet.state_dict(), f"2stems/{name}.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--name",
type=str,
required=True,
choices=["vocals", "accompaniment"],
)
args = parser.parse_args()
print(vars(args))
main(args.name)
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import onnx
import onnxmltools
import torch
from onnxmltools.utils.float16_converter import convert_float_to_float16
from onnxruntime.quantization import QuantType, quantize_dynamic
from unet import UNet
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
def add_meta_data(filename, prefix):
meta_data = {
"model_type": "spleeter",
"sample_rate": 41000,
"version": 1,
"model_url": "https://github.com/deezer/spleeter",
"stems": 2,
"comment": prefix,
"model_name": "2stems.tar.gz",
}
model = onnx.load(filename)
print(model.metadata_props)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
print("--------------------")
print(model.metadata_props)
onnx.save(model, filename)
def export(model, prefix):
num_splits = 1
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
filename = f"./2stems/{prefix}.onnx"
torch.onnx.export(
model,
x,
filename,
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "num_splits"},
},
opset_version=13,
)
add_meta_data(filename, prefix)
filename_int8 = f"./2stems/{prefix}.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QUInt8,
)
filename_fp16 = f"./2stems/{prefix}.fp16.onnx"
export_onnx_fp16(filename, filename_fp16)
@torch.no_grad()
def main():
vocals = UNet()
state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
vocals.load_state_dict(state_dict)
vocals.eval()
accompaniment = UNet()
state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
accompaniment.load_state_dict(state_dict)
accompaniment.eval()
export(vocals, "vocals")
export(accompaniment, "accompaniment")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env bash
if [ ! -f 2stems.tar.gz ]; then
curl -SL -O https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz
fi
if [ ! -d ./2stems ]; then
mkdir -p 2stems
cd 2stems
tar xvf ../2stems.tar.gz
cd ..
fi
ls -lh
ls -lh 2stems
if [ ! -f 2stems/frozen_vocals_model.pb ]; then
python3 ./convert_to_pb.py \
--model-dir ./2stems \
--output-node-names vocals_spectrogram/mul \
--output-filename ./2stems/frozen_vocals_model.pb
fi
ls -lh 2stems
if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then
python3 ./convert_to_pb.py \
--model-dir ./2stems \
--output-node-names accompaniment_spectrogram/mul \
--output-filename ./2stems/frozen_accompaniment_model.pb
fi
ls -lh 2stems
python3 ./convert_to_torch.py --name vocals
python3 ./convert_to_torch.py --name accompaniment
python3 ./export_onnx.py
ls -lh 2stems
... ...
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Please see ./run.sh for usage
from typing import Optional
import ffmpeg
import numpy as np
import soundfile as sf
import torch
from pydub import AudioSegment
from unet import UNet
def load_audio(filename, sample_rate: Optional[int] = 44100):
probe = ffmpeg.probe(filename)
if "streams" not in probe or len(probe["streams"]) == 0:
raise ValueError("No stream was found with ffprobe")
metadata = next(
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
)
n_channels = metadata["channels"]
if sample_rate is None:
sample_rate = metadata["sample_rate"]
process = (
ffmpeg.input(filename)
.output("pipe:", format="f32le", ar=sample_rate)
.run_async(pipe_stdout=True, pipe_stderr=True)
)
buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
waveform = torch.from_numpy(np.copy(waveform)).to(torch.float32)
if n_channels == 1:
waveform = waveform.tile(1, 2)
if n_channels > 2:
waveform = waveform[:, :2]
return waveform, sample_rate
@torch.no_grad()
def main():
vocals = UNet()
vocals.eval()
state_dict = torch.load("./2stems/vocals.pt", map_location="cpu")
vocals.load_state_dict(state_dict)
accompaniment = UNet()
accompaniment.eval()
state_dict = torch.load("./2stems/accompaniment.pt", map_location="cpu")
accompaniment.load_state_dict(state_dict)
#
# waveform, sample_rate = load_audio("./audio_example.mp3")
# You can download the following two mp3 from
# https://huggingface.co/spaces/csukuangfj/music-source-separation/tree/main/examples
waveform, sample_rate = load_audio("./qi-feng-le.mp3")
# waveform, sample_rate = load_audio("./Yesterday_Once_More-Carpenters.mp3")
assert waveform.shape[1] == 2, waveform.shape
waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096))
# torch.stft requires a 2-D input of shape (N, T), so we transpose waveform
stft = torch.stft(
waveform.t(),
n_fft=4096,
hop_length=1024,
window=torch.hann_window(4096, periodic=True),
center=False,
onesided=True,
return_complex=True,
)
print("stft", stft.shape)
# stft: (2, 2049, 465)
# stft is a complex tensor
y = stft.permute(2, 1, 0)
print("y0", y.shape)
# (465, 2049, 2)
y = y[:, :1024, :]
# (465, 1024, 2)
tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512
pad_size = 512 - tensor_size
y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size))
# (512, 1024, 2)
print("y1", y.shape, y.dtype)
num_splits = int(y.shape[0] / 512)
y = y.reshape([num_splits, 512] + list(y.shape[1:]))
# y: (1, 512, 1024, 2)
print("y2", y.shape, y.dtype)
y = y.abs()
y = y.permute(0, 3, 1, 2)
# (1, 2, 512, 1024)
print("y3", y.shape, y.dtype)
vocals_spec = vocals(y)
accompaniment_spec = accompaniment(y)
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
print(
"vocals_spec",
vocals_spec.shape,
accompaniment_spec.shape,
sum_spec.shape,
vocals_spec.dtype,
)
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
# (1, 2, 512, 1024)
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
# (1, 2, 512, 1024)
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0))
# (1, 2, 512, 2049)
spec = spec.permute(0, 2, 3, 1)
# (1, 512, 2049, 2)
print("here00", spec.shape)
spec = spec.reshape(-1, spec.shape[2], spec.shape[3])
# (512, 2049, 2)
print("here2", spec.shape)
# (512, 2049, 2)
spec = spec[: stft.shape[2], :, :]
# (465, 2049, 2)
print("here 3", spec.shape, stft.shape)
spec = spec.permute(2, 1, 0)
# (2, 2049, 465)
masked_stft = spec * stft
wave = torch.istft(
masked_stft,
4096,
1024,
window=torch.hann_window(4096, periodic=True),
onesided=True,
) * (2 / 3)
print(wave.shape, wave.dtype)
sf.write(f"{name}.wav", wave.t(), 44100)
wave = (wave.t() * 32768).to(torch.int16)
sound = AudioSegment(
data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
)
sound.export(f"{name}.mp3", format="mp3", bitrate="128k")
if __name__ == "__main__":
main()
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import time
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
from separate import load_audio
"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
"""
class OnnxModel:
def __init__(self, filename):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
print(f"----------inputs for {filename}----------")
for i in self.model.get_inputs():
print(i)
print(f"----------outputs for {filename}----------")
for i in self.model.get_outputs():
print(i)
print("--------------------")
def __call__(self, x):
"""
Args:
x: (num_splits, 2, 512, 1024)
"""
spec = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
return torch.from_numpy(spec)
def main():
vocals = OnnxModel("./2stems/vocals.onnx")
accompaniment = OnnxModel("./2stems/accompaniment.onnx")
waveform, sample_rate = load_audio("./qi-feng-le.mp3")
waveform = waveform[: 44100 * 10, :]
stft_config = knf.StftConfig(
n_fft=4096,
hop_length=1024,
win_length=4096,
center=False,
window_type="hann",
)
knf_stft = knf.Stft(stft_config)
knf_istft = knf.IStft(stft_config)
start = time.time()
stft_result_c0 = knf_stft(waveform[:, 0].tolist())
stft_result_c1 = knf_stft(waveform[:, 1].tolist())
print("c0 stft", stft_result_c0.num_frames)
orig_real0 = np.array(stft_result_c0.real, dtype=np.float32).reshape(
stft_result_c0.num_frames, -1
)
orig_imag0 = np.array(stft_result_c0.imag, dtype=np.float32).reshape(
stft_result_c0.num_frames, -1
)
orig_real1 = np.array(stft_result_c1.real, dtype=np.float32).reshape(
stft_result_c1.num_frames, -1
)
orig_imag1 = np.array(stft_result_c1.imag, dtype=np.float32).reshape(
stft_result_c1.num_frames, -1
)
real0 = torch.from_numpy(orig_real0)
imag0 = torch.from_numpy(orig_imag0)
real1 = torch.from_numpy(orig_real1)
imag1 = torch.from_numpy(orig_imag1)
# (num_frames, n_fft/2_1)
print("real0", real0.shape)
# keep only the first 1024 bins
real0 = real0[:, :1024]
imag0 = imag0[:, :1024]
real1 = real1[:, :1024]
imag1 = imag1[:, :1024]
stft0 = (real0.square() + imag0.square()).sqrt()
stft1 = (real1.square() + imag1.square()).sqrt()
# pad it to multiple of 512
padding = 512 - real0.shape[0] % 512
print("padding", padding)
if padding > 0:
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
stft0 = stft0.reshape(-1, 1, 512, 1024)
stft1 = stft1.reshape(-1, 1, 512, 1024)
stft_01 = torch.cat([stft0, stft1], axis=1)
print("stft_01", stft_01.shape, stft_01.dtype)
vocals_spec = vocals(stft_01)
accompaniment_spec = accompaniment(stft_01)
# (num_splits, num_channels, 512, 1024)
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec_c0 = spec[:, 0, :, :]
spec_c1 = spec[:, 1, :, :]
spec_c0 = spec_c0.reshape(-1, 1024)
spec_c1 = spec_c1.reshape(-1, 1024)
spec_c0 = spec_c0[: stft_result_c0.num_frames, :]
spec_c1 = spec_c1[: stft_result_c0.num_frames, :]
spec_c0 = torch.nn.functional.pad(spec_c0, (0, 2049 - 1024, 0, 0))
spec_c1 = torch.nn.functional.pad(spec_c1, (0, 2049 - 1024, 0, 0))
spec_c0_real = spec_c0 * orig_real0
spec_c0_imag = spec_c0 * orig_imag0
spec_c1_real = spec_c1 * orig_real1
spec_c1_imag = spec_c1 * orig_imag1
result0 = knf.StftResult(
real=spec_c0_real.reshape(-1).tolist(),
imag=spec_c0_imag.reshape(-1).tolist(),
num_frames=orig_real0.shape[0],
)
result1 = knf.StftResult(
real=spec_c1_real.reshape(-1).tolist(),
imag=spec_c1_imag.reshape(-1).tolist(),
num_frames=orig_real1.shape[0],
)
wav0 = knf_istft(result0)
wav1 = knf_istft(result1)
wav = np.array([wav0, wav1], dtype=np.float32)
wav = np.transpose(wav)
# now wav is (num_samples, num_channels)
sf.write(f"./onnx-{name}.wav", wav, 44100)
print(f"Saved to ./onnx-{name}.wav")
end = time.time()
elapsed_seconds = end - start
audio_duration = waveform.shape[0] / sample_rate
real_time_factor = elapsed_seconds / audio_duration
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
if __name__ == "__main__":
main()
... ...
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import torch
class UNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0)
self.bn = torch.nn.BatchNorm2d(
16, track_running_stats=True, eps=1e-3, momentum=0.01
)
#
self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0)
self.bn1 = torch.nn.BatchNorm2d(
32, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0)
self.bn2 = torch.nn.BatchNorm2d(
64, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0)
self.bn3 = torch.nn.BatchNorm2d(
128, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0)
self.bn4 = torch.nn.BatchNorm2d(
256, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0)
self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
self.bn5 = torch.nn.BatchNorm2d(
256, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2)
self.bn6 = torch.nn.BatchNorm2d(
128, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2)
self.bn7 = torch.nn.BatchNorm2d(
64, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2)
self.bn8 = torch.nn.BatchNorm2d(
32, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2)
self.bn9 = torch.nn.BatchNorm2d(
16, track_running_stats=True, eps=1e-3, momentum=0.01
)
self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2)
self.bn10 = torch.nn.BatchNorm2d(
1, track_running_stats=True, eps=1e-3, momentum=0.01
)
# output logit is False, so we need self.up7
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
def forward(self, x):
in_x = x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
conv1 = self.conv(x)
batch1 = self.bn(conv1)
rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2)
x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0)
conv2 = self.conv1(x) # (3, 32, 128, 256)
batch2 = self.bn1(conv2)
rel2 = torch.nn.functional.leaky_relu(
batch2, negative_slope=0.2
) # (3, 32, 128, 256)
x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0)
conv3 = self.conv2(x) # (3, 64, 64, 128)
batch3 = self.bn2(conv3)
rel3 = torch.nn.functional.leaky_relu(
batch3, negative_slope=0.2
) # (3, 64, 64, 128)
x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0)
conv4 = self.conv3(x) # (3, 128, 32, 64)
batch4 = self.bn3(conv4)
rel4 = torch.nn.functional.leaky_relu(
batch4, negative_slope=0.2
) # (3, 128, 32, 64)
x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0)
conv5 = self.conv4(x) # (3, 256, 16, 32)
batch5 = self.bn4(conv5)
rel6 = torch.nn.functional.leaky_relu(
batch5, negative_slope=0.2
) # (3, 256, 16, 32)
x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0)
conv6 = self.conv5(x) # (3, 512, 8, 16)
up1 = self.up1(conv6)
up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32)
up1 = torch.nn.functional.relu(up1)
batch7 = self.bn5(up1)
merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32)
up2 = self.up2(merge1)
up2 = up2[:, :, 1:-2, 1:-2]
up2 = torch.nn.functional.relu(up2)
batch8 = self.bn6(up2)
merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64)
up3 = self.up3(merge2)
up3 = up3[:, :, 1:-2, 1:-2]
up3 = torch.nn.functional.relu(up3)
batch9 = self.bn7(up3)
merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128)
up4 = self.up4(merge3)
up4 = up4[:, :, 1:-2, 1:-2]
up4 = torch.nn.functional.relu(up4)
batch10 = self.bn8(up4)
merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256)
up5 = self.up5(merge4)
up5 = up5[:, :, 1:-2, 1:-2]
up5 = torch.nn.functional.relu(up5)
batch11 = self.bn9(up5)
merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512)
up6 = self.up6(merge5)
up6 = up6[:, :, 1:-2, 1:-2]
up6 = torch.nn.functional.relu(up6)
batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024)
up7 = self.up7(batch12)
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
return up7 * in_x
... ...