Fangjun Kuang
Committed by GitHub

Add Python API for source separation (#2283)

@@ -8,6 +8,32 @@ log() { @@ -8,6 +8,32 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test spleeter"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  14 +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  15 +rm sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  16 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav
  17 +./python-api-examples/offline-source-separation-spleeter.py
  18 +rm -rf sherpa-onnx-spleeter-2stems-fp16
  19 +rm qi-feng-le-zh.wav
  20 +
  21 +log "test UVR"
  22 +
  23 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/UVR_MDXNET_9482.onnx
  24 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav
  25 +./python-api-examples/offline-source-separation-uvr.py
  26 +rm UVR_MDXNET_9482.onnx
  27 +rm qi-feng-le-zh.wav
  28 +
  29 +mkdir source-separation
  30 +
  31 +mv spleeter-*.wav source-separation
  32 +mv uvr-*.wav source-separation
  33 +
  34 +ls -lh source-separation
  35 +
  36 +
11 log "test offline dolphin ctc" 37 log "test offline dolphin ctc"
12 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 38 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2
13 tar xvf sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2 39 tar xvf sherpa-onnx-dolphin-base-ctc-multi-lang-int8-2025-04-02.tar.bz2
@@ -99,5 +99,10 @@ jobs: @@ -99,5 +99,10 @@ jobs:
99 99
100 - uses: actions/upload-artifact@v4 100 - uses: actions/upload-artifact@v4
101 with: 101 with:
  102 + name: source-separation-${{ matrix.os }}-${{ matrix.python-version }}
  103 + path: ./source-separation
  104 +
  105 + - uses: actions/upload-artifact@v4
  106 + with:
102 name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }} 107 name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }}
103 path: tts 108 path: tts
@@ -36,22 +36,18 @@ jobs: @@ -36,22 +36,18 @@ jobs:
36 fail-fast: false 36 fail-fast: false
37 matrix: 37 matrix:
38 include: 38 include:
39 - # it fails to install ffmpeg on ubuntu 20.04  
40 - #  
41 - # - os: ubuntu-20.04  
42 - # python-version: "3.7"  
43 - # - os: ubuntu-20.04  
44 - # python-version: "3.8"  
45 - # - os: ubuntu-20.04  
46 - # python-version: "3.9"  
47 -  
48 - - os: ubuntu-22.04 39 + - os: ubuntu-24.04
  40 + python-version: "3.8"
  41 + - os: ubuntu-24.04
  42 + python-version: "3.9"
  43 +
  44 + - os: ubuntu-24.04
49 python-version: "3.10" 45 python-version: "3.10"
50 - - os: ubuntu-22.04 46 + - os: ubuntu-24.04
51 python-version: "3.11" 47 python-version: "3.11"
52 - - os: ubuntu-22.04 48 + - os: ubuntu-24.04
53 python-version: "3.12" 49 python-version: "3.12"
54 - - os: ubuntu-22.04 50 + - os: ubuntu-24.04
55 python-version: "3.13" 51 python-version: "3.13"
56 52
57 steps: 53 steps:
@@ -81,10 +77,12 @@ jobs: @@ -81,10 +77,12 @@ jobs:
81 python3 -m pip install --upgrade pip numpy pypinyin sentencepiece>=0.1.96 soundfile 77 python3 -m pip install --upgrade pip numpy pypinyin sentencepiece>=0.1.96 soundfile
82 python3 -m pip install wheel twine setuptools 78 python3 -m pip install wheel twine setuptools
83 79
84 - - name: Install ffmpeg  
85 - shell: bash  
86 - run: |  
87 - sudo apt-get install ffmpeg 80 + - uses: afoley587/setup-ffmpeg@main
  81 + id: setup-ffmpeg
  82 + with:
  83 + ffmpeg-version: release
  84 + architecture: ''
  85 + github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }}
88 86
89 - name: Install ninja 87 - name: Install ninja
90 shell: bash 88 shell: bash
@@ -191,5 +189,10 @@ jobs: @@ -191,5 +189,10 @@ jobs:
191 189
192 - uses: actions/upload-artifact@v4 190 - uses: actions/upload-artifact@v4
193 with: 191 with:
  192 + name: source-separation-${{ matrix.os }}-${{ matrix.python-version }}-whl
  193 + path: ./source-separation
  194 +
  195 + - uses: actions/upload-artifact@v4
  196 + with:
194 name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }} 197 name: tts-generated-test-files-${{ matrix.os }}-${{ matrix.python-version }}
195 path: tts 198 path: tts
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2025 Xiaomi Corporation
  3 +
  4 +"""
  5 +This file shows how to use spleeter for source separation.
  6 +
  7 +Please first download a spleeter model from
  8 +
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
  10 +
  11 +The following is an example:
  12 +
  13 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  14 +
  15 +Please also download a test file
  16 +
  17 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav
  18 +
  19 +The test wav file is 16-bit encoded with 2 channels. If you have other
  20 +formats, e.g., .mp4 or .mp3, please first use ffmpeg to convert it.
  21 +For instance
  22 +
  23 + ffmpeg -i your.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 out.wav
  24 +
  25 +Then you can use out.wav as input for this example.
  26 +"""
  27 +
  28 +import time
  29 +from pathlib import Path
  30 +
  31 +import numpy as np
  32 +import sherpa_onnx
  33 +import soundfile as sf
  34 +
  35 +
  36 +def create_offline_source_separation():
  37 + # Please read the help message at the beginning of this file
  38 + # to download model files
  39 + vocals = "./sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx"
  40 + accompaniment = "./sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx"
  41 +
  42 + if not Path(vocals).is_file():
  43 + raise ValueError(f"{vocals} does not exist.")
  44 +
  45 + if not Path(accompaniment).is_file():
  46 + raise ValueError(f"{accompaniment} does not exist.")
  47 +
  48 + config = sherpa_onnx.OfflineSourceSeparationConfig(
  49 + model=sherpa_onnx.OfflineSourceSeparationModelConfig(
  50 + spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig(
  51 + vocals=vocals,
  52 + accompaniment=accompaniment,
  53 + ),
  54 + num_threads=1,
  55 + debug=False,
  56 + provider="cpu",
  57 + )
  58 + )
  59 + if not config.validate():
  60 + raise ValueError("Please check your config.")
  61 +
  62 + return sherpa_onnx.OfflineSourceSeparation(config)
  63 +
  64 +
  65 +def load_audio():
  66 + # Please read the help message at the beginning of this file to download
  67 + # the following wav_file
  68 + wav_file = "./qi-feng-le-zh.wav"
  69 + if not Path(wav_file).is_file():
  70 + raise ValueError(f"{wav_file} does not exist")
  71 +
  72 + samples, sample_rate = sf.read(wav_file, dtype="float32", always_2d=True)
  73 + samples = np.transpose(samples)
  74 + # now samples is of shape (num_channels, num_samples)
  75 + assert (
  76 + samples.shape[1] > samples.shape[0]
  77 + ), f"You should use (num_channels, num_samples). {samples.shape}"
  78 +
  79 + assert (
  80 + samples.dtype == np.float32
  81 + ), f"Expect np.float32 as dtype. Given: {samples.dtype}"
  82 +
  83 + return samples, sample_rate
  84 +
  85 +
  86 +def main():
  87 + sp = create_offline_source_separation()
  88 + samples, sample_rate = load_audio()
  89 + samples = np.ascontiguousarray(samples)
  90 +
  91 + start = time.time()
  92 + output = sp.process(sample_rate=sample_rate, samples=samples)
  93 + end = time.time()
  94 +
  95 + print("output.sample_rate", output.sample_rate)
  96 +
  97 + assert len(output.stems) == 2, len(output.stems)
  98 +
  99 + vocals = output.stems[0].data
  100 + non_vocals = output.stems[1].data
  101 + # vocals.shape (num_channels, num_samples)
  102 +
  103 + vocals = np.transpose(vocals)
  104 + non_vocals = np.transpose(non_vocals)
  105 +
  106 + # vocals.shape (num_samples,num_channels)
  107 +
  108 + sf.write("./spleeter-vocals.wav", vocals, samplerate=output.sample_rate)
  109 + sf.write("./spleeter-non-vocals.wav", non_vocals, samplerate=output.sample_rate)
  110 +
  111 + elapsed_seconds = end - start
  112 + audio_duration = samples.shape[1] / sample_rate
  113 + real_time_factor = elapsed_seconds / audio_duration
  114 +
  115 + print("Saved to ./spleeter-vocals.wav and ./spleeter-non-vocals.wav")
  116 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  117 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  118 + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
  119 +
  120 +
  121 +if __name__ == "__main__":
  122 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2025 Xiaomi Corporation
  3 +
  4 +"""
  5 +This file shows how to use UVR for source separation.
  6 +
  7 +Please first download a UVR model from
  8 +
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
  10 +
  11 +The following is an example:
  12 +
  13 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/UVR_MDXNET_9482.onnx
  14 +
  15 +Please also download a test file
  16 +
  17 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/qi-feng-le-zh.wav
  18 +
  19 +The test wav file is 16-bit encoded with 2 channels. If you have other
  20 +formats, e.g., .mp4 or .mp3, please first use ffmpeg to convert it.
  21 +For instance
  22 +
  23 + ffmpeg -i your.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 out.wav
  24 +
  25 +Then you can use out.wav as input for this example.
  26 +"""
  27 +
  28 +import time
  29 +from pathlib import Path
  30 +
  31 +import numpy as np
  32 +import sherpa_onnx
  33 +import soundfile as sf
  34 +
  35 +
  36 +def create_offline_source_separation():
  37 + # Please read the help message at the beginning of this file
  38 + # to download model files
  39 + model = "./UVR_MDXNET_9482.onnx"
  40 +
  41 + if not Path(model).is_file():
  42 + raise ValueError(f"{model} does not exist.")
  43 +
  44 + config = sherpa_onnx.OfflineSourceSeparationConfig(
  45 + model=sherpa_onnx.OfflineSourceSeparationModelConfig(
  46 + uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig(
  47 + model=model,
  48 + ),
  49 + num_threads=1,
  50 + debug=False,
  51 + provider="cpu",
  52 + )
  53 + )
  54 + if not config.validate():
  55 + raise ValueError("Please check your config.")
  56 +
  57 + return sherpa_onnx.OfflineSourceSeparation(config)
  58 +
  59 +
  60 +def load_audio():
  61 + # Please read the help message at the beginning of this file to download
  62 + # the following wav_file
  63 + wav_file = "./qi-feng-le-zh.wav"
  64 + if not Path(wav_file).is_file():
  65 + raise ValueError(f"{wav_file} does not exist")
  66 +
  67 + samples, sample_rate = sf.read(wav_file, dtype="float32", always_2d=True)
  68 + samples = np.transpose(samples)
  69 + # now samples is of shape (num_channels, num_samples)
  70 + assert (
  71 + samples.shape[1] > samples.shape[0]
  72 + ), f"You should use (num_channels, num_samples). {samples.shape}"
  73 +
  74 + assert (
  75 + samples.dtype == np.float32
  76 + ), f"Expect np.float32 as dtype. Given: {samples.dtype}"
  77 +
  78 + return samples, sample_rate
  79 +
  80 +
  81 +def main():
  82 + sp = create_offline_source_separation()
  83 + samples, sample_rate = load_audio()
  84 + samples = np.ascontiguousarray(samples)
  85 +
  86 + print("Started. Please wait")
  87 + start = time.time()
  88 + output = sp.process(sample_rate=sample_rate, samples=samples)
  89 + end = time.time()
  90 +
  91 + print("output.sample_rate", output.sample_rate)
  92 +
  93 + assert len(output.stems) == 2, len(output.stems)
  94 +
  95 + vocals = output.stems[0].data
  96 + non_vocals = output.stems[1].data
  97 + # vocals.shape (num_channels, num_samples)
  98 +
  99 + vocals = np.transpose(vocals)
  100 + non_vocals = np.transpose(non_vocals)
  101 +
  102 + # vocals.shape (num_samples,num_channels)
  103 +
  104 + sf.write("./uvr-vocals.wav", vocals, samplerate=output.sample_rate)
  105 + sf.write("./uvr-non-vocals.wav", non_vocals, samplerate=output.sample_rate)
  106 +
  107 + elapsed_seconds = end - start
  108 + audio_duration = samples.shape[1] / sample_rate
  109 + real_time_factor = elapsed_seconds / audio_duration
  110 +
  111 + print("Saved to ./uvr-vocals.wav and ./uvr-non-vocals.wav")
  112 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  113 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  114 + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
  115 +
  116 +
  117 +if __name__ == "__main__":
  118 + main()
@@ -20,6 +20,10 @@ set(srcs @@ -20,6 +20,10 @@ set(srcs
20 offline-punctuation.cc 20 offline-punctuation.cc
21 offline-recognizer.cc 21 offline-recognizer.cc
22 offline-sense-voice-model-config.cc 22 offline-sense-voice-model-config.cc
  23 + offline-source-separation-model-config.cc
  24 + offline-source-separation-spleeter-model-config.cc
  25 + offline-source-separation-uvr-model-config.cc
  26 + offline-source-separation.cc
23 offline-speech-denoiser-gtcrn-model-config.cc 27 offline-speech-denoiser-gtcrn-model-config.cc
24 offline-speech-denoiser-model-config.cc 28 offline-speech-denoiser-model-config.cc
25 offline-speech-denoiser.cc 29 offline-speech-denoiser.cc
@@ -9,6 +9,8 @@ @@ -9,6 +9,8 @@
9 9
10 #include "sherpa-onnx/csrc/fast-clustering.h" 10 #include "sherpa-onnx/csrc/fast-clustering.h"
11 11
  12 +#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
  13 +
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
13 15
14 static void PybindFastClusteringConfig(py::module *m) { 16 static void PybindFastClusteringConfig(py::module *m) {
@@ -32,6 +34,12 @@ void PybindFastClustering(py::module *m) { @@ -32,6 +34,12 @@ void PybindFastClustering(py::module *m) {
32 "__call__", 34 "__call__",
33 [](const PyClass &self, 35 [](const PyClass &self,
34 py::array_t<float> features) -> std::vector<int32_t> { 36 py::array_t<float> features) -> std::vector<int32_t> {
  37 + if (!(C_CONTIGUOUS == (features.flags() & C_CONTIGUOUS))) {
  38 + throw py::value_error(
  39 + "input features should be contiguous. Please use "
  40 + "np.ascontiguousarray(features)");
  41 + }
  42 +
35 int num_dim = features.ndim(); 43 int num_dim = features.ndim();
36 if (num_dim != 2) { 44 if (num_dim != 2) {
37 std::ostringstream os; 45 std::ostringstream os;
@@ -59,14 +59,14 @@ void PybindOfflineRecognizer(py::module *m) { @@ -59,14 +59,14 @@ void PybindOfflineRecognizer(py::module *m) {
59 return self.CreateStream(hotwords); 59 return self.CreateStream(hotwords);
60 }, 60 },
61 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>()) 61 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
62 - .def("decode_stream", &PyClass::DecodeStream, 62 + .def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
63 py::call_guard<py::gil_scoped_release>()) 63 py::call_guard<py::gil_scoped_release>())
64 .def( 64 .def(
65 "decode_streams", 65 "decode_streams",
66 [](const PyClass &self, std::vector<OfflineStream *> ss) { 66 [](const PyClass &self, std::vector<OfflineStream *> ss) {
67 self.DecodeStreams(ss.data(), ss.size()); 67 self.DecodeStreams(ss.data(), ss.size());
68 }, 68 },
69 - py::call_guard<py::gil_scoped_release>()); 69 + py::arg("ss"), py::call_guard<py::gil_scoped_release>());
70 } 70 }
71 71
72 } // namespace sherpa_onnx 72 } // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-source-separation-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
  10 +#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void PybindOfflineSourceSeparationModelConfig(py::module *m) {
  16 + PybindOfflineSourceSeparationSpleeterModelConfig(m);
  17 + PybindOfflineSourceSeparationUvrModelConfig(m);
  18 +
  19 + using PyClass = OfflineSourceSeparationModelConfig;
  20 + py::class_<PyClass>(*m, "OfflineSourceSeparationModelConfig")
  21 + .def(py::init<const OfflineSourceSeparationSpleeterModelConfig &,
  22 + const OfflineSourceSeparationUvrModelConfig &, int32_t,
  23 + bool, const std::string &>(),
  24 + py::arg("spleeter") = OfflineSourceSeparationSpleeterModelConfig{},
  25 + py::arg("uvr") = OfflineSourceSeparationUvrModelConfig{},
  26 + py::arg("num_threads") = 1, py::arg("debug") = false,
  27 + py::arg("provider") = "cpu")
  28 + .def_readwrite("spleeter", &PyClass::spleeter)
  29 + .def_readwrite("uvr", &PyClass::uvr)
  30 + .def_readwrite("num_threads", &PyClass::num_threads)
  31 + .def_readwrite("debug", &PyClass::debug)
  32 + .def_readwrite("provider", &PyClass::provider)
  33 + .def("validate", &PyClass::Validate)
  34 + .def("__str__", &PyClass::ToString);
  35 +}
  36 +
  37 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-source-separation-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSourceSeparationModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m) {
  14 + using PyClass = OfflineSourceSeparationSpleeterModelConfig;
  15 + py::class_<PyClass>(*m, "OfflineSourceSeparationSpleeterModelConfig")
  16 + .def(py::init<const std::string &, const std::string &>(),
  17 + py::arg("vocals") = "", py::arg("accompaniment") = "")
  18 + .def_readwrite("vocals", &PyClass::vocals)
  19 + .def_readwrite("accompaniment", &PyClass::accompaniment)
  20 + .def("validate", &PyClass::Validate)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-source-separation-spleeter-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSourceSeparationSpleeterModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation-uvr-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineSourceSeparationUvrModelConfig(py::module *m) {
  14 + using PyClass = OfflineSourceSeparationUvrModelConfig;
  15 + py::class_<PyClass>(*m, "OfflineSourceSeparationUvrModelConfig")
  16 + .def(py::init<const std::string &>(), py::arg("model") = "")
  17 + .def_readwrite("model", &PyClass::model)
  18 + .def("validate", &PyClass::Validate)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-source-separation-uvr-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSourceSeparationUvrModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_UVR_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-source-separation-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h"
  10 +#include "sherpa-onnx/python/csrc/offline-source-separation.h"
  11 +
  12 +#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +static void PybindOfflineSourceSeparationConfig(py::module *m) {
  17 + PybindOfflineSourceSeparationModelConfig(m);
  18 +
  19 + using PyClass = OfflineSourceSeparationConfig;
  20 + py::class_<PyClass>(*m, "OfflineSourceSeparationConfig")
  21 + .def(py::init<const OfflineSourceSeparationModelConfig &>(),
  22 + py::arg("model") = OfflineSourceSeparationModelConfig{})
  23 + .def_readwrite("model", &PyClass::model)
  24 + .def("validate", &PyClass::Validate)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +static void PybindMultiChannelSamples(py::module *m) {
  29 + using PyClass = MultiChannelSamples;
  30 +
  31 + py::class_<PyClass>(*m, "MultiChannelSamples")
  32 + .def_property_readonly("data", [](PyClass &self) -> py::object {
  33 + // if data is not empty, return a float array of
  34 + // shape (num_channels, num_samples)
  35 + int32_t num_channels = self.data.size();
  36 + if (num_channels == 0) {
  37 + return py::none();
  38 + }
  39 +
  40 + int32_t num_samples = self.data[0].size();
  41 + if (num_samples == 0) {
  42 + return py::none();
  43 + }
  44 +
  45 + py::array_t<float> ans({num_channels, num_samples});
  46 +
  47 + py::buffer_info buf = ans.request();
  48 + auto p = static_cast<float *>(buf.ptr);
  49 +
  50 + for (int32_t i = 0; i != num_channels; ++i) {
  51 + std::copy(self.data[i].begin(), self.data[i].end(),
  52 + p + i * num_samples);
  53 + }
  54 +
  55 + return ans;
  56 + });
  57 +}
  58 +
  59 +static void PybindOfflineSourceSeparationOutput(py::module *m) {
  60 + using PyClass = OfflineSourceSeparationOutput;
  61 + py::class_<PyClass>(*m, "OfflineSourceSeparationOutput")
  62 + .def_property_readonly(
  63 + "sample_rate", [](const PyClass &self) { return self.sample_rate; })
  64 + .def_property_readonly("stems",
  65 + [](const PyClass &self) { return self.stems; });
  66 +}
  67 +
  68 +void PybindOfflineSourceSeparation(py::module *m) {
  69 + PybindOfflineSourceSeparationConfig(m);
  70 + PybindOfflineSourceSeparationOutput(m);
  71 +
  72 + PybindMultiChannelSamples(m);
  73 +
  74 + using PyClass = OfflineSourceSeparation;
  75 + py::class_<PyClass>(*m, "OfflineSourceSeparation")
  76 + .def(py::init<const OfflineSourceSeparationConfig &>(),
  77 + py::arg("config") = OfflineSourceSeparationConfig{})
  78 + .def(
  79 + "process",
  80 + [](const PyClass &self, int32_t sample_rate,
  81 + const py::array_t<float> &samples) {
  82 + if (!(C_CONTIGUOUS == (samples.flags() & C_CONTIGUOUS))) {
  83 + throw py::value_error(
  84 + "input samples should be contiguous. Please use "
  85 + "np.ascontiguousarray(samples)");
  86 + }
  87 +
  88 + int num_dim = samples.ndim();
  89 + if (samples.ndim() != 2) {
  90 + std::ostringstream os;
  91 + os << "Expect an array of 2 dimensions [num_channels x "
  92 + "num_samples]. "
  93 + "Given dim: "
  94 + << num_dim << "\n";
  95 + throw py::value_error(os.str());
  96 + }
  97 +
  98 + // if num_samples is less than 10, it is very likely the user
  99 + // has swapped num_channels and num_samples.
  100 + if (samples.shape(1) < 10) {
  101 + std::ostringstream os;
  102 + os << "Expect an array of 2 dimensions [num_channels x "
  103 + "num_samples]. "
  104 + "Given ["
  105 + << samples.shape(0) << " x " << samples.shape(1) << "]"
  106 + << "\n";
  107 + throw py::value_error(os.str());
  108 + }
  109 +
  110 + int32_t num_channels = samples.shape(0);
  111 + int32_t num_samples = samples.shape(1);
  112 + const float *p = samples.data();
  113 +
  114 + OfflineSourceSeparationInput input;
  115 +
  116 + input.samples.data.resize(num_channels);
  117 + input.sample_rate = sample_rate;
  118 +
  119 + for (int32_t i = 0; i != num_channels; ++i) {
  120 + input.samples.data[i] = {p + i * num_samples,
  121 + p + (i + 1) * num_samples};
  122 + }
  123 +
  124 + pybind11::gil_scoped_release release;
  125 +
  126 + return self.Process(input);
  127 + },
  128 + py::arg("sample_rate"), py::arg("samples"),
  129 + "samples is of shape (num_channels, num-samples) with dtype "
  130 + "np.float32");
  131 +}
  132 +
  133 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-source-separation-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSourceSeparation(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SOURCE_SEPARATION_CONFIG_H_
@@ -47,6 +47,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) { @@ -47,6 +47,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) {
47 int32_t sample_rate) { 47 int32_t sample_rate) {
48 return self.Run(samples.data(), samples.size(), sample_rate); 48 return self.Run(samples.data(), samples.size(), sample_rate);
49 }, 49 },
  50 + py::arg("samples"), py::arg("sample_rate"),
50 py::call_guard<py::gil_scoped_release>()) 51 py::call_guard<py::gil_scoped_release>())
51 .def( 52 .def(
52 "run", 53 "run",
@@ -54,6 +55,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) { @@ -54,6 +55,7 @@ void PybindOfflineSpeechDenoiser(py::module *m) {
54 int32_t sample_rate) { 55 int32_t sample_rate) {
55 return self.Run(samples.data(), samples.size(), sample_rate); 56 return self.Run(samples.data(), samples.size(), sample_rate);
56 }, 57 },
  58 + py::arg("samples"), py::arg("sample_rate"),
57 py::call_guard<py::gil_scoped_release>()) 59 py::call_guard<py::gil_scoped_release>())
58 .def_property_readonly("sample_rate", &PyClass::GetSampleRate); 60 .def_property_readonly("sample_rate", &PyClass::GetSampleRate);
59 } 61 }
@@ -109,19 +109,20 @@ void PybindOnlineRecognizer(py::module *m) { @@ -109,19 +109,20 @@ void PybindOnlineRecognizer(py::module *m) {
109 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>()) 109 py::arg("hotwords"), py::call_guard<py::gil_scoped_release>())
110 .def("is_ready", &PyClass::IsReady, 110 .def("is_ready", &PyClass::IsReady,
111 py::call_guard<py::gil_scoped_release>()) 111 py::call_guard<py::gil_scoped_release>())
112 - .def("decode_stream", &PyClass::DecodeStream, 112 + .def("decode_stream", &PyClass::DecodeStream, py::arg("s"),
113 py::call_guard<py::gil_scoped_release>()) 113 py::call_guard<py::gil_scoped_release>())
114 .def( 114 .def(
115 "decode_streams", 115 "decode_streams",
116 [](PyClass &self, std::vector<OnlineStream *> ss) { 116 [](PyClass &self, std::vector<OnlineStream *> ss) {
117 self.DecodeStreams(ss.data(), ss.size()); 117 self.DecodeStreams(ss.data(), ss.size());
118 }, 118 },
  119 + py::arg("ss"), py::call_guard<py::gil_scoped_release>())
  120 + .def("get_result", &PyClass::GetResult, py::arg("s"),
119 py::call_guard<py::gil_scoped_release>()) 121 py::call_guard<py::gil_scoped_release>())
120 - .def("get_result", &PyClass::GetResult, 122 + .def("is_endpoint", &PyClass::IsEndpoint, py::arg("s"),
121 py::call_guard<py::gil_scoped_release>()) 123 py::call_guard<py::gil_scoped_release>())
122 - .def("is_endpoint", &PyClass::IsEndpoint,  
123 - py::call_guard<py::gil_scoped_release>())  
124 - .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>()); 124 + .def("reset", &PyClass::Reset, py::arg("s"),
  125 + py::call_guard<py::gil_scoped_release>());
125 } 126 }
126 127
127 } // namespace sherpa_onnx 128 } // namespace sherpa_onnx
@@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
17 #include "sherpa-onnx/python/csrc/offline-model-config.h" 17 #include "sherpa-onnx/python/csrc/offline-model-config.h"
18 #include "sherpa-onnx/python/csrc/offline-punctuation.h" 18 #include "sherpa-onnx/python/csrc/offline-punctuation.h"
19 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 19 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
  20 +#include "sherpa-onnx/python/csrc/offline-source-separation.h"
20 #include "sherpa-onnx/python/csrc/offline-speech-denoiser.h" 21 #include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
21 #include "sherpa-onnx/python/csrc/offline-stream.h" 22 #include "sherpa-onnx/python/csrc/offline-stream.h"
22 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" 23 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
@@ -110,6 +111,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -110,6 +111,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
110 111
111 PybindAlsa(&m); 112 PybindAlsa(&m);
112 PybindOfflineSpeechDenoiser(&m); 113 PybindOfflineSpeechDenoiser(&m);
  114 + PybindOfflineSourceSeparation(&m);
113 } 115 }
114 116
115 } // namespace sherpa_onnx 117 } // namespace sherpa_onnx
@@ -11,6 +11,11 @@ from _sherpa_onnx import ( @@ -11,6 +11,11 @@ from _sherpa_onnx import (
11 OfflinePunctuation, 11 OfflinePunctuation,
12 OfflinePunctuationConfig, 12 OfflinePunctuationConfig,
13 OfflinePunctuationModelConfig, 13 OfflinePunctuationModelConfig,
  14 + OfflineSourceSeparation,
  15 + OfflineSourceSeparationConfig,
  16 + OfflineSourceSeparationModelConfig,
  17 + OfflineSourceSeparationSpleeterModelConfig,
  18 + OfflineSourceSeparationUvrModelConfig,
14 OfflineSpeakerDiarization, 19 OfflineSpeakerDiarization,
15 OfflineSpeakerDiarizationConfig, 20 OfflineSpeakerDiarizationConfig,
16 OfflineSpeakerDiarizationResult, 21 OfflineSpeakerDiarizationResult,