Committed by
GitHub
Add C++ runtime for spleeter about source separation (#2242)
正在显示
28 个修改的文件
包含
1267 行增加
和
72 行删除
| @@ -56,6 +56,7 @@ def get_binaries(): | @@ -56,6 +56,7 @@ def get_binaries(): | ||
| 56 | "sherpa-onnx-offline-denoiser", | 56 | "sherpa-onnx-offline-denoiser", |
| 57 | "sherpa-onnx-offline-language-identification", | 57 | "sherpa-onnx-offline-language-identification", |
| 58 | "sherpa-onnx-offline-punctuation", | 58 | "sherpa-onnx-offline-punctuation", |
| 59 | + "sherpa-onnx-offline-source-separation", | ||
| 59 | "sherpa-onnx-offline-speaker-diarization", | 60 | "sherpa-onnx-offline-speaker-diarization", |
| 60 | "sherpa-onnx-offline-tts", | 61 | "sherpa-onnx-offline-tts", |
| 61 | "sherpa-onnx-offline-tts-play", | 62 | "sherpa-onnx-offline-tts-play", |
| @@ -217,8 +217,8 @@ def main(name): | @@ -217,8 +217,8 @@ def main(name): | ||
| 217 | # for the batchnormalization in torch, | 217 | # for the batchnormalization in torch, |
| 218 | # default input shape is NCHW | 218 | # default input shape is NCHW |
| 219 | 219 | ||
| 220 | - # NHWC to NCHW | ||
| 221 | - torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2)) | 220 | + torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2)) |
| 221 | + torch_y1_out = torch_y1_out.permute(1, 0, 2, 3) | ||
| 222 | 222 | ||
| 223 | # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) | 223 | # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape) |
| 224 | assert torch.allclose( | 224 | assert torch.allclose( |
| @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix): | @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix): | ||
| 46 | 46 | ||
| 47 | def export(model, prefix): | 47 | def export(model, prefix): |
| 48 | num_splits = 1 | 48 | num_splits = 1 |
| 49 | - x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32) | 49 | + x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32) |
| 50 | 50 | ||
| 51 | filename = f"./2stems/{prefix}.onnx" | 51 | filename = f"./2stems/{prefix}.onnx" |
| 52 | torch.onnx.export( | 52 | torch.onnx.export( |
| @@ -56,7 +56,7 @@ def export(model, prefix): | @@ -56,7 +56,7 @@ def export(model, prefix): | ||
| 56 | input_names=["x"], | 56 | input_names=["x"], |
| 57 | output_names=["y"], | 57 | output_names=["y"], |
| 58 | dynamic_axes={ | 58 | dynamic_axes={ |
| 59 | - "x": {0: "num_splits"}, | 59 | + "x": {1: "num_splits"}, |
| 60 | }, | 60 | }, |
| 61 | opset_version=13, | 61 | opset_version=13, |
| 62 | ) | 62 | ) |
| @@ -101,13 +101,17 @@ def main(): | @@ -101,13 +101,17 @@ def main(): | ||
| 101 | print("y2", y.shape, y.dtype) | 101 | print("y2", y.shape, y.dtype) |
| 102 | 102 | ||
| 103 | y = y.abs() | 103 | y = y.abs() |
| 104 | - y = y.permute(0, 3, 1, 2) | ||
| 105 | - # (1, 2, 512, 1024) | 104 | + |
| 105 | + y = y.permute(3, 0, 1, 2) | ||
| 106 | + # (2, 1, 512, 1024) | ||
| 106 | print("y3", y.shape, y.dtype) | 107 | print("y3", y.shape, y.dtype) |
| 107 | 108 | ||
| 108 | vocals_spec = vocals(y) | 109 | vocals_spec = vocals(y) |
| 109 | accompaniment_spec = accompaniment(y) | 110 | accompaniment_spec = accompaniment(y) |
| 110 | 111 | ||
| 112 | + vocals_spec = vocals_spec.permute(1, 0, 2, 3) | ||
| 113 | + accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3) | ||
| 114 | + | ||
| 111 | sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 | 115 | sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 |
| 112 | print( | 116 | print( |
| 113 | "vocals_spec", | 117 | "vocals_spec", |
| @@ -12,15 +12,14 @@ from separate import load_audio | @@ -12,15 +12,14 @@ from separate import load_audio | ||
| 12 | 12 | ||
| 13 | """ | 13 | """ |
| 14 | ----------inputs for ./2stems/vocals.onnx---------- | 14 | ----------inputs for ./2stems/vocals.onnx---------- |
| 15 | -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) | 15 | +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) |
| 16 | ----------outputs for ./2stems/vocals.onnx---------- | 16 | ----------outputs for ./2stems/vocals.onnx---------- |
| 17 | -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) | 17 | +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) |
| 18 | 18 | ||
| 19 | ----------inputs for ./2stems/accompaniment.onnx---------- | 19 | ----------inputs for ./2stems/accompaniment.onnx---------- |
| 20 | -NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024]) | 20 | +NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024]) |
| 21 | ----------outputs for ./2stems/accompaniment.onnx---------- | 21 | ----------outputs for ./2stems/accompaniment.onnx---------- |
| 22 | -NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024]) | ||
| 23 | - | 22 | +NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024]) |
| 24 | """ | 23 | """ |
| 25 | 24 | ||
| 26 | 25 | ||
| @@ -123,16 +122,16 @@ def main(): | @@ -123,16 +122,16 @@ def main(): | ||
| 123 | if padding > 0: | 122 | if padding > 0: |
| 124 | stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) | 123 | stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding)) |
| 125 | stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) | 124 | stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding)) |
| 126 | - stft0 = stft0.reshape(-1, 1, 512, 1024) | ||
| 127 | - stft1 = stft1.reshape(-1, 1, 512, 1024) | 125 | + stft0 = stft0.reshape(1, -1, 512, 1024) |
| 126 | + stft1 = stft1.reshape(1, -1, 512, 1024) | ||
| 128 | 127 | ||
| 129 | - stft_01 = torch.cat([stft0, stft1], axis=1) | 128 | + stft_01 = torch.cat([stft0, stft1], axis=0) |
| 130 | 129 | ||
| 131 | print("stft_01", stft_01.shape, stft_01.dtype) | 130 | print("stft_01", stft_01.shape, stft_01.dtype) |
| 132 | 131 | ||
| 133 | vocals_spec = vocals(stft_01) | 132 | vocals_spec = vocals(stft_01) |
| 134 | accompaniment_spec = accompaniment(stft_01) | 133 | accompaniment_spec = accompaniment(stft_01) |
| 135 | - # (num_splits, num_channels, 512, 1024) | 134 | + # (num_channels, num_splits, 512, 1024) |
| 136 | 135 | ||
| 137 | sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 | 136 | sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10 |
| 138 | 137 | ||
| @@ -142,8 +141,8 @@ def main(): | @@ -142,8 +141,8 @@ def main(): | ||
| 142 | for name, spec in zip( | 141 | for name, spec in zip( |
| 143 | ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] | 142 | ["vocals", "accompaniment"], [vocals_spec, accompaniment_spec] |
| 144 | ): | 143 | ): |
| 145 | - spec_c0 = spec[:, 0, :, :] | ||
| 146 | - spec_c1 = spec[:, 1, :, :] | 144 | + spec_c0 = spec[0] |
| 145 | + spec_c1 = spec[1] | ||
| 147 | 146 | ||
| 148 | spec_c0 = spec_c0.reshape(-1, 1024) | 147 | spec_c0 = spec_c0.reshape(-1, 1024) |
| 149 | spec_c1 = spec_c1.reshape(-1, 1024) | 148 | spec_c1 = spec_c1.reshape(-1, 1024) |
| @@ -67,6 +67,14 @@ class UNet(torch.nn.Module): | @@ -67,6 +67,14 @@ class UNet(torch.nn.Module): | ||
| 67 | self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) | 67 | self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3) |
| 68 | 68 | ||
| 69 | def forward(self, x): | 69 | def forward(self, x): |
| 70 | + """ | ||
| 71 | + Args: | ||
| 72 | + x: (num_audio_channels, num_splits, 512, 1024) | ||
| 73 | + Returns: | ||
| 74 | + y: (num_audio_channels, num_splits, 512, 1024) | ||
| 75 | + """ | ||
| 76 | + x = x.permute(1, 0, 2, 3) | ||
| 77 | + | ||
| 70 | in_x = x | 78 | in_x = x |
| 71 | # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) | 79 | # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024) |
| 72 | x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) | 80 | x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0) |
| @@ -147,4 +155,5 @@ class UNet(torch.nn.Module): | @@ -147,4 +155,5 @@ class UNet(torch.nn.Module): | ||
| 147 | up7 = self.up7(batch12) | 155 | up7 = self.up7(batch12) |
| 148 | up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) | 156 | up7 = torch.sigmoid(up7) # (3, 2, 512, 1024) |
| 149 | 157 | ||
| 150 | - return up7 * in_x | 158 | + ans = up7 * in_x |
| 159 | + return ans.permute(1, 0, 2, 3) |
| @@ -50,6 +50,13 @@ set(sources | @@ -50,6 +50,13 @@ set(sources | ||
| 50 | offline-rnn-lm.cc | 50 | offline-rnn-lm.cc |
| 51 | offline-sense-voice-model-config.cc | 51 | offline-sense-voice-model-config.cc |
| 52 | offline-sense-voice-model.cc | 52 | offline-sense-voice-model.cc |
| 53 | + | ||
| 54 | + offline-source-separation-impl.cc | ||
| 55 | + offline-source-separation-model-config.cc | ||
| 56 | + offline-source-separation-spleeter-model-config.cc | ||
| 57 | + offline-source-separation-spleeter-model.cc | ||
| 58 | + offline-source-separation.cc | ||
| 59 | + | ||
| 53 | offline-stream.cc | 60 | offline-stream.cc |
| 54 | offline-tdnn-ctc-model.cc | 61 | offline-tdnn-ctc-model.cc |
| 55 | offline-tdnn-model-config.cc | 62 | offline-tdnn-model-config.cc |
| @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 326 | add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) | 333 | add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) |
| 327 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 334 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 328 | add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) | 335 | add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) |
| 336 | + add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc) | ||
| 329 | add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) | 337 | add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) |
| 330 | add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc) | 338 | add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc) |
| 331 | 339 | ||
| @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 346 | sherpa-onnx-offline-language-identification | 354 | sherpa-onnx-offline-language-identification |
| 347 | sherpa-onnx-offline-parallel | 355 | sherpa-onnx-offline-parallel |
| 348 | sherpa-onnx-offline-punctuation | 356 | sherpa-onnx-offline-punctuation |
| 357 | + sherpa-onnx-offline-source-separation | ||
| 349 | sherpa-onnx-online-punctuation | 358 | sherpa-onnx-online-punctuation |
| 350 | sherpa-onnx-vad | 359 | sherpa-onnx-vad |
| 351 | ) | 360 | ) |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +std::unique_ptr<OfflineSourceSeparationImpl> | ||
| 14 | +OfflineSourceSeparationImpl::Create( | ||
| 15 | + const OfflineSourceSeparationConfig &config) { | ||
| 16 | + // TODO(fangjun): Support other models | ||
| 17 | + return std::make_unique<OfflineSourceSeparationSpleeterImpl>(config); | ||
| 18 | +} | ||
| 19 | + | ||
| 20 | +template <typename Manager> | ||
| 21 | +std::unique_ptr<OfflineSourceSeparationImpl> | ||
| 22 | +OfflineSourceSeparationImpl::Create( | ||
| 23 | + Manager *mgr, const OfflineSourceSeparationConfig &config) { | ||
| 24 | + // TODO(fangjun): Support other models | ||
| 25 | + return std::make_unique<OfflineSourceSeparationSpleeterImpl>(mgr, config); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | +template std::unique_ptr<OfflineSourceSeparationImpl> | ||
| 30 | +OfflineSourceSeparationImpl::Create( | ||
| 31 | + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); | ||
| 32 | +#endif | ||
| 33 | + | ||
| 34 | +#if __OHOS__ | ||
| 35 | +template std::unique_ptr<OfflineSourceSeparationImpl> | ||
| 36 | +OfflineSourceSeparationImpl::Create( | ||
| 37 | + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); | ||
| 38 | +#endif | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-source-separation.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineSourceSeparationImpl { | ||
| 15 | + public: | ||
| 16 | + static std::unique_ptr<OfflineSourceSeparationImpl> Create( | ||
| 17 | + const OfflineSourceSeparationConfig &config); | ||
| 18 | + | ||
| 19 | + template <typename Manager> | ||
| 20 | + static std::unique_ptr<OfflineSourceSeparationImpl> Create( | ||
| 21 | + Manager *mgr, const OfflineSourceSeparationConfig &config); | ||
| 22 | + | ||
| 23 | + virtual ~OfflineSourceSeparationImpl() = default; | ||
| 24 | + | ||
| 25 | + virtual OfflineSourceSeparationOutput Process( | ||
| 26 | + const OfflineSourceSeparationInput &input) const = 0; | ||
| 27 | + | ||
| 28 | + virtual int32_t GetOutputSampleRate() const = 0; | ||
| 29 | + | ||
| 30 | + virtual int32_t GetNumberOfStems() const = 0; | ||
| 31 | +}; | ||
| 32 | + | ||
| 33 | +} // namespace sherpa_onnx | ||
| 34 | + | ||
| 35 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" | ||
| 6 | + | ||
| 7 | +namespace sherpa_onnx { | ||
| 8 | + | ||
| 9 | +void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) { | ||
| 10 | + spleeter.Register(po); | ||
| 11 | + | ||
| 12 | + po->Register("num-threads", &num_threads, | ||
| 13 | + "Number of threads to run the neural network"); | ||
| 14 | + | ||
| 15 | + po->Register("debug", &debug, | ||
| 16 | + "true to print model information while loading it."); | ||
| 17 | + | ||
| 18 | + po->Register("provider", &provider, | ||
| 19 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +bool OfflineSourceSeparationModelConfig::Validate() const { | ||
| 23 | + return spleeter.Validate(); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +std::string OfflineSourceSeparationModelConfig::ToString() const { | ||
| 27 | + std::ostringstream os; | ||
| 28 | + | ||
| 29 | + os << "OfflineSourceSeparationModelConfig("; | ||
| 30 | + os << "spleeter=" << spleeter.ToString() << ", "; | ||
| 31 | + os << "num_threads=" << num_threads << ", "; | ||
| 32 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 33 | + os << "provider=\"" << provider << "\")"; | ||
| 34 | + | ||
| 35 | + return os.str(); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct OfflineSourceSeparationModelConfig { | ||
| 16 | + OfflineSourceSeparationSpleeterModelConfig spleeter; | ||
| 17 | + | ||
| 18 | + int32_t num_threads = 1; | ||
| 19 | + bool debug = false; | ||
| 20 | + std::string provider = "cpu"; | ||
| 21 | + | ||
| 22 | + OfflineSourceSeparationModelConfig() = default; | ||
| 23 | + | ||
| 24 | + OfflineSourceSeparationModelConfig( | ||
| 25 | + const OfflineSourceSeparationSpleeterModelConfig &spleeter, | ||
| 26 | + int32_t num_threads, bool debug, const std::string &provider) | ||
| 27 | + : spleeter(spleeter), | ||
| 28 | + num_threads(num_threads), | ||
| 29 | + debug(debug), | ||
| 30 | + provider(provider) {} | ||
| 31 | + | ||
| 32 | + void Register(ParseOptions *po); | ||
| 33 | + | ||
| 34 | + bool Validate() const; | ||
| 35 | + | ||
| 36 | + std::string ToString() const; | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include "Eigen/Dense" | ||
| 9 | +#include "kaldi-native-fbank/csrc/istft.h" | ||
| 10 | +#include "kaldi-native-fbank/csrc/stft.h" | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-source-separation.h" | ||
| 14 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 15 | +#include "sherpa-onnx/csrc/resample.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl { | ||
| 20 | + public: | ||
| 21 | + OfflineSourceSeparationSpleeterImpl( | ||
| 22 | + const OfflineSourceSeparationConfig &config) | ||
| 23 | + : config_(config), model_(config_.model) {} | ||
| 24 | + | ||
| 25 | + template <typename Manager> | ||
| 26 | + OfflineSourceSeparationSpleeterImpl( | ||
| 27 | + Manager *mgr, const OfflineSourceSeparationConfig &config) | ||
| 28 | + : config_(config), model_(mgr, config_.model) {} | ||
| 29 | + | ||
| 30 | + OfflineSourceSeparationOutput Process( | ||
| 31 | + const OfflineSourceSeparationInput &input) const override { | ||
| 32 | + const OfflineSourceSeparationInput *p_input = &input; | ||
| 33 | + OfflineSourceSeparationInput tmp_input; | ||
| 34 | + | ||
| 35 | + int32_t output_sample_rate = GetOutputSampleRate(); | ||
| 36 | + | ||
| 37 | + if (input.sample_rate != output_sample_rate) { | ||
| 38 | + SHERPA_ONNX_LOGE( | ||
| 39 | + "Creating a resampler:\n" | ||
| 40 | + " in_sample_rate: %d\n" | ||
| 41 | + " output_sample_rate: %d\n", | ||
| 42 | + input.sample_rate, output_sample_rate); | ||
| 43 | + | ||
| 44 | + float min_freq = std::min<int32_t>(input.sample_rate, output_sample_rate); | ||
| 45 | + float lowpass_cutoff = 0.99 * 0.5 * min_freq; | ||
| 46 | + | ||
| 47 | + int32_t lowpass_filter_width = 6; | ||
| 48 | + auto resampler = std::make_unique<LinearResample>( | ||
| 49 | + input.sample_rate, output_sample_rate, lowpass_cutoff, | ||
| 50 | + lowpass_filter_width); | ||
| 51 | + | ||
| 52 | + std::vector<float> s; | ||
| 53 | + for (const auto &samples : input.samples.data) { | ||
| 54 | + resampler->Reset(); | ||
| 55 | + resampler->Resample(samples.data(), samples.size(), true, &s); | ||
| 56 | + tmp_input.samples.data.push_back(std::move(s)); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + tmp_input.sample_rate = output_sample_rate; | ||
| 60 | + p_input = &tmp_input; | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + if (p_input->samples.data.size() > 1) { | ||
| 64 | + if (config_.model.debug) { | ||
| 65 | + SHERPA_ONNX_LOGE("input ch1 samples size: %d", | ||
| 66 | + static_cast<int32_t>(p_input->samples.data[1].size())); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) { | ||
| 70 | + SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d", | ||
| 71 | + static_cast<int32_t>(p_input->samples.data[0].size()), | ||
| 72 | + static_cast<int32_t>(p_input->samples.data[1].size())); | ||
| 73 | + | ||
| 74 | + SHERPA_ONNX_EXIT(-1); | ||
| 75 | + } | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + auto stft_ch0 = ComputeStft(*p_input, 0); | ||
| 79 | + | ||
| 80 | + auto stft_ch1 = ComputeStft(*p_input, 1); | ||
| 81 | + knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1; | ||
| 82 | + | ||
| 83 | + int32_t num_frames = stft_ch0.num_frames; | ||
| 84 | + int32_t fft_bins = stft_ch0.real.size() / num_frames; | ||
| 85 | + | ||
| 86 | + int32_t pad = 512 - (stft_ch0.num_frames % 512); | ||
| 87 | + if (pad < 512) { | ||
| 88 | + num_frames += pad; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + if (num_frames % 512) { | ||
| 92 | + SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d", | ||
| 93 | + num_frames, num_frames % 512); | ||
| 94 | + SHERPA_ONNX_EXIT(-1); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + Eigen::VectorXf real(2 * num_frames * 1024); | ||
| 98 | + Eigen::VectorXf imag(2 * num_frames * 1024); | ||
| 99 | + real.setZero(); | ||
| 100 | + imag.setZero(); | ||
| 101 | + | ||
| 102 | + float *p_real = &real[0]; | ||
| 103 | + float *p_imag = &imag[0]; | ||
| 104 | + | ||
| 105 | + // copy stft result of channel 0 | ||
| 106 | + for (int32_t i = 0; i != stft_ch0.num_frames; ++i) { | ||
| 107 | + std::copy(stft_ch0.real.data() + i * fft_bins, | ||
| 108 | + stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i); | ||
| 109 | + | ||
| 110 | + std::copy(stft_ch0.imag.data() + i * fft_bins, | ||
| 111 | + stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i); | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + p_real += num_frames * 1024; | ||
| 115 | + p_imag += num_frames * 1024; | ||
| 116 | + | ||
| 117 | + // copy stft result of channel 1 | ||
| 118 | + for (int32_t i = 0; i != stft_ch1.num_frames; ++i) { | ||
| 119 | + std::copy(p_stft_ch1->real.data() + i * fft_bins, | ||
| 120 | + p_stft_ch1->real.data() + i * fft_bins + 1024, | ||
| 121 | + p_real + 1024 * i); | ||
| 122 | + | ||
| 123 | + std::copy(p_stft_ch1->imag.data() + i * fft_bins, | ||
| 124 | + p_stft_ch1->imag.data() + i * fft_bins + 1024, | ||
| 125 | + p_imag + 1024 * i); | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt(); | ||
| 129 | + | ||
| 130 | + auto memory_info = | ||
| 131 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 132 | + | ||
| 133 | + std::array<int64_t, 4> x_shape{2, num_frames / 512, 512, 1024}; | ||
| 134 | + Ort::Value x_tensor = Ort::Value::CreateTensor( | ||
| 135 | + memory_info, &x[0], x.size(), x_shape.data(), x_shape.size()); | ||
| 136 | + | ||
| 137 | + Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor)); | ||
| 138 | + Ort::Value accompaniment_spec_tensor = | ||
| 139 | + model_.RunAccompaniment(std::move(x_tensor)); | ||
| 140 | + | ||
| 141 | + Eigen::VectorXf vocals_spec = Eigen::Map<Eigen::VectorXf>( | ||
| 142 | + vocals_spec_tensor.GetTensorMutableData<float>(), x.size()); | ||
| 143 | + | ||
| 144 | + Eigen::VectorXf accompaniment_spec = Eigen::Map<Eigen::VectorXf>( | ||
| 145 | + accompaniment_spec_tensor.GetTensorMutableData<float>(), x.size()); | ||
| 146 | + | ||
| 147 | + Eigen::VectorXf sum_spec = vocals_spec.array().square() + | ||
| 148 | + accompaniment_spec.array().square() + 1e-10; | ||
| 149 | + | ||
| 150 | + vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array(); | ||
| 151 | + | ||
| 152 | + accompaniment_spec = | ||
| 153 | + (accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array(); | ||
| 154 | + | ||
| 155 | + auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0); | ||
| 156 | + auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1); | ||
| 157 | + | ||
| 158 | + auto accompaniment_samples_ch0 = | ||
| 159 | + ProcessSpec(accompaniment_spec, stft_ch0, 0); | ||
| 160 | + auto accompaniment_samples_ch1 = | ||
| 161 | + ProcessSpec(accompaniment_spec, *p_stft_ch1, 1); | ||
| 162 | + | ||
| 163 | + OfflineSourceSeparationOutput ans; | ||
| 164 | + ans.sample_rate = GetOutputSampleRate(); | ||
| 165 | + | ||
| 166 | + ans.stems.resize(2); | ||
| 167 | + ans.stems[0].data.reserve(2); | ||
| 168 | + ans.stems[1].data.reserve(2); | ||
| 169 | + | ||
| 170 | + ans.stems[0].data.push_back(std::move(vocals_samples_ch0)); | ||
| 171 | + ans.stems[0].data.push_back(std::move(vocals_samples_ch1)); | ||
| 172 | + | ||
| 173 | + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0)); | ||
| 174 | + ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1)); | ||
| 175 | + | ||
| 176 | + return ans; | ||
| 177 | + } | ||
| 178 | + | ||
| 179 | + int32_t GetOutputSampleRate() const override { | ||
| 180 | + return model_.GetMetaData().sample_rate; | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + int32_t GetNumberOfStems() const override { | ||
| 184 | + return model_.GetMetaData().num_stems; | ||
| 185 | + } | ||
| 186 | + | ||
| 187 | + private: | ||
| 188 | + // spec is of shape (2, num_chunks, 512, 1024) | ||
| 189 | + std::vector<float> ProcessSpec(const Eigen::VectorXf &spec, | ||
| 190 | + const knf::StftResult &stft, | ||
| 191 | + int32_t channel) const { | ||
| 192 | + int32_t fft_bins = stft.real.size() / stft.num_frames; | ||
| 193 | + | ||
| 194 | + Eigen::VectorXf mask(stft.real.size()); | ||
| 195 | + mask.setZero(); | ||
| 196 | + | ||
| 197 | + float *p_mask = &mask[0]; | ||
| 198 | + | ||
| 199 | + // assume there are 2 channels | ||
| 200 | + const float *p_spec = &spec[0] + (spec.size() / 2) * channel; | ||
| 201 | + | ||
| 202 | + for (int32_t i = 0; i != stft.num_frames; ++i) { | ||
| 203 | + std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024, | ||
| 204 | + p_mask + i * fft_bins); | ||
| 205 | + } | ||
| 206 | + | ||
| 207 | + knf::StftResult masked_stft; | ||
| 208 | + | ||
| 209 | + masked_stft.num_frames = stft.num_frames; | ||
| 210 | + masked_stft.real.resize(stft.real.size()); | ||
| 211 | + masked_stft.imag.resize(stft.imag.size()); | ||
| 212 | + | ||
| 213 | + Eigen::Map<Eigen::VectorXf>(masked_stft.real.data(), | ||
| 214 | + masked_stft.real.size()) = | ||
| 215 | + mask.array() * | ||
| 216 | + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.real.data()), | ||
| 217 | + stft.real.size()) | ||
| 218 | + .array(); | ||
| 219 | + | ||
| 220 | + Eigen::Map<Eigen::VectorXf>(masked_stft.imag.data(), | ||
| 221 | + masked_stft.imag.size()) = | ||
| 222 | + mask.array() * | ||
| 223 | + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.imag.data()), | ||
| 224 | + stft.imag.size()) | ||
| 225 | + .array(); | ||
| 226 | + | ||
| 227 | + auto stft_config = GetStftConfig(); | ||
| 228 | + knf::IStft istft(stft_config); | ||
| 229 | + | ||
| 230 | + return istft.Compute(masked_stft); | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input, | ||
| 234 | + int32_t ch) const { | ||
| 235 | + if (ch >= input.samples.data.size()) { | ||
| 236 | + SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch, | ||
| 237 | + static_cast<int32_t>(input.samples.data.size())); | ||
| 238 | + SHERPA_ONNX_EXIT(-1); | ||
| 239 | + } | ||
| 240 | + | ||
| 241 | + if (input.samples.data[ch].empty()) { | ||
| 242 | + return {}; | ||
| 243 | + } | ||
| 244 | + | ||
| 245 | + return ComputeStft(input.samples.data[ch]); | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + knf::StftResult ComputeStft(const std::vector<float> &samples) const { | ||
| 249 | + auto stft_config = GetStftConfig(); | ||
| 250 | + knf::Stft stft(stft_config); | ||
| 251 | + | ||
| 252 | + return stft.Compute(samples.data(), samples.size()); | ||
| 253 | + } | ||
| 254 | + | ||
| 255 | + knf::StftConfig GetStftConfig() const { | ||
| 256 | + const auto &meta = model_.GetMetaData(); | ||
| 257 | + | ||
| 258 | + knf::StftConfig stft_config; | ||
| 259 | + stft_config.n_fft = meta.n_fft; | ||
| 260 | + stft_config.hop_length = meta.hop_length; | ||
| 261 | + stft_config.win_length = meta.window_length; | ||
| 262 | + stft_config.window_type = meta.window_type; | ||
| 263 | + stft_config.center = meta.center; | ||
| 264 | + stft_config.center = false; | ||
| 265 | + | ||
| 266 | + return stft_config; | ||
| 267 | + } | ||
| 268 | + | ||
| 269 | + private: | ||
| 270 | + OfflineSourceSeparationConfig config_; | ||
| 271 | + OfflineSourceSeparationSpleeterModel model_; | ||
| 272 | +}; | ||
| 273 | + | ||
| 274 | +} // namespace sherpa_onnx | ||
| 275 | + | ||
| 276 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model"); | ||
| 14 | + | ||
| 15 | + po->Register("spleeter-accompaniment", &accompaniment, | ||
| 16 | + "Path to the spleeter accompaniment model"); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OfflineSourceSeparationSpleeterModelConfig::Validate() const { | ||
| 20 | + if (vocals.empty()) { | ||
| 21 | + SHERPA_ONNX_LOGE("Please provide --spleeter-vocals"); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + if (!FileExists(vocals)) { | ||
| 26 | + SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str()); | ||
| 27 | + return false; | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + if (accompaniment.empty()) { | ||
| 31 | + SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment"); | ||
| 32 | + return false; | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + if (!FileExists(accompaniment)) { | ||
| 36 | + SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ", | ||
| 37 | + accompaniment.c_str()); | ||
| 38 | + return false; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + return true; | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const { | ||
| 45 | + std::ostringstream os; | ||
| 46 | + | ||
| 47 | + os << "OfflineSourceSeparationSpleeterModelConfig("; | ||
| 48 | + os << "vocals=\"" << vocals << "\", "; | ||
| 49 | + os << "accompaniment=\"" << accompaniment << "\")"; | ||
| 50 | + | ||
| 51 | + return os.str(); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h" | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct OfflineSourceSeparationSpleeterModelConfig { | ||
| 16 | + std::string vocals; | ||
| 17 | + | ||
| 18 | + std::string accompaniment; | ||
| 19 | + | ||
| 20 | + OfflineSourceSeparationSpleeterModelConfig() = default; | ||
| 21 | + | ||
| 22 | + OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals, | ||
| 23 | + const std::string &accompaniment) | ||
| 24 | + : vocals(vocals), accompaniment(accompaniment) {} | ||
| 25 | + | ||
| 26 | + void Register(ParseOptions *po); | ||
| 27 | + | ||
| 28 | + bool Validate() const; | ||
| 29 | + | ||
| 30 | + std::string ToString() const; | ||
| 31 | +}; | ||
| 32 | + | ||
| 33 | +} // namespace sherpa_onnx | ||
| 34 | + | ||
| 35 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <unordered_map> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +// See also | ||
| 14 | +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py | ||
| 15 | +struct OfflineSourceSeparationSpleeterModelMetaData { | ||
| 16 | + int32_t sample_rate = 44100; | ||
| 17 | + int32_t num_stems = 2; | ||
| 18 | + | ||
| 19 | + int32_t n_fft = 4096; | ||
| 20 | + int32_t hop_length = 1024; | ||
| 21 | + int32_t window_length = 4096; | ||
| 22 | + bool center = false; | ||
| 23 | + std::string window_type = "hann"; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#if __OHOS__ | ||
| 18 | +#include "rawfile/raw_file_manager.h" | ||
| 19 | +#endif | ||
| 20 | + | ||
| 21 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 22 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 23 | +#include "sherpa-onnx/csrc/session.h" | ||
| 24 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +class OfflineSourceSeparationSpleeterModel::Impl { | ||
| 29 | + public: | ||
| 30 | + explicit Impl(const OfflineSourceSeparationModelConfig &config) | ||
| 31 | + : config_(config), | ||
| 32 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 33 | + sess_opts_(GetSessionOptions(config)), | ||
| 34 | + allocator_{} { | ||
| 35 | + { | ||
| 36 | + auto buf = ReadFile(config.spleeter.vocals); | ||
| 37 | + InitVocals(buf.data(), buf.size()); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + { | ||
| 41 | + auto buf = ReadFile(config.spleeter.accompaniment); | ||
| 42 | + InitAccompaniment(buf.data(), buf.size()); | ||
| 43 | + } | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + template <typename Manager> | ||
| 47 | + Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config) | ||
| 48 | + : config_(config), | ||
| 49 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 50 | + sess_opts_(GetSessionOptions(config)), | ||
| 51 | + allocator_{} { | ||
| 52 | + { | ||
| 53 | + auto buf = ReadFile(mgr, config.spleeter.vocals); | ||
| 54 | + InitVocals(buf.data(), buf.size()); | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + { | ||
| 58 | + auto buf = ReadFile(mgr, config.spleeter.accompaniment); | ||
| 59 | + InitAccompaniment(buf.data(), buf.size()); | ||
| 60 | + } | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const { | ||
| 64 | + return meta_; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + Ort::Value RunVocals(Ort::Value x) const { | ||
| 68 | + auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1, | ||
| 69 | + vocals_output_names_ptr_.data(), | ||
| 70 | + vocals_output_names_ptr_.size()); | ||
| 71 | + return std::move(out[0]); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + Ort::Value RunAccompaniment(Ort::Value x) const { | ||
| 75 | + auto out = | ||
| 76 | + accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x, | ||
| 77 | + 1, accompaniment_output_names_ptr_.data(), | ||
| 78 | + accompaniment_output_names_ptr_.size()); | ||
| 79 | + return std::move(out[0]); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + private: | ||
| 83 | + void InitVocals(void *model_data, size_t model_data_length) { | ||
| 84 | + vocals_sess_ = std::make_unique<Ort::Session>( | ||
| 85 | + env_, model_data, model_data_length, sess_opts_); | ||
| 86 | + | ||
| 87 | + GetInputNames(vocals_sess_.get(), &vocals_input_names_, | ||
| 88 | + &vocals_input_names_ptr_); | ||
| 89 | + | ||
| 90 | + GetOutputNames(vocals_sess_.get(), &vocals_output_names_, | ||
| 91 | + &vocals_output_names_ptr_); | ||
| 92 | + | ||
| 93 | + Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata(); | ||
| 94 | + if (config_.debug) { | ||
| 95 | + std::ostringstream os; | ||
| 96 | + os << "---vocals model---\n"; | ||
| 97 | + PrintModelMetadata(os, meta_data); | ||
| 98 | + | ||
| 99 | + os << "----------input names----------\n"; | ||
| 100 | + int32_t i = 0; | ||
| 101 | + for (const auto &s : vocals_input_names_) { | ||
| 102 | + os << i << " " << s << "\n"; | ||
| 103 | + ++i; | ||
| 104 | + } | ||
| 105 | + os << "----------output names----------\n"; | ||
| 106 | + i = 0; | ||
| 107 | + for (const auto &s : vocals_output_names_) { | ||
| 108 | + os << i << " " << s << "\n"; | ||
| 109 | + ++i; | ||
| 110 | + } | ||
| 111 | + | ||
| 112 | +#if __OHOS__ | ||
| 113 | + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); | ||
| 114 | +#else | ||
| 115 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 116 | +#endif | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 120 | + | ||
| 121 | + std::string model_type; | ||
| 122 | + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); | ||
| 123 | + if (model_type != "spleeter") { | ||
| 124 | + SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'", | ||
| 125 | + model_type.c_str()); | ||
| 126 | + SHERPA_ONNX_EXIT(-1); | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems"); | ||
| 130 | + if (meta_.num_stems != 2) { | ||
| 131 | + SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems", | ||
| 132 | + meta_.num_stems); | ||
| 133 | + SHERPA_ONNX_EXIT(-1); | ||
| 134 | + } | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + void InitAccompaniment(void *model_data, size_t model_data_length) { | ||
| 138 | + accompaniment_sess_ = std::make_unique<Ort::Session>( | ||
| 139 | + env_, model_data, model_data_length, sess_opts_); | ||
| 140 | + | ||
| 141 | + GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_, | ||
| 142 | + &accompaniment_input_names_ptr_); | ||
| 143 | + | ||
| 144 | + GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_, | ||
| 145 | + &accompaniment_output_names_ptr_); | ||
| 146 | + } | ||
| 147 | + | ||
| 148 | + private: | ||
| 149 | + OfflineSourceSeparationModelConfig config_; | ||
| 150 | + OfflineSourceSeparationSpleeterModelMetaData meta_; | ||
| 151 | + | ||
| 152 | + Ort::Env env_; | ||
| 153 | + Ort::SessionOptions sess_opts_; | ||
| 154 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 155 | + | ||
| 156 | + std::unique_ptr<Ort::Session> vocals_sess_; | ||
| 157 | + | ||
| 158 | + std::vector<std::string> vocals_input_names_; | ||
| 159 | + std::vector<const char *> vocals_input_names_ptr_; | ||
| 160 | + | ||
| 161 | + std::vector<std::string> vocals_output_names_; | ||
| 162 | + std::vector<const char *> vocals_output_names_ptr_; | ||
| 163 | + | ||
| 164 | + std::unique_ptr<Ort::Session> accompaniment_sess_; | ||
| 165 | + | ||
| 166 | + std::vector<std::string> accompaniment_input_names_; | ||
| 167 | + std::vector<const char *> accompaniment_input_names_ptr_; | ||
| 168 | + | ||
| 169 | + std::vector<std::string> accompaniment_output_names_; | ||
| 170 | + std::vector<const char *> accompaniment_output_names_ptr_; | ||
| 171 | +}; | ||
| 172 | + | ||
| 173 | +OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() = | ||
| 174 | + default; | ||
| 175 | + | ||
| 176 | +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( | ||
| 177 | + const OfflineSourceSeparationModelConfig &config) | ||
| 178 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 179 | + | ||
| 180 | +template <typename Manager> | ||
| 181 | +OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel( | ||
| 182 | + Manager *mgr, const OfflineSourceSeparationModelConfig &config) | ||
| 183 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 184 | + | ||
| 185 | +Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const { | ||
| 186 | + return impl_->RunVocals(std::move(x)); | ||
| 187 | +} | ||
| 188 | + | ||
| 189 | +Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment( | ||
| 190 | + Ort::Value x) const { | ||
| 191 | + return impl_->RunAccompaniment(std::move(x)); | ||
| 192 | +} | ||
| 193 | + | ||
| 194 | +const OfflineSourceSeparationSpleeterModelMetaData & | ||
| 195 | +OfflineSourceSeparationSpleeterModel::GetMetaData() const { | ||
| 196 | + return impl_->GetMetaData(); | ||
| 197 | +} | ||
| 198 | + | ||
| 199 | +#if __ANDROID_API__ >= 9 | ||
| 200 | +template OfflineSourceSeparationSpleeterModel:: | ||
| 201 | + OfflineSourceSeparationSpleeterModel( | ||
| 202 | + AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config); | ||
| 203 | +#endif | ||
| 204 | + | ||
| 205 | +#if __OHOS__ | ||
| 206 | +template OfflineSourceSeparationSpleeterModel:: | ||
| 207 | + OfflineSourceSeparationSpleeterModel( | ||
| 208 | + NativeResourceManager *mgr, | ||
| 209 | + const OfflineSourceSeparationModelConfig &config); | ||
| 210 | +#endif | ||
| 211 | + | ||
| 212 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | + | ||
| 8 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 9 | +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineSourceSeparationSpleeterModel { | ||
| 15 | + public: | ||
| 16 | + ~OfflineSourceSeparationSpleeterModel(); | ||
| 17 | + | ||
| 18 | + explicit OfflineSourceSeparationSpleeterModel( | ||
| 19 | + const OfflineSourceSeparationModelConfig &config); | ||
| 20 | + | ||
| 21 | + template <typename Manager> | ||
| 22 | + OfflineSourceSeparationSpleeterModel( | ||
| 23 | + Manager *mgr, const OfflineSourceSeparationModelConfig &config); | ||
| 24 | + | ||
| 25 | + Ort::Value RunVocals(Ort::Value x) const; | ||
| 26 | + Ort::Value RunAccompaniment(Ort::Value x) const; | ||
| 27 | + | ||
| 28 | + const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const; | ||
| 29 | + | ||
| 30 | + private: | ||
| 31 | + class Impl; | ||
| 32 | + std::unique_ptr<Impl> impl_; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx | ||
| 36 | + | ||
| 37 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-source-separation.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-source-separation.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-source-separation-impl.h" | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#if __OHOS__ | ||
| 17 | +#include "rawfile/raw_file_manager.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +void OfflineSourceSeparationConfig::Register(ParseOptions *po) { | ||
| 23 | + model.Register(po); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +bool OfflineSourceSeparationConfig::Validate() const { | ||
| 27 | + return model.Validate(); | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +std::string OfflineSourceSeparationConfig::ToString() const { | ||
| 31 | + std::ostringstream os; | ||
| 32 | + | ||
| 33 | + os << "OfflineSourceSeparationConfig("; | ||
| 34 | + os << "model=" << model.ToString() << ")"; | ||
| 35 | + | ||
| 36 | + return os.str(); | ||
| 37 | +} | ||
| 38 | + | ||
| 39 | +template <typename Manager> | ||
| 40 | +OfflineSourceSeparation::OfflineSourceSeparation( | ||
| 41 | + Manager *mgr, const OfflineSourceSeparationConfig &config) | ||
| 42 | + : impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {} | ||
| 43 | + | ||
| 44 | +OfflineSourceSeparation::OfflineSourceSeparation( | ||
| 45 | + const OfflineSourceSeparationConfig &config) | ||
| 46 | + : impl_(OfflineSourceSeparationImpl::Create(config)) {} | ||
| 47 | + | ||
| 48 | +OfflineSourceSeparation::~OfflineSourceSeparation() = default; | ||
| 49 | + | ||
| 50 | +OfflineSourceSeparationOutput OfflineSourceSeparation::Process( | ||
| 51 | + const OfflineSourceSeparationInput &input) const { | ||
| 52 | + return impl_->Process(input); | ||
| 53 | +} | ||
| 54 | + | ||
| 55 | +int32_t OfflineSourceSeparation::GetOutputSampleRate() const { | ||
| 56 | + return impl_->GetOutputSampleRate(); | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +// e.g., it is 2 for 2stems from spleeter | ||
| 60 | +int32_t OfflineSourceSeparation::GetNumberOfStems() const { | ||
| 61 | + return impl_->GetNumberOfStems(); | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +#if __ANDROID_API__ >= 9 | ||
| 65 | +template OfflineSourceSeparation::OfflineSourceSeparation( | ||
| 66 | + AAssetManager *mgr, const OfflineSourceSeparationConfig &config); | ||
| 67 | +#endif | ||
| 68 | + | ||
| 69 | +#if __OHOS__ | ||
| 70 | +template OfflineSourceSeparation::OfflineSourceSeparation( | ||
| 71 | + NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config); | ||
| 72 | +#endif | ||
| 73 | + | ||
| 74 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-source-separation.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-source-separation.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/offline-source-separation-model-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct OfflineSourceSeparationConfig { | ||
| 18 | + OfflineSourceSeparationModelConfig model; | ||
| 19 | + | ||
| 20 | + OfflineSourceSeparationConfig() = default; | ||
| 21 | + | ||
| 22 | + OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model) | ||
| 23 | + : model(model) {} | ||
| 24 | + | ||
| 25 | + void Register(ParseOptions *po); | ||
| 26 | + | ||
| 27 | + bool Validate() const; | ||
| 28 | + | ||
| 29 | + std::string ToString() const; | ||
| 30 | +}; | ||
| 31 | + | ||
| 32 | +struct MultiChannelSamples { | ||
| 33 | + // data[i] is for the i-th channel | ||
| 34 | + // | ||
| 35 | + // each sample is in the range [-1, 1] | ||
| 36 | + std::vector<std::vector<float>> data; | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +struct OfflineSourceSeparationInput { | ||
| 40 | + MultiChannelSamples samples; | ||
| 41 | + | ||
| 42 | + int32_t sample_rate; | ||
| 43 | +}; | ||
| 44 | + | ||
| 45 | +struct OfflineSourceSeparationOutput { | ||
| 46 | + std::vector<MultiChannelSamples> stems; | ||
| 47 | + | ||
| 48 | + int32_t sample_rate; | ||
| 49 | +}; | ||
| 50 | + | ||
| 51 | +class OfflineSourceSeparationImpl; | ||
| 52 | + | ||
| 53 | +class OfflineSourceSeparation { | ||
| 54 | + public: | ||
| 55 | + ~OfflineSourceSeparation(); | ||
| 56 | + | ||
| 57 | + OfflineSourceSeparation(const OfflineSourceSeparationConfig &config); | ||
| 58 | + | ||
| 59 | + template <typename Manager> | ||
| 60 | + OfflineSourceSeparation(Manager *mgr, | ||
| 61 | + const OfflineSourceSeparationConfig &config); | ||
| 62 | + | ||
| 63 | + OfflineSourceSeparationOutput Process( | ||
| 64 | + const OfflineSourceSeparationInput &input) const; | ||
| 65 | + | ||
| 66 | + int32_t GetOutputSampleRate() const; | ||
| 67 | + | ||
| 68 | + // e.g., it is 2 for 2stems from spleeter | ||
| 69 | + int32_t GetNumberOfStems() const; | ||
| 70 | + | ||
| 71 | + private: | ||
| 72 | + std::unique_ptr<OfflineSourceSeparationImpl> impl_; | ||
| 73 | +}; | ||
| 74 | + | ||
| 75 | +} // namespace sherpa_onnx | ||
| 76 | + | ||
| 77 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_ |
| @@ -12,7 +12,7 @@ | @@ -12,7 +12,7 @@ | ||
| 12 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| 13 | 13 | ||
| 14 | // please refer to | 14 | // please refer to |
| 15 | -// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py | 15 | +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py |
| 16 | struct OfflineSpeechDenoiserGtcrnModelMetaData { | 16 | struct OfflineSpeechDenoiserGtcrnModelMetaData { |
| 17 | int32_t sample_rate = 0; | 17 | int32_t sample_rate = 0; |
| 18 | int32_t version = 1; | 18 | int32_t version = 1; |
| @@ -11,7 +11,7 @@ | @@ -11,7 +11,7 @@ | ||
| 11 | 11 | ||
| 12 | int main(int32_t argc, char *argv[]) { | 12 | int main(int32_t argc, char *argv[]) { |
| 13 | const char *kUsageMessage = R"usage( | 13 | const char *kUsageMessage = R"usage( |
| 14 | -Non-stremaing speech denoising with sherpa-onnx. | 14 | +Non-streaming speech denoising with sherpa-onnx. |
| 15 | 15 | ||
| 16 | Please visit | 16 | Please visit |
| 17 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models | 17 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#include <stdio.h> | ||
| 5 | + | ||
| 6 | +#include <chrono> // NOLINT | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-source-separation.h" | ||
| 10 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 11 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 12 | + | ||
| 13 | +int main(int32_t argc, char *argv[]) { | ||
| 14 | + const char *kUsageMessage = R"usage( | ||
| 15 | +Non-streaming source separation with sherpa-onnx. | ||
| 16 | + | ||
| 17 | +Please visit | ||
| 18 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models | ||
| 19 | +to download models. | ||
| 20 | + | ||
| 21 | +Usage: | ||
| 22 | + | ||
| 23 | +(1) Use spleeter models | ||
| 24 | + | ||
| 25 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2 | ||
| 26 | +tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2 | ||
| 27 | + | ||
| 28 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav | ||
| 29 | + | ||
| 30 | +./bin/sherpa-onnx-offline-source-separation \ | ||
| 31 | + --spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \ | ||
| 32 | + --spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \ | ||
| 33 | + --input-wav=audio_example.wav \ | ||
| 34 | + --output-vocals-wav=output_vocals.wav \ | ||
| 35 | + --output-accompaniment-wav=output_accompaniment.wav | ||
| 36 | +)usage"; | ||
| 37 | + | ||
| 38 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 39 | + sherpa_onnx::OfflineSourceSeparationConfig config; | ||
| 40 | + | ||
| 41 | + std::string input_wave; | ||
| 42 | + std::string output_vocals_wave; | ||
| 43 | + std::string output_accompaniment_wave; | ||
| 44 | + | ||
| 45 | + config.Register(&po); | ||
| 46 | + po.Register("input-wav", &input_wave, "Path to input wav."); | ||
| 47 | + po.Register("output-vocals-wav", &output_vocals_wave, | ||
| 48 | + "Path to output vocals wav"); | ||
| 49 | + po.Register("output-accompaniment-wav", &output_accompaniment_wave, | ||
| 50 | + "Path to output accompaniment wav"); | ||
| 51 | + | ||
| 52 | + po.Read(argc, argv); | ||
| 53 | + if (po.NumArgs() != 0) { | ||
| 54 | + fprintf(stderr, "Please don't give positional arguments\n"); | ||
| 55 | + po.PrintUsage(); | ||
| 56 | + exit(EXIT_FAILURE); | ||
| 57 | + } | ||
| 58 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 59 | + | ||
| 60 | + if (input_wave.empty()) { | ||
| 61 | + fprintf(stderr, "Please provide --input-wav\n"); | ||
| 62 | + po.PrintUsage(); | ||
| 63 | + exit(EXIT_FAILURE); | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + if (output_vocals_wave.empty()) { | ||
| 67 | + fprintf(stderr, "Please provide --output-vocals-wav\n"); | ||
| 68 | + po.PrintUsage(); | ||
| 69 | + exit(EXIT_FAILURE); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + if (output_accompaniment_wave.empty()) { | ||
| 73 | + fprintf(stderr, "Please provide --output-accompaniment-wav\n"); | ||
| 74 | + po.PrintUsage(); | ||
| 75 | + exit(EXIT_FAILURE); | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + if (!config.Validate()) { | ||
| 79 | + fprintf(stderr, "Errors in config!\n"); | ||
| 80 | + exit(EXIT_FAILURE); | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + bool is_ok = false; | ||
| 84 | + sherpa_onnx::OfflineSourceSeparationInput input; | ||
| 85 | + input.samples.data = | ||
| 86 | + sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok); | ||
| 87 | + if (!is_ok) { | ||
| 88 | + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); | ||
| 89 | + return -1; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + fprintf(stderr, "Started\n"); | ||
| 93 | + | ||
| 94 | + sherpa_onnx::OfflineSourceSeparation sp(config); | ||
| 95 | + | ||
| 96 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 97 | + auto output = sp.Process(input); | ||
| 98 | + const auto end = std::chrono::steady_clock::now(); | ||
| 99 | + | ||
| 100 | + float elapsed_seconds = | ||
| 101 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 102 | + .count() / | ||
| 103 | + 1000.; | ||
| 104 | + | ||
| 105 | + is_ok = sherpa_onnx::WriteWave( | ||
| 106 | + output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(), | ||
| 107 | + output.stems[0].data[1].data(), output.stems[0].data[0].size()); | ||
| 108 | + | ||
| 109 | + if (!is_ok) { | ||
| 110 | + fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str()); | ||
| 111 | + exit(EXIT_FAILURE); | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate, | ||
| 115 | + output.stems[1].data[0].data(), | ||
| 116 | + output.stems[1].data[1].data(), | ||
| 117 | + output.stems[1].data[0].size()); | ||
| 118 | + | ||
| 119 | + if (!is_ok) { | ||
| 120 | + fprintf(stderr, "Failed to write to '%s'\n", | ||
| 121 | + output_accompaniment_wave.c_str()); | ||
| 122 | + exit(EXIT_FAILURE); | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + fprintf(stderr, "Done\n"); | ||
| 126 | + fprintf(stderr, "Saved to write to '%s' and '%s'\n", | ||
| 127 | + output_vocals_wave.c_str(), output_accompaniment_wave.c_str()); | ||
| 128 | + | ||
| 129 | + float duration = | ||
| 130 | + input.samples.data[0].size() / static_cast<float>(input.sample_rate); | ||
| 131 | + fprintf(stderr, "num threads: %d\n", config.model.num_threads); | ||
| 132 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 133 | + float rtf = elapsed_seconds / duration; | ||
| 134 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 135 | + elapsed_seconds, duration, rtf); | ||
| 136 | + | ||
| 137 | + return 0; | ||
| 138 | +} |
| @@ -63,8 +63,9 @@ in sherpa-onnx. | @@ -63,8 +63,9 @@ in sherpa-onnx. | ||
| 63 | 63 | ||
| 64 | // Read a wave file of mono-channel. | 64 | // Read a wave file of mono-channel. |
| 65 | // Return its samples normalized to the range [-1, 1). | 65 | // Return its samples normalized to the range [-1, 1). |
| 66 | -std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 67 | - bool *is_ok) { | 66 | +std::vector<std::vector<float>> ReadWaveImpl(std::istream &is, |
| 67 | + int32_t *sampling_rate, | ||
| 68 | + bool *is_ok) { | ||
| 68 | WaveHeader header{}; | 69 | WaveHeader header{}; |
| 69 | is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id)); | 70 | is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id)); |
| 70 | 71 | ||
| @@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 144 | is.read(reinterpret_cast<char *>(&header.num_channels), | 145 | is.read(reinterpret_cast<char *>(&header.num_channels), |
| 145 | sizeof(header.num_channels)); | 146 | sizeof(header.num_channels)); |
| 146 | 147 | ||
| 147 | - if (header.num_channels != 1) { // we support only single channel for now | ||
| 148 | - SHERPA_ONNX_LOGE( | ||
| 149 | - "Warning: %d channels are found. We only use the first channel.\n", | ||
| 150 | - header.num_channels); | ||
| 151 | - } | ||
| 152 | - | ||
| 153 | is.read(reinterpret_cast<char *>(&header.sample_rate), | 148 | is.read(reinterpret_cast<char *>(&header.sample_rate), |
| 154 | sizeof(header.sample_rate)); | 149 | sizeof(header.sample_rate)); |
| 155 | 150 | ||
| @@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 219 | 214 | ||
| 220 | *sampling_rate = header.sample_rate; | 215 | *sampling_rate = header.sample_rate; |
| 221 | 216 | ||
| 222 | - std::vector<float> ans; | 217 | + std::vector<std::vector<float>> ans(header.num_channels); |
| 223 | 218 | ||
| 224 | if (header.bits_per_sample == 16 && header.audio_format == 1) { | 219 | if (header.bits_per_sample == 16 && header.audio_format == 1) { |
| 225 | // header.subchunk2_size contains the number of bytes in the data. | 220 | // header.subchunk2_size contains the number of bytes in the data. |
| @@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 233 | return {}; | 228 | return {}; |
| 234 | } | 229 | } |
| 235 | 230 | ||
| 236 | - ans.resize(samples.size() / header.num_channels); | 231 | + for (auto &v : ans) { |
| 232 | + v.resize(samples.size() / header.num_channels); | ||
| 233 | + } | ||
| 237 | 234 | ||
| 238 | // samples are interleaved | 235 | // samples are interleaved |
| 239 | - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 240 | - ans[i] = samples[i * header.num_channels] / 32768.; | 236 | + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size()); |
| 237 | + i += header.num_channels, ++k) { | ||
| 238 | + for (int32_t c = 0; c != header.num_channels; ++c) { | ||
| 239 | + ans[c][k] = samples[i + c] / 32768.; | ||
| 240 | + } | ||
| 241 | } | 241 | } |
| 242 | } else if (header.bits_per_sample == 8 && header.audio_format == 1) { | 242 | } else if (header.bits_per_sample == 8 && header.audio_format == 1) { |
| 243 | // number of samples == number of bytes for 8-bit encoded samples | 243 | // number of samples == number of bytes for 8-bit encoded samples |
| @@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 252 | return {}; | 252 | return {}; |
| 253 | } | 253 | } |
| 254 | 254 | ||
| 255 | - ans.resize(samples.size() / header.num_channels); | ||
| 256 | - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 257 | - // Note(fangjun): We want to normalize each sample into the range [-1, 1] | ||
| 258 | - // Since each original sample is in the range [0, 256], dividing | ||
| 259 | - // them by 128 converts them to the range [0, 2]; | ||
| 260 | - // so after subtracting 1, we get the range [-1, 1] | ||
| 261 | - // | ||
| 262 | - ans[i] = samples[i * header.num_channels] / 128. - 1; | 255 | + for (auto &v : ans) { |
| 256 | + v.resize(samples.size() / header.num_channels); | ||
| 257 | + } | ||
| 258 | + | ||
| 259 | + // samples are interleaved | ||
| 260 | + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size()); | ||
| 261 | + i += header.num_channels, ++k) { | ||
| 262 | + for (int32_t c = 0; c != header.num_channels; ++c) { | ||
| 263 | + // Note(fangjun): We want to normalize each sample into the range [-1, | ||
| 264 | + // 1] Since each original sample is in the range [0, 256], dividing them | ||
| 265 | + // by 128 converts them to the range [0, 2]; so after subtracting 1, we | ||
| 266 | + // get the range [-1, 1] | ||
| 267 | + // | ||
| 268 | + ans[c][k] = samples[i + c] / 128. - 1; | ||
| 269 | + } | ||
| 263 | } | 270 | } |
| 264 | } else if (header.bits_per_sample == 32 && header.audio_format == 1) { | 271 | } else if (header.bits_per_sample == 32 && header.audio_format == 1) { |
| 265 | // 32 here is for int32 | 272 | // 32 here is for int32 |
| @@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 275 | return {}; | 282 | return {}; |
| 276 | } | 283 | } |
| 277 | 284 | ||
| 278 | - ans.resize(samples.size() / header.num_channels); | ||
| 279 | - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 280 | - ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31); | 285 | + for (auto &v : ans) { |
| 286 | + v.resize(samples.size() / header.num_channels); | ||
| 287 | + } | ||
| 288 | + | ||
| 289 | + // samples are interleaved | ||
| 290 | + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size()); | ||
| 291 | + i += header.num_channels, ++k) { | ||
| 292 | + for (int32_t c = 0; c != header.num_channels; ++c) { | ||
| 293 | + ans[c][k] = static_cast<float>(samples[i + c]) / (1 << 31); | ||
| 294 | + } | ||
| 281 | } | 295 | } |
| 282 | } else if (header.bits_per_sample == 32 && header.audio_format == 3) { | 296 | } else if (header.bits_per_sample == 32 && header.audio_format == 3) { |
| 283 | // 32 here is for float32 | 297 | // 32 here is for float32 |
| @@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | @@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate, | ||
| 293 | return {}; | 307 | return {}; |
| 294 | } | 308 | } |
| 295 | 309 | ||
| 296 | - ans.resize(samples.size() / header.num_channels); | ||
| 297 | - for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) { | ||
| 298 | - ans[i] = samples[i * header.num_channels]; | 310 | + for (auto &v : ans) { |
| 311 | + v.resize(samples.size() / header.num_channels); | ||
| 312 | + } | ||
| 313 | + | ||
| 314 | + // samples are interleaved | ||
| 315 | + for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size()); | ||
| 316 | + i += header.num_channels, ++k) { | ||
| 317 | + for (int32_t c = 0; c != header.num_channels; ++c) { | ||
| 318 | + ans[c][k] = samples[i + c]; | ||
| 319 | + } | ||
| 299 | } | 320 | } |
| 300 | } else { | 321 | } else { |
| 301 | SHERPA_ONNX_LOGE( | 322 | SHERPA_ONNX_LOGE( |
| @@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, | @@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, | ||
| 321 | std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, | 342 | std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, |
| 322 | bool *is_ok) { | 343 | bool *is_ok) { |
| 323 | auto samples = ReadWaveImpl(is, sampling_rate, is_ok); | 344 | auto samples = ReadWaveImpl(is, sampling_rate, is_ok); |
| 345 | + | ||
| 346 | + if (samples.size() > 1) { | ||
| 347 | + SHERPA_ONNX_LOGE( | ||
| 348 | + "Warning: %d channels are found. We only use the first channel.\n", | ||
| 349 | + static_cast<int32_t>(samples.size())); | ||
| 350 | + } | ||
| 351 | + | ||
| 352 | + return samples[0]; | ||
| 353 | +} | ||
| 354 | + | ||
| 355 | +std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is, | ||
| 356 | + int32_t *sampling_rate, | ||
| 357 | + bool *is_ok) { | ||
| 358 | + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); | ||
| 324 | return samples; | 359 | return samples; |
| 325 | } | 360 | } |
| 326 | 361 | ||
| 362 | +std::vector<std::vector<float>> ReadWaveMultiChannel( | ||
| 363 | + const std::string &filename, int32_t *sampling_rate, bool *is_ok) { | ||
| 364 | + std::ifstream is(filename, std::ifstream::binary); | ||
| 365 | + return ReadWaveMultiChannel(is, sampling_rate, is_ok); | ||
| 366 | +} | ||
| 367 | + | ||
| 327 | } // namespace sherpa_onnx | 368 | } // namespace sherpa_onnx |
| @@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, | @@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate, | ||
| 26 | std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, | 26 | std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate, |
| 27 | bool *is_ok); | 27 | bool *is_ok); |
| 28 | 28 | ||
| 29 | +std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is, | ||
| 30 | + int32_t *sampling_rate, | ||
| 31 | + bool *is_ok); | ||
| 32 | + | ||
| 33 | +std::vector<std::vector<float>> ReadWaveMultiChannel( | ||
| 34 | + const std::string &filename, int32_t *sampling_rate, bool *is_ok); | ||
| 35 | + | ||
| 29 | } // namespace sherpa_onnx | 36 | } // namespace sherpa_onnx |
| 30 | 37 | ||
| 31 | #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ | 38 | #endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/wave-writer.h" | 5 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 6 | 6 | ||
| 7 | +#include <algorithm> | ||
| 7 | #include <cstring> | 8 | #include <cstring> |
| 8 | #include <fstream> | 9 | #include <fstream> |
| 9 | #include <string> | 10 | #include <string> |
| @@ -36,12 +37,44 @@ struct WaveHeader { | @@ -36,12 +37,44 @@ struct WaveHeader { | ||
| 36 | 37 | ||
| 37 | } // namespace | 38 | } // namespace |
| 38 | 39 | ||
| 39 | -int64_t WaveFileSize(int32_t n_samples) { | ||
| 40 | - return sizeof(WaveHeader) + n_samples * sizeof(int16_t); | 40 | +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) { |
| 41 | + return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels; | ||
| 41 | } | 42 | } |
| 42 | 43 | ||
| 43 | void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | 44 | void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, |
| 44 | int32_t n) { | 45 | int32_t n) { |
| 46 | + WriteWave(buffer, sampling_rate, samples, nullptr, n); | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +bool WriteWave(const std::string &filename, int32_t sampling_rate, | ||
| 50 | + const float *samples, int32_t n) { | ||
| 51 | + return WriteWave(filename, sampling_rate, samples, nullptr, n); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +bool WriteWave(const std::string &filename, int32_t sampling_rate, | ||
| 55 | + const float *samples_ch0, const float *samples_ch1, int32_t n) { | ||
| 56 | + std::string buffer; | ||
| 57 | + buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2)); | ||
| 58 | + | ||
| 59 | + WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n); | ||
| 60 | + | ||
| 61 | + std::ofstream os(filename, std::ios::binary); | ||
| 62 | + if (!os) { | ||
| 63 | + SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); | ||
| 64 | + return false; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + os << buffer; | ||
| 68 | + if (!os) { | ||
| 69 | + SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); | ||
| 70 | + return false; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + return true; | ||
| 74 | +} | ||
| 75 | + | ||
| 76 | +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, | ||
| 77 | + const float *samples_ch1, int32_t n) { | ||
| 45 | WaveHeader header{}; | 78 | WaveHeader header{}; |
| 46 | header.chunk_id = 0x46464952; // FFIR | 79 | header.chunk_id = 0x46464952; // FFIR |
| 47 | header.format = 0x45564157; // EVAW | 80 | header.format = 0x45564157; // EVAW |
| @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | ||
| 49 | header.subchunk1_size = 16; // 16 for PCM | 82 | header.subchunk1_size = 16; // 16 for PCM |
| 50 | header.audio_format = 1; // PCM =1 | 83 | header.audio_format = 1; // PCM =1 |
| 51 | 84 | ||
| 52 | - int32_t num_channels = 1; | 85 | + int32_t num_channels = samples_ch1 == nullptr ? 1 : 2; |
| 53 | int32_t bits_per_sample = 16; // int16_t | 86 | int32_t bits_per_sample = 16; // int16_t |
| 87 | + | ||
| 54 | header.num_channels = num_channels; | 88 | header.num_channels = num_channels; |
| 55 | header.sample_rate = sampling_rate; | 89 | header.sample_rate = sampling_rate; |
| 56 | header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; | 90 | header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; |
| @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | ||
| 61 | 95 | ||
| 62 | header.chunk_size = 36 + header.subchunk2_size; | 96 | header.chunk_size = 36 + header.subchunk2_size; |
| 63 | 97 | ||
| 64 | - std::vector<int16_t> samples_int16(n); | 98 | + std::vector<int16_t> samples_int16_ch0(n); |
| 65 | for (int32_t i = 0; i != n; ++i) { | 99 | for (int32_t i = 0; i != n; ++i) { |
| 66 | - samples_int16[i] = samples[i] * 32767; | 100 | + samples_int16_ch0[i] = std::min<int32_t>(samples_ch0[i] * 32767, 32767); |
| 101 | + } | ||
| 102 | + | ||
| 103 | + std::vector<int16_t> samples_int16_ch1; | ||
| 104 | + if (samples_ch1) { | ||
| 105 | + samples_int16_ch1.resize(n); | ||
| 106 | + for (int32_t i = 0; i != n; ++i) { | ||
| 107 | + samples_int16_ch1[i] = std::min<int32_t>(samples_ch1[i] * 32767, 32767); | ||
| 108 | + } | ||
| 67 | } | 109 | } |
| 68 | 110 | ||
| 69 | memcpy(buffer, &header, sizeof(WaveHeader)); | 111 | memcpy(buffer, &header, sizeof(WaveHeader)); |
| 70 | - memcpy(buffer + sizeof(WaveHeader), samples_int16.data(), | ||
| 71 | - n * sizeof(int16_t)); | ||
| 72 | -} | ||
| 73 | 112 | ||
| 74 | -bool WriteWave(const std::string &filename, int32_t sampling_rate, | ||
| 75 | - const float *samples, int32_t n) { | ||
| 76 | - std::string buffer; | ||
| 77 | - buffer.resize(WaveFileSize(n)); | ||
| 78 | - WriteWave(buffer.data(), sampling_rate, samples, n); | ||
| 79 | - std::ofstream os(filename, std::ios::binary); | ||
| 80 | - if (!os) { | ||
| 81 | - SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str()); | ||
| 82 | - return false; | ||
| 83 | - } | ||
| 84 | - os << buffer; | ||
| 85 | - if (!os) { | ||
| 86 | - SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str()); | ||
| 87 | - return false; | 113 | + if (samples_ch1 == nullptr) { |
| 114 | + memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(), | ||
| 115 | + n * sizeof(int16_t)); | ||
| 116 | + } else { | ||
| 117 | + auto p = reinterpret_cast<int16_t *>(buffer + sizeof(WaveHeader)); | ||
| 118 | + | ||
| 119 | + for (int32_t i = 0; i != n; ++i) { | ||
| 120 | + p[2 * i] = samples_int16_ch0[i]; | ||
| 121 | + p[2 * i + 1] = samples_int16_ch1[i]; | ||
| 122 | + } | ||
| 88 | } | 123 | } |
| 89 | - return true; | ||
| 90 | } | 124 | } |
| 91 | 125 | ||
| 92 | } // namespace sherpa_onnx | 126 | } // namespace sherpa_onnx |
| @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate, | @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate, | ||
| 25 | void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, | 25 | void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, |
| 26 | int32_t n); | 26 | int32_t n); |
| 27 | 27 | ||
| 28 | -int64_t WaveFileSize(int32_t n_samples); | 28 | +bool WriteWave(const std::string &filename, int32_t sampling_rate, |
| 29 | + const float *samples_ch0, const float *samples_ch1, int32_t n); | ||
| 30 | + | ||
| 31 | +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0, | ||
| 32 | + const float *samples_ch1, int32_t n); | ||
| 33 | + | ||
| 34 | +int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1); | ||
| 29 | 35 | ||
| 30 | } // namespace sherpa_onnx | 36 | } // namespace sherpa_onnx |
| 31 | 37 |
| @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) { | @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) { | ||
| 77 | console.log('ArrayBuffer length:', arrayBuffer.byteLength); | 77 | console.log('ArrayBuffer length:', arrayBuffer.byteLength); |
| 78 | 78 | ||
| 79 | const uint8Array = new Uint8Array(arrayBuffer); | 79 | const uint8Array = new Uint8Array(arrayBuffer); |
| 80 | - const wave = readWaveFromBinaryData(uint8Array); | 80 | + const wave = readWaveFromBinaryData(uint8Array, Module); |
| 81 | if (wave == null) { | 81 | if (wave == null) { |
| 82 | alert( | 82 | alert( |
| 83 | `${file.name} is not a valid .wav file. Please select a *.wav file`); | 83 | `${file.name} is not a valid .wav file. Please select a *.wav file`); |
-
请 注册 或 登录 后发表评论