Fangjun Kuang
Committed by GitHub

Add C++ runtime for spleeter about source separation (#2242)

@@ -3,7 +3,7 @@ name: export-spleeter-to-onnx @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx
3 on: 3 on:
4 push: 4 push:
5 branches: 5 branches:
6 - - spleeter-2 6 + - spleeter-cpp-2
7 workflow_dispatch: 7 workflow_dispatch:
8 8
9 concurrency: 9 concurrency:
@@ -56,6 +56,7 @@ def get_binaries(): @@ -56,6 +56,7 @@ def get_binaries():
56 "sherpa-onnx-offline-denoiser", 56 "sherpa-onnx-offline-denoiser",
57 "sherpa-onnx-offline-language-identification", 57 "sherpa-onnx-offline-language-identification",
58 "sherpa-onnx-offline-punctuation", 58 "sherpa-onnx-offline-punctuation",
  59 + "sherpa-onnx-offline-source-separation",
59 "sherpa-onnx-offline-speaker-diarization", 60 "sherpa-onnx-offline-speaker-diarization",
60 "sherpa-onnx-offline-tts", 61 "sherpa-onnx-offline-tts",
61 "sherpa-onnx-offline-tts-play", 62 "sherpa-onnx-offline-tts-play",
@@ -217,8 +217,8 @@ def main(name): @@ -217,8 +217,8 @@ def main(name):
217 # for the batchnormalization in torch, 217 # for the batchnormalization in torch,
218 # default input shape is NCHW 218 # default input shape is NCHW
219 219
220 - # NHWC to NCHW  
221 - torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) 220 + torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
  221 + torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)
222 222
223 # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) 223 # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
224 assert torch.allclose( 224 assert torch.allclose(
@@ -46,7 +46,7 @@ def add_meta_data(filename, prefix): @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
46 46
47 def export(model, prefix): 47 def export(model, prefix):
48 num_splits = 1 48 num_splits = 1
49 - x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) 49 + x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)
50 50
51 filename = f"./2stems/{prefix}.onnx" 51 filename = f"./2stems/{prefix}.onnx"
52 torch.onnx.export( 52 torch.onnx.export(
@@ -56,7 +56,7 @@ def export(model, prefix): @@ -56,7 +56,7 @@ def export(model, prefix):
56 input_names=["x"], 56 input_names=["x"],
57 output_names=["y"], 57 output_names=["y"],
58 dynamic_axes={ 58 dynamic_axes={
59 - "x": {0: "num_splits"}, 59 + "x": {1: "num_splits"},
60 }, 60 },
61 opset_version=13, 61 opset_version=13,
62 ) 62 )
@@ -101,13 +101,17 @@ def main(): @@ -101,13 +101,17 @@ def main():
101 print("y2", y.shape, y.dtype) 101 print("y2", y.shape, y.dtype)
102 102
103 y = y.abs() 103 y = y.abs()
104 - y = y.permute(0, 3, 1, 2)  
105 - # (1, 2, 512, 1024) 104 +
  105 + y = y.permute(3, 0, 1, 2)
  106 + # (2, 1, 512, 1024)
106 print("y3", y.shape, y.dtype) 107 print("y3", y.shape, y.dtype)
107 108
108 vocals_spec = vocals(y) 109 vocals_spec = vocals(y)
109 accompaniment_spec = accompaniment(y) 110 accompaniment_spec = accompaniment(y)
110 111
  112 + vocals_spec = vocals_spec.permute(1, 0, 2, 3)
  113 + accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)
  114 +
111 sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 115 sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
112 print( 116 print(
113 "vocals_spec", 117 "vocals_spec",
@@ -12,15 +12,14 @@ from separate import load_audio @@ -12,15 +12,14 @@ from separate import load_audio
12 12
13 """ 13 """
14 ----------inputs for ./2stems/vocals.onnx---------- 14 ----------inputs for ./2stems/vocals.onnx----------
15 -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) 15 +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
16 ----------outputs for ./2stems/vocals.onnx---------- 16 ----------outputs for ./2stems/vocals.onnx----------
17 -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) 17 +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
18 18
19 ----------inputs for ./2stems/accompaniment.onnx---------- 19 ----------inputs for ./2stems/accompaniment.onnx----------
20 -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) 20 +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
21 ----------outputs for ./2stems/accompaniment.onnx---------- 21 ----------outputs for ./2stems/accompaniment.onnx----------
22 -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])  
23 - 22 +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
24 """ 23 """
25 24
26 25
@@ -123,16 +122,16 @@ def main(): @@ -123,16 +122,16 @@ def main():
123 if padding > 0: 122 if padding > 0:
124 stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) 123 stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
125 stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) 124 stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
126 - stft0 = stft0.reshape(-1, 1, 512, 1024)  
127 - stft1 = stft1.reshape(-1, 1, 512, 1024) 125 + stft0 = stft0.reshape(1, -1, 512, 1024)
  126 + stft1 = stft1.reshape(1, -1, 512, 1024)
128 127
129 - stft_01 = torch.cat([stft0, stft1], axis=1) 128 + stft_01 = torch.cat([stft0, stft1], axis=0)
130 129
131 print("stft_01", stft_01.shape, stft_01.dtype) 130 print("stft_01", stft_01.shape, stft_01.dtype)
132 131
133 vocals_spec = vocals(stft_01) 132 vocals_spec = vocals(stft_01)
134 accompaniment_spec = accompaniment(stft_01) 133 accompaniment_spec = accompaniment(stft_01)
135 - # (num_splits, num_channels, 512, 1024) 134 + # (num_channels, num_splits, 512, 1024)
136 135
137 sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 136 sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
138 137
@@ -142,8 +141,8 @@ def main(): @@ -142,8 +141,8 @@ def main():
142 for name, spec in zip( 141 for name, spec in zip(
143 ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] 142 ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
144 ): 143 ):
145 - spec_c0 = spec[:, 0, :, :]  
146 - spec_c1 = spec[:, 1, :, :] 144 + spec_c0 = spec[0]
  145 + spec_c1 = spec[1]
147 146
148 spec_c0 = spec_c0.reshape(-1, 1024) 147 spec_c0 = spec_c0.reshape(-1, 1024)
149 spec_c1 = spec_c1.reshape(-1, 1024) 148 spec_c1 = spec_c1.reshape(-1, 1024)
@@ -67,6 +67,14 @@ class UNet(torch.nn.Module): @@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
67 self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) 67 self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
68 68
69 def forward(self, x): 69 def forward(self, x):
  70 + """
  71 + Args:
  72 + x: (num_audio_channels, num_splits, 512, 1024)
  73 + Returns:
  74 + y: (num_audio_channels, num_splits, 512, 1024)
  75 + """
  76 + x = x.permute(1, 0, 2, 3)
  77 +
70 in_x = x 78 in_x = x
71 # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) 79 # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
72 x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) 80 x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module): @@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
147 up7 = self.up7(batch12) 155 up7 = self.up7(batch12)
148 up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) 156 up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
149 157
150 - return up7 * in_x 158 + ans = up7 * in_x
  159 + return ans.permute(1, 0, 2, 3)
@@ -50,6 +50,13 @@ set(sources @@ -50,6 +50,13 @@ set(sources
50 offline-rnn-lm.cc 50 offline-rnn-lm.cc
51 offline-sense-voice-model-config.cc 51 offline-sense-voice-model-config.cc
52 offline-sense-voice-model.cc 52 offline-sense-voice-model.cc
  53 +
  54 + offline-source-separation-impl.cc
  55 + offline-source-separation-model-config.cc
  56 + offline-source-separation-spleeter-model-config.cc
  57 + offline-source-separation-spleeter-model.cc
  58 + offline-source-separation.cc
  59 +
53 offline-stream.cc 60 offline-stream.cc
54 offline-tdnn-ctc-model.cc 61 offline-tdnn-ctc-model.cc
55 offline-tdnn-model-config.cc 62 offline-tdnn-model-config.cc
@@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
326 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) 333 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
327 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 334 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
328 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) 335 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
  336 + add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc)
329 add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) 337 add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
330 add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc) 338 add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc)
331 339
@@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
346 sherpa-onnx-offline-language-identification 354 sherpa-onnx-offline-language-identification
347 sherpa-onnx-offline-parallel 355 sherpa-onnx-offline-parallel
348 sherpa-onnx-offline-punctuation 356 sherpa-onnx-offline-punctuation
  357 + sherpa-onnx-offline-source-separation
349 sherpa-onnx-online-punctuation 358 sherpa-onnx-online-punctuation
350 sherpa-onnx-vad 359 sherpa-onnx-vad
351 ) 360 )
  1 +// sherpa-onnx/csrc/offline-source-separation-impl.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
  6 +
  7 +#include <memory>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +std::unique_ptr<OfflineSourceSeparationImpl>
  14 +OfflineSourceSeparationImpl::Create(
  15 + const OfflineSourceSeparationConfig &config) {
  16 + // TODO(fangjun): Support other models
  17 + return std::make_unique<OfflineSourceSeparationSpleeterImpl>(config);
  18 +}
  19 +
  20 +template <typename Manager>
  21 +std::unique_ptr<OfflineSourceSeparationImpl>
  22 +OfflineSourceSeparationImpl::Create(
  23 + Manager *mgr, const OfflineSourceSeparationConfig &config) {
  24 + // TODO(fangjun): Support other models
  25 + return std::make_unique<OfflineSourceSeparationSpleeterImpl>(mgr, config);
  26 +}
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 +template std::unique_ptr<OfflineSourceSeparationImpl>
  30 +OfflineSourceSeparationImpl::Create(
  31 + AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
  32 +#endif
  33 +
  34 +#if __OHOS__
  35 +template std::unique_ptr<OfflineSourceSeparationImpl>
  36 +OfflineSourceSeparationImpl::Create(
  37 + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
  38 +#endif
  39 +
  40 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-source-separation-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-source-separation.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineSourceSeparationImpl {
  15 + public:
  16 + static std::unique_ptr<OfflineSourceSeparationImpl> Create(
  17 + const OfflineSourceSeparationConfig &config);
  18 +
  19 + template <typename Manager>
  20 + static std::unique_ptr<OfflineSourceSeparationImpl> Create(
  21 + Manager *mgr, const OfflineSourceSeparationConfig &config);
  22 +
  23 + virtual ~OfflineSourceSeparationImpl() = default;
  24 +
  25 + virtual OfflineSourceSeparationOutput Process(
  26 + const OfflineSourceSeparationInput &input) const = 0;
  27 +
  28 + virtual int32_t GetOutputSampleRate() const = 0;
  29 +
  30 + virtual int32_t GetNumberOfStems() const = 0;
  31 +};
  32 +
  33 +} // namespace sherpa_onnx
  34 +
  35 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-source-separation-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
  6 +
  7 +namespace sherpa_onnx {
  8 +
  9 +void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) {
  10 + spleeter.Register(po);
  11 +
  12 + po->Register("num-threads", &num_threads,
  13 + "Number of threads to run the neural network");
  14 +
  15 + po->Register("debug", &debug,
  16 + "true to print model information while loading it.");
  17 +
  18 + po->Register("provider", &provider,
  19 + "Specify a provider to use: cpu, cuda, coreml");
  20 +}
  21 +
  22 +bool OfflineSourceSeparationModelConfig::Validate() const {
  23 + return spleeter.Validate();
  24 +}
  25 +
  26 +std::string OfflineSourceSeparationModelConfig::ToString() const {
  27 + std::ostringstream os;
  28 +
  29 + os << "OfflineSourceSeparationModelConfig(";
  30 + os << "spleeter=" << spleeter.ToString() << ", ";
  31 + os << "num_threads=" << num_threads << ", ";
  32 + os << "debug=" << (debug ? "True" : "False") << ", ";
  33 + os << "provider=\"" << provider << "\")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-source-separation-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OfflineSourceSeparationModelConfig {
  16 + OfflineSourceSeparationSpleeterModelConfig spleeter;
  17 +
  18 + int32_t num_threads = 1;
  19 + bool debug = false;
  20 + std::string provider = "cpu";
  21 +
  22 + OfflineSourceSeparationModelConfig() = default;
  23 +
  24 + OfflineSourceSeparationModelConfig(
  25 + const OfflineSourceSeparationSpleeterModelConfig &spleeter,
  26 + int32_t num_threads, bool debug, const std::string &provider)
  27 + : spleeter(spleeter),
  28 + num_threads(num_threads),
  29 + debug(debug),
  30 + provider(provider) {}
  31 +
  32 + void Register(ParseOptions *po);
  33 +
  34 + bool Validate() const;
  35 +
  36 + std::string ToString() const;
  37 +};
  38 +
  39 +} // namespace sherpa_onnx
  40 +
  41 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
  7 +
  8 +#include "Eigen/Dense"
  9 +#include "kaldi-native-fbank/csrc/istft.h"
  10 +#include "kaldi-native-fbank/csrc/stft.h"
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
  13 +#include "sherpa-onnx/csrc/offline-source-separation.h"
  14 +#include "sherpa-onnx/csrc/onnx-utils.h"
  15 +#include "sherpa-onnx/csrc/resample.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl {
  20 + public:
  21 + OfflineSourceSeparationSpleeterImpl(
  22 + const OfflineSourceSeparationConfig &config)
  23 + : config_(config), model_(config_.model) {}
  24 +
  25 + template <typename Manager>
  26 + OfflineSourceSeparationSpleeterImpl(
  27 + Manager *mgr, const OfflineSourceSeparationConfig &config)
  28 + : config_(config), model_(mgr, config_.model) {}
  29 +
  30 + OfflineSourceSeparationOutput Process(
  31 + const OfflineSourceSeparationInput &input) const override {
  32 + const OfflineSourceSeparationInput *p_input = &input;
  33 + OfflineSourceSeparationInput tmp_input;
  34 +
  35 + int32_t output_sample_rate = GetOutputSampleRate();
  36 +
  37 + if (input.sample_rate != output_sample_rate) {
  38 + SHERPA_ONNX_LOGE(
  39 + "Creating a resampler:\n"
  40 + " in_sample_rate: %d\n"
  41 + " output_sample_rate: %d\n",
  42 + input.sample_rate, output_sample_rate);
  43 +
  44 + float min_freq = std::min<int32_t>(input.sample_rate, output_sample_rate);
  45 + float lowpass_cutoff = 0.99 * 0.5 * min_freq;
  46 +
  47 + int32_t lowpass_filter_width = 6;
  48 + auto resampler = std::make_unique<LinearResample>(
  49 + input.sample_rate, output_sample_rate, lowpass_cutoff,
  50 + lowpass_filter_width);
  51 +
  52 + std::vector<float> s;
  53 + for (const auto &samples : input.samples.data) {
  54 + resampler->Reset();
  55 + resampler->Resample(samples.data(), samples.size(), true, &s);
  56 + tmp_input.samples.data.push_back(std::move(s));
  57 + }
  58 +
  59 + tmp_input.sample_rate = output_sample_rate;
  60 + p_input = &tmp_input;
  61 + }
  62 +
  63 + if (p_input->samples.data.size() > 1) {
  64 + if (config_.model.debug) {
  65 + SHERPA_ONNX_LOGE("input ch1 samples size: %d",
  66 + static_cast<int32_t>(p_input->samples.data[1].size()));
  67 + }
  68 +
  69 + if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) {
  70 + SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d",
  71 + static_cast<int32_t>(p_input->samples.data[0].size()),
  72 + static_cast<int32_t>(p_input->samples.data[1].size()));
  73 +
  74 + SHERPA_ONNX_EXIT(-1);
  75 + }
  76 + }
  77 +
  78 + auto stft_ch0 = ComputeStft(*p_input, 0);
  79 +
  80 + auto stft_ch1 = ComputeStft(*p_input, 1);
  81 + knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1;
  82 +
  83 + int32_t num_frames = stft_ch0.num_frames;
  84 + int32_t fft_bins = stft_ch0.real.size() / num_frames;
  85 +
  86 + int32_t pad = 512 - (stft_ch0.num_frames % 512);
  87 + if (pad < 512) {
  88 + num_frames += pad;
  89 + }
  90 +
  91 + if (num_frames % 512) {
  92 + SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d",
  93 + num_frames, num_frames % 512);
  94 + SHERPA_ONNX_EXIT(-1);
  95 + }
  96 +
  97 + Eigen::VectorXf real(2 * num_frames * 1024);
  98 + Eigen::VectorXf imag(2 * num_frames * 1024);
  99 + real.setZero();
  100 + imag.setZero();
  101 +
  102 + float *p_real = &real[0];
  103 + float *p_imag = &imag[0];
  104 +
  105 + // copy stft result of channel 0
  106 + for (int32_t i = 0; i != stft_ch0.num_frames; ++i) {
  107 + std::copy(stft_ch0.real.data() + i * fft_bins,
  108 + stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i);
  109 +
  110 + std::copy(stft_ch0.imag.data() + i * fft_bins,
  111 + stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i);
  112 + }
  113 +
  114 + p_real += num_frames * 1024;
  115 + p_imag += num_frames * 1024;
  116 +
  117 + // copy stft result of channel 1
  118 + for (int32_t i = 0; i != stft_ch1.num_frames; ++i) {
  119 + std::copy(p_stft_ch1->real.data() + i * fft_bins,
  120 + p_stft_ch1->real.data() + i * fft_bins + 1024,
  121 + p_real + 1024 * i);
  122 +
  123 + std::copy(p_stft_ch1->imag.data() + i * fft_bins,
  124 + p_stft_ch1->imag.data() + i * fft_bins + 1024,
  125 + p_imag + 1024 * i);
  126 + }
  127 +
  128 + Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt();
  129 +
  130 + auto memory_info =
  131 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  132 +
  133 + std::array<int64_t, 4> x_shape{2, num_frames / 512, 512, 1024};
  134 + Ort::Value x_tensor = Ort::Value::CreateTensor(
  135 + memory_info, &x[0], x.size(), x_shape.data(), x_shape.size());
  136 +
  137 + Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor));
  138 + Ort::Value accompaniment_spec_tensor =
  139 + model_.RunAccompaniment(std::move(x_tensor));
  140 +
  141 + Eigen::VectorXf vocals_spec = Eigen::Map<Eigen::VectorXf>(
  142 + vocals_spec_tensor.GetTensorMutableData<float>(), x.size());
  143 +
  144 + Eigen::VectorXf accompaniment_spec = Eigen::Map<Eigen::VectorXf>(
  145 + accompaniment_spec_tensor.GetTensorMutableData<float>(), x.size());
  146 +
  147 + Eigen::VectorXf sum_spec = vocals_spec.array().square() +
  148 + accompaniment_spec.array().square() + 1e-10;
  149 +
  150 + vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array();
  151 +
  152 + accompaniment_spec =
  153 + (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array();
  154 +
  155 + auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0);
  156 + auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1);
  157 +
  158 + auto accompaniment_samples_ch0 =
  159 + ProcessSpec(accompaniment_spec, stft_ch0, 0);
  160 + auto accompaniment_samples_ch1 =
  161 + ProcessSpec(accompaniment_spec, *p_stft_ch1, 1);
  162 +
  163 + OfflineSourceSeparationOutput ans;
  164 + ans.sample_rate = GetOutputSampleRate();
  165 +
  166 + ans.stems.resize(2);
  167 + ans.stems[0].data.reserve(2);
  168 + ans.stems[1].data.reserve(2);
  169 +
  170 + ans.stems[0].data.push_back(std::move(vocals_samples_ch0));
  171 + ans.stems[0].data.push_back(std::move(vocals_samples_ch1));
  172 +
  173 + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0));
  174 + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1));
  175 +
  176 + return ans;
  177 + }
  178 +
  179 + int32_t GetOutputSampleRate() const override {
  180 + return model_.GetMetaData().sample_rate;
  181 + }
  182 +
  183 + int32_t GetNumberOfStems() const override {
  184 + return model_.GetMetaData().num_stems;
  185 + }
  186 +
  187 + private:
  188 + // spec is of shape (2, num_chunks, 512, 1024)
  189 + std::vector<float> ProcessSpec(const Eigen::VectorXf &spec,
  190 + const knf::StftResult &stft,
  191 + int32_t channel) const {
  192 + int32_t fft_bins = stft.real.size() / stft.num_frames;
  193 +
  194 + Eigen::VectorXf mask(stft.real.size());
  195 + mask.setZero();
  196 +
  197 + float *p_mask = &mask[0];
  198 +
  199 + // assume there are 2 channels
  200 + const float *p_spec = &spec[0] + (spec.size() / 2) * channel;
  201 +
  202 + for (int32_t i = 0; i != stft.num_frames; ++i) {
  203 + std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024,
  204 + p_mask + i * fft_bins);
  205 + }
  206 +
  207 + knf::StftResult masked_stft;
  208 +
  209 + masked_stft.num_frames = stft.num_frames;
  210 + masked_stft.real.resize(stft.real.size());
  211 + masked_stft.imag.resize(stft.imag.size());
  212 +
  213 + Eigen::Map<Eigen::VectorXf>(masked_stft.real.data(),
  214 + masked_stft.real.size()) =
  215 + mask.array() *
  216 + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.real.data()),
  217 + stft.real.size())
  218 + .array();
  219 +
  220 + Eigen::Map<Eigen::VectorXf>(masked_stft.imag.data(),
  221 + masked_stft.imag.size()) =
  222 + mask.array() *
  223 + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.imag.data()),
  224 + stft.imag.size())
  225 + .array();
  226 +
  227 + auto stft_config = GetStftConfig();
  228 + knf::IStft istft(stft_config);
  229 +
  230 + return istft.Compute(masked_stft);
  231 + }
  232 +
  233 + knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input,
  234 + int32_t ch) const {
  235 + if (ch >= input.samples.data.size()) {
  236 + SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch,
  237 + static_cast<int32_t>(input.samples.data.size()));
  238 + SHERPA_ONNX_EXIT(-1);
  239 + }
  240 +
  241 + if (input.samples.data[ch].empty()) {
  242 + return {};
  243 + }
  244 +
  245 + return ComputeStft(input.samples.data[ch]);
  246 + }
  247 +
  248 + knf::StftResult ComputeStft(const std::vector<float> &samples) const {
  249 + auto stft_config = GetStftConfig();
  250 + knf::Stft stft(stft_config);
  251 +
  252 + return stft.Compute(samples.data(), samples.size());
  253 + }
  254 +
  255 + knf::StftConfig GetStftConfig() const {
  256 + const auto &meta = model_.GetMetaData();
  257 +
  258 + knf::StftConfig stft_config;
  259 + stft_config.n_fft = meta.n_fft;
  260 + stft_config.hop_length = meta.hop_length;
  261 + stft_config.win_length = meta.window_length;
  262 + stft_config.window_type = meta.window_type;
  263 + stft_config.center = meta.center;
  264 + stft_config.center = false;
  265 +
  266 + return stft_config;
  267 + }
  268 +
  269 + private:
  270 + OfflineSourceSeparationConfig config_;
  271 + OfflineSourceSeparationSpleeterModel model_;
  272 +};
  273 +
  274 +} // namespace sherpa_onnx
  275 +
  276 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) {
  13 + po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model");
  14 +
  15 + po->Register("spleeter-accompaniment", &accompaniment,
  16 + "Path to the spleeter accompaniment model");
  17 +}
  18 +
  19 +bool OfflineSourceSeparationSpleeterModelConfig::Validate() const {
  20 + if (vocals.empty()) {
  21 + SHERPA_ONNX_LOGE("Please provide --spleeter-vocals");
  22 + return false;
  23 + }
  24 +
  25 + if (!FileExists(vocals)) {
  26 + SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str());
  27 + return false;
  28 + }
  29 +
  30 + if (accompaniment.empty()) {
  31 + SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment");
  32 + return false;
  33 + }
  34 +
  35 + if (!FileExists(accompaniment)) {
  36 + SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ",
  37 + accompaniment.c_str());
  38 + return false;
  39 + }
  40 +
  41 + return true;
  42 +}
  43 +
  44 +std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const {
  45 + std::ostringstream os;
  46 +
  47 + os << "OfflineSourceSeparationSpleeterModelConfig(";
  48 + os << "vocals=\"" << vocals << "\", ";
  49 + os << "accompaniment=\"" << accompaniment << "\")";
  50 +
  51 + return os.str();
  52 +}
  53 +
  54 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct OfflineSourceSeparationSpleeterModelConfig {
  16 + std::string vocals;
  17 +
  18 + std::string accompaniment;
  19 +
  20 + OfflineSourceSeparationSpleeterModelConfig() = default;
  21 +
  22 + OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals,
  23 + const std::string &accompaniment)
  24 + : vocals(vocals), accompaniment(accompaniment) {}
  25 +
  26 + void Register(ParseOptions *po);
  27 +
  28 + bool Validate() const;
  29 +
  30 + std::string ToString() const;
  31 +};
  32 +
  33 +} // namespace sherpa_onnx
  34 +
  35 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
  6 +
  7 +#include <string>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// See also
  14 +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py
  15 +struct OfflineSourceSeparationSpleeterModelMetaData {
  16 + int32_t sample_rate = 44100;
  17 + int32_t num_stems = 2;
  18 +
  19 + int32_t n_fft = 4096;
  20 + int32_t hop_length = 1024;
  21 + int32_t window_length = 4096;
  22 + bool center = false;
  23 + std::string window_type = "hann";
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#if __OHOS__
  18 +#include "rawfile/raw_file_manager.h"
  19 +#endif
  20 +
  21 +#include "sherpa-onnx/csrc/file-utils.h"
  22 +#include "sherpa-onnx/csrc/onnx-utils.h"
  23 +#include "sherpa-onnx/csrc/session.h"
  24 +#include "sherpa-onnx/csrc/text-utils.h"
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +class OfflineSourceSeparationSpleeterModel::Impl {
  29 + public:
  30 + explicit Impl(const OfflineSourceSeparationModelConfig &config)
  31 + : config_(config),
  32 + env_(ORT_LOGGING_LEVEL_ERROR),
  33 + sess_opts_(GetSessionOptions(config)),
  34 + allocator_{} {
  35 + {
  36 + auto buf = ReadFile(config.spleeter.vocals);
  37 + InitVocals(buf.data(), buf.size());
  38 + }
  39 +
  40 + {
  41 + auto buf = ReadFile(config.spleeter.accompaniment);
  42 + InitAccompaniment(buf.data(), buf.size());
  43 + }
  44 + }
  45 +
  46 + template <typename Manager>
  47 + Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config)
  48 + : config_(config),
  49 + env_(ORT_LOGGING_LEVEL_ERROR),
  50 + sess_opts_(GetSessionOptions(config)),
  51 + allocator_{} {
  52 + {
  53 + auto buf = ReadFile(mgr, config.spleeter.vocals);
  54 + InitVocals(buf.data(), buf.size());
  55 + }
  56 +
  57 + {
  58 + auto buf = ReadFile(mgr, config.spleeter.accompaniment);
  59 + InitAccompaniment(buf.data(), buf.size());
  60 + }
  61 + }
  62 +
  63 + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const {
  64 + return meta_;
  65 + }
  66 +
  67 + Ort::Value RunVocals(Ort::Value x) const {
  68 + auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1,
  69 + vocals_output_names_ptr_.data(),
  70 + vocals_output_names_ptr_.size());
  71 + return std::move(out[0]);
  72 + }
  73 +
  74 + Ort::Value RunAccompaniment(Ort::Value x) const {
  75 + auto out =
  76 + accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x,
  77 + 1, accompaniment_output_names_ptr_.data(),
  78 + accompaniment_output_names_ptr_.size());
  79 + return std::move(out[0]);
  80 + }
  81 +
  82 + private:
  83 + void InitVocals(void *model_data, size_t model_data_length) {
  84 + vocals_sess_ = std::make_unique<Ort::Session>(
  85 + env_, model_data, model_data_length, sess_opts_);
  86 +
  87 + GetInputNames(vocals_sess_.get(), &vocals_input_names_,
  88 + &vocals_input_names_ptr_);
  89 +
  90 + GetOutputNames(vocals_sess_.get(), &vocals_output_names_,
  91 + &vocals_output_names_ptr_);
  92 +
  93 + Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata();
  94 + if (config_.debug) {
  95 + std::ostringstream os;
  96 + os << "---vocals model---\n";
  97 + PrintModelMetadata(os, meta_data);
  98 +
  99 + os << "----------input names----------\n";
  100 + int32_t i = 0;
  101 + for (const auto &s : vocals_input_names_) {
  102 + os << i << " " << s << "\n";
  103 + ++i;
  104 + }
  105 + os << "----------output names----------\n";
  106 + i = 0;
  107 + for (const auto &s : vocals_output_names_) {
  108 + os << i << " " << s << "\n";
  109 + ++i;
  110 + }
  111 +
  112 +#if __OHOS__
  113 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  114 +#else
  115 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  116 +#endif
  117 + }
  118 +
  119 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  120 +
  121 + std::string model_type;
  122 + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
  123 + if (model_type != "spleeter") {
  124 + SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'",
  125 + model_type.c_str());
  126 + SHERPA_ONNX_EXIT(-1);
  127 + }
  128 +
  129 + SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems");
  130 + if (meta_.num_stems != 2) {
  131 + SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems",
  132 + meta_.num_stems);
  133 + SHERPA_ONNX_EXIT(-1);
  134 + }
  135 + }
  136 +
  137 + void InitAccompaniment(void *model_data, size_t model_data_length) {
  138 + accompaniment_sess_ = std::make_unique<Ort::Session>(
  139 + env_, model_data, model_data_length, sess_opts_);
  140 +
  141 + GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_,
  142 + &accompaniment_input_names_ptr_);
  143 +
  144 + GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_,
  145 + &accompaniment_output_names_ptr_);
  146 + }
  147 +
  148 + private:
  149 + OfflineSourceSeparationModelConfig config_;
  150 + OfflineSourceSeparationSpleeterModelMetaData meta_;
  151 +
  152 + Ort::Env env_;
  153 + Ort::SessionOptions sess_opts_;
  154 + Ort::AllocatorWithDefaultOptions allocator_;
  155 +
  156 + std::unique_ptr<Ort::Session> vocals_sess_;
  157 +
  158 + std::vector<std::string> vocals_input_names_;
  159 + std::vector<const char *> vocals_input_names_ptr_;
  160 +
  161 + std::vector<std::string> vocals_output_names_;
  162 + std::vector<const char *> vocals_output_names_ptr_;
  163 +
  164 + std::unique_ptr<Ort::Session> accompaniment_sess_;
  165 +
  166 + std::vector<std::string> accompaniment_input_names_;
  167 + std::vector<const char *> accompaniment_input_names_ptr_;
  168 +
  169 + std::vector<std::string> accompaniment_output_names_;
  170 + std::vector<const char *> accompaniment_output_names_ptr_;
  171 +};
  172 +
  173 +OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() =
  174 + default;
  175 +
  176 +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
  177 + const OfflineSourceSeparationModelConfig &config)
  178 + : impl_(std::make_unique<Impl>(config)) {}
  179 +
  180 +template <typename Manager>
  181 +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
  182 + Manager *mgr, const OfflineSourceSeparationModelConfig &config)
  183 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  184 +
  185 +Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const {
  186 + return impl_->RunVocals(std::move(x));
  187 +}
  188 +
  189 +Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment(
  190 + Ort::Value x) const {
  191 + return impl_->RunAccompaniment(std::move(x));
  192 +}
  193 +
  194 +const OfflineSourceSeparationSpleeterModelMetaData &
  195 +OfflineSourceSeparationSpleeterModel::GetMetaData() const {
  196 + return impl_->GetMetaData();
  197 +}
  198 +
  199 +#if __ANDROID_API__ >= 9
  200 +template OfflineSourceSeparationSpleeterModel::
  201 + OfflineSourceSeparationSpleeterModel(
  202 + AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config);
  203 +#endif
  204 +
  205 +#if __OHOS__
  206 +template OfflineSourceSeparationSpleeterModel::
  207 + OfflineSourceSeparationSpleeterModel(
  208 + NativeResourceManager *mgr,
  209 + const OfflineSourceSeparationModelConfig &config);
  210 +#endif
  211 +
  212 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
  6 +#include <memory>
  7 +
  8 +#include "onnxruntime_cxx_api.h" // NOLINT
  9 +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
  10 +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineSourceSeparationSpleeterModel {
  15 + public:
  16 + ~OfflineSourceSeparationSpleeterModel();
  17 +
  18 + explicit OfflineSourceSeparationSpleeterModel(
  19 + const OfflineSourceSeparationModelConfig &config);
  20 +
  21 + template <typename Manager>
  22 + OfflineSourceSeparationSpleeterModel(
  23 + Manager *mgr, const OfflineSourceSeparationModelConfig &config);
  24 +
  25 + Ort::Value RunVocals(Ort::Value x) const;
  26 + Ort::Value RunAccompaniment(Ort::Value x) const;
  27 +
  28 + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const;
  29 +
  30 + private:
  31 + class Impl;
  32 + std::unique_ptr<Impl> impl_;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-source-separation.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-source-separation.h"
  6 +
  7 +#include <memory>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#if __OHOS__
  17 +#include "rawfile/raw_file_manager.h"
  18 +#endif
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +void OfflineSourceSeparationConfig::Register(ParseOptions *po) {
  23 + model.Register(po);
  24 +}
  25 +
  26 +bool OfflineSourceSeparationConfig::Validate() const {
  27 + return model.Validate();
  28 +}
  29 +
  30 +std::string OfflineSourceSeparationConfig::ToString() const {
  31 + std::ostringstream os;
  32 +
  33 + os << "OfflineSourceSeparationConfig(";
  34 + os << "model=" << model.ToString() << ")";
  35 +
  36 + return os.str();
  37 +}
  38 +
  39 +template <typename Manager>
  40 +OfflineSourceSeparation::OfflineSourceSeparation(
  41 + Manager *mgr, const OfflineSourceSeparationConfig &config)
  42 + : impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {}
  43 +
  44 +OfflineSourceSeparation::OfflineSourceSeparation(
  45 + const OfflineSourceSeparationConfig &config)
  46 + : impl_(OfflineSourceSeparationImpl::Create(config)) {}
  47 +
  48 +OfflineSourceSeparation::~OfflineSourceSeparation() = default;
  49 +
  50 +OfflineSourceSeparationOutput OfflineSourceSeparation::Process(
  51 + const OfflineSourceSeparationInput &input) const {
  52 + return impl_->Process(input);
  53 +}
  54 +
  55 +int32_t OfflineSourceSeparation::GetOutputSampleRate() const {
  56 + return impl_->GetOutputSampleRate();
  57 +}
  58 +
  59 +// e.g., it is 2 for 2stems from spleeter
  60 +int32_t OfflineSourceSeparation::GetNumberOfStems() const {
  61 + return impl_->GetNumberOfStems();
  62 +}
  63 +
  64 +#if __ANDROID_API__ >= 9
  65 +template OfflineSourceSeparation::OfflineSourceSeparation(
  66 + AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
  67 +#endif
  68 +
  69 +#if __OHOS__
  70 +template OfflineSourceSeparation::OfflineSourceSeparation(
  71 + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
  72 +#endif
  73 +
  74 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-source-separation.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
  13 +#include "sherpa-onnx/csrc/parse-options.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct OfflineSourceSeparationConfig {
  18 + OfflineSourceSeparationModelConfig model;
  19 +
  20 + OfflineSourceSeparationConfig() = default;
  21 +
  22 + OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model)
  23 + : model(model) {}
  24 +
  25 + void Register(ParseOptions *po);
  26 +
  27 + bool Validate() const;
  28 +
  29 + std::string ToString() const;
  30 +};
  31 +
  32 +struct MultiChannelSamples {
  33 + // data[i] is for the i-th channel
  34 + //
  35 + // each sample is in the range [-1, 1]
  36 + std::vector<std::vector<float>> data;
  37 +};
  38 +
  39 +struct OfflineSourceSeparationInput {
  40 + MultiChannelSamples samples;
  41 +
  42 + int32_t sample_rate;
  43 +};
  44 +
  45 +struct OfflineSourceSeparationOutput {
  46 + std::vector<MultiChannelSamples> stems;
  47 +
  48 + int32_t sample_rate;
  49 +};
  50 +
  51 +class OfflineSourceSeparationImpl;
  52 +
  53 +class OfflineSourceSeparation {
  54 + public:
  55 + ~OfflineSourceSeparation();
  56 +
  57 + OfflineSourceSeparation(const OfflineSourceSeparationConfig &config);
  58 +
  59 + template <typename Manager>
  60 + OfflineSourceSeparation(Manager *mgr,
  61 + const OfflineSourceSeparationConfig &config);
  62 +
  63 + OfflineSourceSeparationOutput Process(
  64 + const OfflineSourceSeparationInput &input) const;
  65 +
  66 + int32_t GetOutputSampleRate() const;
  67 +
  68 + // e.g., it is 2 for 2stems from spleeter
  69 + int32_t GetNumberOfStems() const;
  70 +
  71 + private:
  72 + std::unique_ptr<OfflineSourceSeparationImpl> impl_;
  73 +};
  74 +
  75 +} // namespace sherpa_onnx
  76 +
  77 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
@@ -12,7 +12,7 @@ @@ -12,7 +12,7 @@
12 namespace sherpa_onnx { 12 namespace sherpa_onnx {
13 13
14 // please refer to 14 // please refer to
15 -// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py 15 +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py
16 struct OfflineSpeechDenoiserGtcrnModelMetaData { 16 struct OfflineSpeechDenoiserGtcrnModelMetaData {
17 int32_t sample_rate = 0; 17 int32_t sample_rate = 0;
18 int32_t version = 1; 18 int32_t version = 1;
@@ -11,7 +11,7 @@ @@ -11,7 +11,7 @@
11 11
12 int main(int32_t argc, char *argv[]) { 12 int main(int32_t argc, char *argv[]) {
13 const char *kUsageMessage = R"usage( 13 const char *kUsageMessage = R"usage(
14 -Non-stremaing speech denoising with sherpa-onnx. 14 +Non-streaming speech denoising with sherpa-onnx.
15 15
16 Please visit 16 Please visit
17 https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models 17 https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#include <stdio.h>
  5 +
  6 +#include <chrono> // NOLINT
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-source-separation.h"
  10 +#include "sherpa-onnx/csrc/wave-reader.h"
  11 +#include "sherpa-onnx/csrc/wave-writer.h"
  12 +
  13 +int main(int32_t argc, char *argv[]) {
  14 + const char *kUsageMessage = R"usage(
  15 +Non-streaming source separation with sherpa-onnx.
  16 +
  17 +Please visit
  18 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
  19 +to download models.
  20 +
  21 +Usage:
  22 +
  23 +(1) Use spleeter models
  24 +
  25 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  26 +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
  27 +
  28 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
  29 +
  30 +./bin/sherpa-onnx-offline-source-separation \
  31 + --spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \
  32 + --spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \
  33 + --input-wav=audio_example.wav \
  34 + --output-vocals-wav=output_vocals.wav \
  35 + --output-accompaniment-wav=output_accompaniment.wav
  36 +)usage";
  37 +
  38 + sherpa_onnx::ParseOptions po(kUsageMessage);
  39 + sherpa_onnx::OfflineSourceSeparationConfig config;
  40 +
  41 + std::string input_wave;
  42 + std::string output_vocals_wave;
  43 + std::string output_accompaniment_wave;
  44 +
  45 + config.Register(&po);
  46 + po.Register("input-wav", &input_wave, "Path to input wav.");
  47 + po.Register("output-vocals-wav", &output_vocals_wave,
  48 + "Path to output vocals wav");
  49 + po.Register("output-accompaniment-wav", &output_accompaniment_wave,
  50 + "Path to output accompaniment wav");
  51 +
  52 + po.Read(argc, argv);
  53 + if (po.NumArgs() != 0) {
  54 + fprintf(stderr, "Please don't give positional arguments\n");
  55 + po.PrintUsage();
  56 + exit(EXIT_FAILURE);
  57 + }
  58 + fprintf(stderr, "%s\n", config.ToString().c_str());
  59 +
  60 + if (input_wave.empty()) {
  61 + fprintf(stderr, "Please provide --input-wav\n");
  62 + po.PrintUsage();
  63 + exit(EXIT_FAILURE);
  64 + }
  65 +
  66 + if (output_vocals_wave.empty()) {
  67 + fprintf(stderr, "Please provide --output-vocals-wav\n");
  68 + po.PrintUsage();
  69 + exit(EXIT_FAILURE);
  70 + }
  71 +
  72 + if (output_accompaniment_wave.empty()) {
  73 + fprintf(stderr, "Please provide --output-accompaniment-wav\n");
  74 + po.PrintUsage();
  75 + exit(EXIT_FAILURE);
  76 + }
  77 +
  78 + if (!config.Validate()) {
  79 + fprintf(stderr, "Errors in config!\n");
  80 + exit(EXIT_FAILURE);
  81 + }
  82 +
  83 + bool is_ok = false;
  84 + sherpa_onnx::OfflineSourceSeparationInput input;
  85 + input.samples.data =
  86 + sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok);
  87 + if (!is_ok) {
  88 + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
  89 + return -1;
  90 + }
  91 +
  92 + fprintf(stderr, "Started\n");
  93 +
  94 + sherpa_onnx::OfflineSourceSeparation sp(config);
  95 +
  96 + const auto begin = std::chrono::steady_clock::now();
  97 + auto output = sp.Process(input);
  98 + const auto end = std::chrono::steady_clock::now();
  99 +
  100 + float elapsed_seconds =
  101 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  102 + .count() /
  103 + 1000.;
  104 +
  105 + is_ok = sherpa_onnx::WriteWave(
  106 + output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(),
  107 + output.stems[0].data[1].data(), output.stems[0].data[0].size());
  108 +
  109 + if (!is_ok) {
  110 + fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str());
  111 + exit(EXIT_FAILURE);
  112 + }
  113 +
  114 + is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate,
  115 + output.stems[1].data[0].data(),
  116 + output.stems[1].data[1].data(),
  117 + output.stems[1].data[0].size());
  118 +
  119 + if (!is_ok) {
  120 + fprintf(stderr, "Failed to write to '%s'\n",
  121 + output_accompaniment_wave.c_str());
  122 + exit(EXIT_FAILURE);
  123 + }
  124 +
  125 + fprintf(stderr, "Done\n");
  126 + fprintf(stderr, "Saved to write to '%s' and '%s'\n",
  127 + output_vocals_wave.c_str(), output_accompaniment_wave.c_str());
  128 +
  129 + float duration =
  130 + input.samples.data[0].size() / static_cast<float>(input.sample_rate);
  131 + fprintf(stderr, "num threads: %d\n", config.model.num_threads);
  132 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  133 + float rtf = elapsed_seconds / duration;
  134 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  135 + elapsed_seconds, duration, rtf);
  136 +
  137 + return 0;
  138 +}
@@ -63,8 +63,9 @@ in sherpa-onnx. @@ -63,8 +63,9 @@ in sherpa-onnx.
63 63
64 // Read a wave file of mono-channel. 64 // Read a wave file of mono-channel.
65 // Return its samples normalized to the range [-1, 1). 65 // Return its samples normalized to the range [-1, 1).
66 -std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,  
67 - bool *is_ok) { 66 +std::vector<std::vector<float>> ReadWaveImpl(std::istream &is,
  67 + int32_t *sampling_rate,
  68 + bool *is_ok) {
68 WaveHeader header{}; 69 WaveHeader header{};
69 is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id)); 70 is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id));
70 71
@@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
144 is.read(reinterpret_cast<char *>(&header.num_channels), 145 is.read(reinterpret_cast<char *>(&header.num_channels),
145 sizeof(header.num_channels)); 146 sizeof(header.num_channels));
146 147
147 - if (header.num_channels != 1) { // we support only single channel for now  
148 - SHERPA_ONNX_LOGE(  
149 - "Warning: %d channels are found. We only use the first channel.\n",  
150 - header.num_channels);  
151 - }  
152 -  
153 is.read(reinterpret_cast<char *>(&header.sample_rate), 148 is.read(reinterpret_cast<char *>(&header.sample_rate),
154 sizeof(header.sample_rate)); 149 sizeof(header.sample_rate));
155 150
@@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
219 214
220 *sampling_rate = header.sample_rate; 215 *sampling_rate = header.sample_rate;
221 216
222 - std::vector<float> ans; 217 + std::vector<std::vector<float>> ans(header.num_channels);
223 218
224 if (header.bits_per_sample == 16 && header.audio_format == 1) { 219 if (header.bits_per_sample == 16 && header.audio_format == 1) {
225 // header.subchunk2_size contains the number of bytes in the data. 220 // header.subchunk2_size contains the number of bytes in the data.
@@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
233 return {}; 228 return {};
234 } 229 }
235 230
236 - ans.resize(samples.size() / header.num_channels); 231 + for (auto &v : ans) {
  232 + v.resize(samples.size() / header.num_channels);
  233 + }
237 234
238 // samples are interleaved 235 // samples are interleaved
239 - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {  
240 - ans[i] = samples[i * header.num_channels] / 32768.; 236 + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
  237 + i += header.num_channels, ++k) {
  238 + for (int32_t c = 0; c != header.num_channels; ++c) {
  239 + ans[c][k] = samples[i + c] / 32768.;
  240 + }
241 } 241 }
242 } else if (header.bits_per_sample == 8 && header.audio_format == 1) { 242 } else if (header.bits_per_sample == 8 && header.audio_format == 1) {
243 // number of samples == number of bytes for 8-bit encoded samples 243 // number of samples == number of bytes for 8-bit encoded samples
@@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
252 return {}; 252 return {};
253 } 253 }
254 254
255 - ans.resize(samples.size() / header.num_channels);  
256 - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {  
257 - // Note(fangjun): We want to normalize each sample into the range [-1, 1]  
258 - // Since each original sample is in the range [0, 256], dividing  
259 - // them by 128 converts them to the range [0, 2];  
260 - // so after subtracting 1, we get the range [-1, 1]  
261 - //  
262 - ans[i] = samples[i * header.num_channels] / 128. - 1; 255 + for (auto &v : ans) {
  256 + v.resize(samples.size() / header.num_channels);
  257 + }
  258 +
  259 + // samples are interleaved
  260 + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
  261 + i += header.num_channels, ++k) {
  262 + for (int32_t c = 0; c != header.num_channels; ++c) {
  263 + // Note(fangjun): We want to normalize each sample into the range [-1,
  264 + // 1] Since each original sample is in the range [0, 256], dividing them
  265 + // by 128 converts them to the range [0, 2]; so after subtracting 1, we
  266 + // get the range [-1, 1]
  267 + //
  268 + ans[c][k] = samples[i + c] / 128. - 1;
  269 + }
263 } 270 }
264 } else if (header.bits_per_sample == 32 && header.audio_format == 1) { 271 } else if (header.bits_per_sample == 32 && header.audio_format == 1) {
265 // 32 here is for int32 272 // 32 here is for int32
@@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
275 return {}; 282 return {};
276 } 283 }
277 284
278 - ans.resize(samples.size() / header.num_channels);  
279 - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {  
280 - ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31); 285 + for (auto &v : ans) {
  286 + v.resize(samples.size() / header.num_channels);
  287 + }
  288 +
  289 + // samples are interleaved
  290 + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
  291 + i += header.num_channels, ++k) {
  292 + for (int32_t c = 0; c != header.num_channels; ++c) {
  293 + ans[c][k] = static_cast<float>(samples[i + c]) / (1 << 31);
  294 + }
281 } 295 }
282 } else if (header.bits_per_sample == 32 && header.audio_format == 3) { 296 } else if (header.bits_per_sample == 32 && header.audio_format == 3) {
283 // 32 here is for float32 297 // 32 here is for float32
@@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, @@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
293 return {}; 307 return {};
294 } 308 }
295 309
296 - ans.resize(samples.size() / header.num_channels);  
297 - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {  
298 - ans[i] = samples[i * header.num_channels]; 310 + for (auto &v : ans) {
  311 + v.resize(samples.size() / header.num_channels);
  312 + }
  313 +
  314 + // samples are interleaved
  315 + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
  316 + i += header.num_channels, ++k) {
  317 + for (int32_t c = 0; c != header.num_channels; ++c) {
  318 + ans[c][k] = samples[i + c];
  319 + }
299 } 320 }
300 } else { 321 } else {
301 SHERPA_ONNX_LOGE( 322 SHERPA_ONNX_LOGE(
@@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, @@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
321 std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, 342 std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
322 bool *is_ok) { 343 bool *is_ok) {
323 auto samples = ReadWaveImpl(is, sampling_rate, is_ok); 344 auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
  345 +
  346 + if (samples.size() > 1) {
  347 + SHERPA_ONNX_LOGE(
  348 + "Warning: %d channels are found. We only use the first channel.\n",
  349 + static_cast<int32_t>(samples.size()));
  350 + }
  351 +
  352 + return samples[0];
  353 +}
  354 +
  355 +std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
  356 + int32_t *sampling_rate,
  357 + bool *is_ok) {
  358 + auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
324 return samples; 359 return samples;
325 } 360 }
326 361
  362 +std::vector<std::vector<float>> ReadWaveMultiChannel(
  363 + const std::string &filename, int32_t *sampling_rate, bool *is_ok) {
  364 + std::ifstream is(filename, std::ifstream::binary);
  365 + return ReadWaveMultiChannel(is, sampling_rate, is_ok);
  366 +}
  367 +
327 } // namespace sherpa_onnx 368 } // namespace sherpa_onnx
@@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, @@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
26 std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, 26 std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
27 bool *is_ok); 27 bool *is_ok);
28 28
  29 +std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
  30 + int32_t *sampling_rate,
  31 + bool *is_ok);
  32 +
  33 +std::vector<std::vector<float>> ReadWaveMultiChannel(
  34 + const std::string &filename, int32_t *sampling_rate, bool *is_ok);
  35 +
29 } // namespace sherpa_onnx 36 } // namespace sherpa_onnx
30 37
31 #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ 38 #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/wave-writer.h" 5 #include "sherpa-onnx/csrc/wave-writer.h"
6 6
  7 +#include <algorithm>
7 #include <cstring> 8 #include <cstring>
8 #include <fstream> 9 #include <fstream>
9 #include <string> 10 #include <string>
@@ -36,12 +37,44 @@ struct WaveHeader { @@ -36,12 +37,44 @@ struct WaveHeader {
36 37
37 } // namespace 38 } // namespace
38 39
39 -int64_t WaveFileSize(int32_t n_samples) {  
40 - return sizeof(WaveHeader) + n_samples * sizeof(int16_t); 40 +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) {
  41 + return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels;
41 } 42 }
42 43
43 void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, 44 void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
44 int32_t n) { 45 int32_t n) {
  46 + WriteWave(buffer, sampling_rate, samples, nullptr, n);
  47 +}
  48 +
  49 +bool WriteWave(const std::string &filename, int32_t sampling_rate,
  50 + const float *samples, int32_t n) {
  51 + return WriteWave(filename, sampling_rate, samples, nullptr, n);
  52 +}
  53 +
  54 +bool WriteWave(const std::string &filename, int32_t sampling_rate,
  55 + const float *samples_ch0, const float *samples_ch1, int32_t n) {
  56 + std::string buffer;
  57 + buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2));
  58 +
  59 + WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n);
  60 +
  61 + std::ofstream os(filename, std::ios::binary);
  62 + if (!os) {
  63 + SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());
  64 + return false;
  65 + }
  66 +
  67 + os << buffer;
  68 + if (!os) {
  69 + SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());
  70 + return false;
  71 + }
  72 +
  73 + return true;
  74 +}
  75 +
  76 +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
  77 + const float *samples_ch1, int32_t n) {
45 WaveHeader header{}; 78 WaveHeader header{};
46 header.chunk_id = 0x46464952; // FFIR 79 header.chunk_id = 0x46464952; // FFIR
47 header.format = 0x45564157; // EVAW 80 header.format = 0x45564157; // EVAW
@@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
49 header.subchunk1_size = 16; // 16 for PCM 82 header.subchunk1_size = 16; // 16 for PCM
50 header.audio_format = 1; // PCM =1 83 header.audio_format = 1; // PCM =1
51 84
52 - int32_t num_channels = 1; 85 + int32_t num_channels = samples_ch1 == nullptr ? 1 : 2;
53 int32_t bits_per_sample = 16; // int16_t 86 int32_t bits_per_sample = 16; // int16_t
  87 +
54 header.num_channels = num_channels; 88 header.num_channels = num_channels;
55 header.sample_rate = sampling_rate; 89 header.sample_rate = sampling_rate;
56 header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; 90 header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8;
@@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
61 95
62 header.chunk_size = 36 + header.subchunk2_size; 96 header.chunk_size = 36 + header.subchunk2_size;
63 97
64 - std::vector<int16_t> samples_int16(n); 98 + std::vector<int16_t> samples_int16_ch0(n);
65 for (int32_t i = 0; i != n; ++i) { 99 for (int32_t i = 0; i != n; ++i) {
66 - samples_int16[i] = samples[i] * 32767; 100 + samples_int16_ch0[i] = std::min<int32_t>(samples_ch0[i] * 32767, 32767);
  101 + }
  102 +
  103 + std::vector<int16_t> samples_int16_ch1;
  104 + if (samples_ch1) {
  105 + samples_int16_ch1.resize(n);
  106 + for (int32_t i = 0; i != n; ++i) {
  107 + samples_int16_ch1[i] = std::min<int32_t>(samples_ch1[i] * 32767, 32767);
  108 + }
67 } 109 }
68 110
69 memcpy(buffer, &header, sizeof(WaveHeader)); 111 memcpy(buffer, &header, sizeof(WaveHeader));
70 - memcpy(buffer + sizeof(WaveHeader), samples_int16.data(),  
71 - n * sizeof(int16_t));  
72 -}  
73 112
74 -bool WriteWave(const std::string &filename, int32_t sampling_rate,  
75 - const float *samples, int32_t n) {  
76 - std::string buffer;  
77 - buffer.resize(WaveFileSize(n));  
78 - WriteWave(buffer.data(), sampling_rate, samples, n);  
79 - std::ofstream os(filename, std::ios::binary);  
80 - if (!os) {  
81 - SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());  
82 - return false;  
83 - }  
84 - os << buffer;  
85 - if (!os) {  
86 - SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());  
87 - return false; 113 + if (samples_ch1 == nullptr) {
  114 + memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(),
  115 + n * sizeof(int16_t));
  116 + } else {
  117 + auto p = reinterpret_cast<int16_t *>(buffer + sizeof(WaveHeader));
  118 +
  119 + for (int32_t i = 0; i != n; ++i) {
  120 + p[2 * i] = samples_int16_ch0[i];
  121 + p[2 * i + 1] = samples_int16_ch1[i];
  122 + }
88 } 123 }
89 - return true;  
90 } 124 }
91 125
92 } // namespace sherpa_onnx 126 } // namespace sherpa_onnx
@@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate, @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate,
25 void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, 25 void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
26 int32_t n); 26 int32_t n);
27 27
28 -int64_t WaveFileSize(int32_t n_samples); 28 +bool WriteWave(const std::string &filename, int32_t sampling_rate,
  29 + const float *samples_ch0, const float *samples_ch1, int32_t n);
  30 +
  31 +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
  32 + const float *samples_ch1, int32_t n);
  33 +
  34 +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1);
29 35
30 } // namespace sherpa_onnx 36 } // namespace sherpa_onnx
31 37
@@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) { @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) {
77 console.log('ArrayBuffer length:', arrayBuffer.byteLength); 77 console.log('ArrayBuffer length:', arrayBuffer.byteLength);
78 78
79 const uint8Array = new Uint8Array(arrayBuffer); 79 const uint8Array = new Uint8Array(arrayBuffer);
80 - const wave = readWaveFromBinaryData(uint8Array); 80 + const wave = readWaveFromBinaryData(uint8Array, Module);
81 if (wave == null) { 81 if (wave == null) {
82 alert( 82 alert(
83 `${file.name} is not a valid .wav file. Please select a *.wav file`); 83 `${file.name} is not a valid .wav file. Please select a *.wav file`);