Fangjun Kuang
Committed by GitHub

Add C++ runtime for spleeter about source separation (#2242)

... ... @@ -3,7 +3,7 @@ name: export-spleeter-to-onnx
on:
push:
branches:
- spleeter-2
- spleeter-cpp-2
workflow_dispatch:
concurrency:
... ...
... ... @@ -56,6 +56,7 @@ def get_binaries():
"sherpa-onnx-offline-denoiser",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-punctuation",
"sherpa-onnx-offline-source-separation",
"sherpa-onnx-offline-speaker-diarization",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
... ...
... ... @@ -217,8 +217,8 @@ def main(name):
# for the batchnormalization in torch,
# default input shape is NCHW
# NHWC to NCHW
torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
torch_y1_out = unet(torch.from_numpy(y0_out).permute(3, 0, 1, 2))
torch_y1_out = torch_y1_out.permute(1, 0, 2, 3)
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
assert torch.allclose(
... ...
... ... @@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
def export(model, prefix):
num_splits = 1
x = torch.rand(num_splits, 2, 512, 1024, dtype=torch.float32)
x = torch.rand(2, num_splits, 512, 1024, dtype=torch.float32)
filename = f"./2stems/{prefix}.onnx"
torch.onnx.export(
... ... @@ -56,7 +56,7 @@ def export(model, prefix):
input_names=["x"],
output_names=["y"],
dynamic_axes={
"x": {0: "num_splits"},
"x": {1: "num_splits"},
},
opset_version=13,
)
... ...
... ... @@ -101,13 +101,17 @@ def main():
print("y2", y.shape, y.dtype)
y = y.abs()
y = y.permute(0, 3, 1, 2)
# (1, 2, 512, 1024)
y = y.permute(3, 0, 1, 2)
# (2, 1, 512, 1024)
print("y3", y.shape, y.dtype)
vocals_spec = vocals(y)
accompaniment_spec = accompaniment(y)
vocals_spec = vocals_spec.permute(1, 0, 2, 3)
accompaniment_spec = accompaniment_spec.permute(1, 0, 2, 3)
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
print(
"vocals_spec",
... ...
... ... @@ -12,15 +12,14 @@ from separate import load_audio
"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=['num_splits', 2, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[2, 'num_splits', 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
"""
... ... @@ -123,16 +122,16 @@ def main():
if padding > 0:
stft0 = torch.nn.functional.pad(stft0, (0, 0, 0, padding))
stft1 = torch.nn.functional.pad(stft1, (0, 0, 0, padding))
stft0 = stft0.reshape(-1, 1, 512, 1024)
stft1 = stft1.reshape(-1, 1, 512, 1024)
stft0 = stft0.reshape(1, -1, 512, 1024)
stft1 = stft1.reshape(1, -1, 512, 1024)
stft_01 = torch.cat([stft0, stft1], axis=1)
stft_01 = torch.cat([stft0, stft1], axis=0)
print("stft_01", stft_01.shape, stft_01.dtype)
vocals_spec = vocals(stft_01)
accompaniment_spec = accompaniment(stft_01)
# (num_splits, num_channels, 512, 1024)
# (num_channels, num_splits, 512, 1024)
sum_spec = (vocals_spec.square() + accompaniment_spec.square()) + 1e-10
... ... @@ -142,8 +141,8 @@ def main():
for name, spec in zip(
["vocals", "accompaniment"], [vocals_spec, accompaniment_spec]
):
spec_c0 = spec[:, 0, :, :]
spec_c1 = spec[:, 1, :, :]
spec_c0 = spec[0]
spec_c1 = spec[1]
spec_c0 = spec_c0.reshape(-1, 1024)
spec_c1 = spec_c1.reshape(-1, 1024)
... ...
... ... @@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
def forward(self, x):
"""
Args:
x: (num_audio_channels, num_splits, 512, 1024)
Returns:
y: (num_audio_channels, num_splits, 512, 1024)
"""
x = x.permute(1, 0, 2, 3)
in_x = x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
... ... @@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
up7 = self.up7(batch12)
up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
return up7 * in_x
ans = up7 * in_x
return ans.permute(1, 0, 2, 3)
... ...
... ... @@ -50,6 +50,13 @@ set(sources
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc
offline-source-separation-impl.cc
offline-source-separation-model-config.cc
offline-source-separation-spleeter-model-config.cc
offline-source-separation-spleeter-model.cc
offline-source-separation.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
... ... @@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
add_executable(sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc)
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
add_executable(sherpa-onnx-vad sherpa-onnx-vad.cc)
... ... @@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-offline-source-separation
sherpa-onnx-online-punctuation
sherpa-onnx-vad
)
... ...
// sherpa-onnx/csrc/offline-source-separation-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h"
namespace sherpa_onnx {
std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
const OfflineSourceSeparationConfig &config) {
// TODO(fangjun): Support other models
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(config);
}
template <typename Manager>
std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
Manager *mgr, const OfflineSourceSeparationConfig &config) {
// TODO(fangjun): Support other models
return std::make_unique<OfflineSourceSeparationSpleeterImpl>(mgr, config);
}
#if __ANDROID_API__ >= 9
template std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
#endif
#if __OHOS__
template std::unique_ptr<OfflineSourceSeparationImpl>
OfflineSourceSeparationImpl::Create(
NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-source-separation-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-source-separation.h"
namespace sherpa_onnx {
class OfflineSourceSeparationImpl {
public:
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
const OfflineSourceSeparationConfig &config);
template <typename Manager>
static std::unique_ptr<OfflineSourceSeparationImpl> Create(
Manager *mgr, const OfflineSourceSeparationConfig &config);
virtual ~OfflineSourceSeparationImpl() = default;
virtual OfflineSourceSeparationOutput Process(
const OfflineSourceSeparationInput &input) const = 0;
virtual int32_t GetOutputSampleRate() const = 0;
virtual int32_t GetNumberOfStems() const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-source-separation-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
namespace sherpa_onnx {
void OfflineSourceSeparationModelConfig::Register(ParseOptions *po) {
spleeter.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool OfflineSourceSeparationModelConfig::Validate() const {
return spleeter.Validate();
}
std::string OfflineSourceSeparationModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSourceSeparationModelConfig(";
os << "spleeter=" << spleeter.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-source-separation-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSourceSeparationModelConfig {
OfflineSourceSeparationSpleeterModelConfig spleeter;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
OfflineSourceSeparationModelConfig() = default;
OfflineSourceSeparationModelConfig(
const OfflineSourceSeparationSpleeterModelConfig &spleeter,
int32_t num_threads, bool debug, const std::string &provider)
: spleeter(spleeter),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
#include "Eigen/Dense"
#include "kaldi-native-fbank/csrc/istft.h"
#include "kaldi-native-fbank/csrc/stft.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/resample.h"
namespace sherpa_onnx {
class OfflineSourceSeparationSpleeterImpl : public OfflineSourceSeparationImpl {
public:
OfflineSourceSeparationSpleeterImpl(
const OfflineSourceSeparationConfig &config)
: config_(config), model_(config_.model) {}
template <typename Manager>
OfflineSourceSeparationSpleeterImpl(
Manager *mgr, const OfflineSourceSeparationConfig &config)
: config_(config), model_(mgr, config_.model) {}
OfflineSourceSeparationOutput Process(
const OfflineSourceSeparationInput &input) const override {
const OfflineSourceSeparationInput *p_input = &input;
OfflineSourceSeparationInput tmp_input;
int32_t output_sample_rate = GetOutputSampleRate();
if (input.sample_rate != output_sample_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
input.sample_rate, output_sample_rate);
float min_freq = std::min<int32_t>(input.sample_rate, output_sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
input.sample_rate, output_sample_rate, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> s;
for (const auto &samples : input.samples.data) {
resampler->Reset();
resampler->Resample(samples.data(), samples.size(), true, &s);
tmp_input.samples.data.push_back(std::move(s));
}
tmp_input.sample_rate = output_sample_rate;
p_input = &tmp_input;
}
if (p_input->samples.data.size() > 1) {
if (config_.model.debug) {
SHERPA_ONNX_LOGE("input ch1 samples size: %d",
static_cast<int32_t>(p_input->samples.data[1].size()));
}
if (p_input->samples.data[0].size() != p_input->samples.data[1].size()) {
SHERPA_ONNX_LOGE("ch0 samples size %d vs ch1 samples size %d",
static_cast<int32_t>(p_input->samples.data[0].size()),
static_cast<int32_t>(p_input->samples.data[1].size()));
SHERPA_ONNX_EXIT(-1);
}
}
auto stft_ch0 = ComputeStft(*p_input, 0);
auto stft_ch1 = ComputeStft(*p_input, 1);
knf::StftResult *p_stft_ch1 = stft_ch1.real.empty() ? &stft_ch0 : &stft_ch1;
int32_t num_frames = stft_ch0.num_frames;
int32_t fft_bins = stft_ch0.real.size() / num_frames;
int32_t pad = 512 - (stft_ch0.num_frames % 512);
if (pad < 512) {
num_frames += pad;
}
if (num_frames % 512) {
SHERPA_ONNX_LOGE("num_frames should be multiple of 512, actual: %d. %d",
num_frames, num_frames % 512);
SHERPA_ONNX_EXIT(-1);
}
Eigen::VectorXf real(2 * num_frames * 1024);
Eigen::VectorXf imag(2 * num_frames * 1024);
real.setZero();
imag.setZero();
float *p_real = &real[0];
float *p_imag = &imag[0];
// copy stft result of channel 0
for (int32_t i = 0; i != stft_ch0.num_frames; ++i) {
std::copy(stft_ch0.real.data() + i * fft_bins,
stft_ch0.real.data() + i * fft_bins + 1024, p_real + 1024 * i);
std::copy(stft_ch0.imag.data() + i * fft_bins,
stft_ch0.imag.data() + i * fft_bins + 1024, p_imag + 1024 * i);
}
p_real += num_frames * 1024;
p_imag += num_frames * 1024;
// copy stft result of channel 1
for (int32_t i = 0; i != stft_ch1.num_frames; ++i) {
std::copy(p_stft_ch1->real.data() + i * fft_bins,
p_stft_ch1->real.data() + i * fft_bins + 1024,
p_real + 1024 * i);
std::copy(p_stft_ch1->imag.data() + i * fft_bins,
p_stft_ch1->imag.data() + i * fft_bins + 1024,
p_imag + 1024 * i);
}
Eigen::VectorXf x = (real.array().square() + imag.array().square()).sqrt();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 4> x_shape{2, num_frames / 512, 512, 1024};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, &x[0], x.size(), x_shape.data(), x_shape.size());
Ort::Value vocals_spec_tensor = model_.RunVocals(View(&x_tensor));
Ort::Value accompaniment_spec_tensor =
model_.RunAccompaniment(std::move(x_tensor));
Eigen::VectorXf vocals_spec = Eigen::Map<Eigen::VectorXf>(
vocals_spec_tensor.GetTensorMutableData<float>(), x.size());
Eigen::VectorXf accompaniment_spec = Eigen::Map<Eigen::VectorXf>(
accompaniment_spec_tensor.GetTensorMutableData<float>(), x.size());
Eigen::VectorXf sum_spec = vocals_spec.array().square() +
accompaniment_spec.array().square() + 1e-10;
vocals_spec = (vocals_spec.array().square() + 1e-10 / 2) / sum_spec.array();
accompaniment_spec =
(accompaniment_spec.array().square() + 1e-10 / 2) / sum_spec.array();
auto vocals_samples_ch0 = ProcessSpec(vocals_spec, stft_ch0, 0);
auto vocals_samples_ch1 = ProcessSpec(vocals_spec, *p_stft_ch1, 1);
auto accompaniment_samples_ch0 =
ProcessSpec(accompaniment_spec, stft_ch0, 0);
auto accompaniment_samples_ch1 =
ProcessSpec(accompaniment_spec, *p_stft_ch1, 1);
OfflineSourceSeparationOutput ans;
ans.sample_rate = GetOutputSampleRate();
ans.stems.resize(2);
ans.stems[0].data.reserve(2);
ans.stems[1].data.reserve(2);
ans.stems[0].data.push_back(std::move(vocals_samples_ch0));
ans.stems[0].data.push_back(std::move(vocals_samples_ch1));
ans.stems[1].data.push_back(std::move(accompaniment_samples_ch0));
ans.stems[1].data.push_back(std::move(accompaniment_samples_ch1));
return ans;
}
int32_t GetOutputSampleRate() const override {
return model_.GetMetaData().sample_rate;
}
int32_t GetNumberOfStems() const override {
return model_.GetMetaData().num_stems;
}
private:
// spec is of shape (2, num_chunks, 512, 1024)
std::vector<float> ProcessSpec(const Eigen::VectorXf &spec,
const knf::StftResult &stft,
int32_t channel) const {
int32_t fft_bins = stft.real.size() / stft.num_frames;
Eigen::VectorXf mask(stft.real.size());
mask.setZero();
float *p_mask = &mask[0];
// assume there are 2 channels
const float *p_spec = &spec[0] + (spec.size() / 2) * channel;
for (int32_t i = 0; i != stft.num_frames; ++i) {
std::copy(p_spec + i * 1024, p_spec + (i + 1) * 1024,
p_mask + i * fft_bins);
}
knf::StftResult masked_stft;
masked_stft.num_frames = stft.num_frames;
masked_stft.real.resize(stft.real.size());
masked_stft.imag.resize(stft.imag.size());
Eigen::Map<Eigen::VectorXf>(masked_stft.real.data(),
masked_stft.real.size()) =
mask.array() *
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.real.data()),
stft.real.size())
.array();
Eigen::Map<Eigen::VectorXf>(masked_stft.imag.data(),
masked_stft.imag.size()) =
mask.array() *
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(stft.imag.data()),
stft.imag.size())
.array();
auto stft_config = GetStftConfig();
knf::IStft istft(stft_config);
return istft.Compute(masked_stft);
}
knf::StftResult ComputeStft(const OfflineSourceSeparationInput &input,
int32_t ch) const {
if (ch >= input.samples.data.size()) {
SHERPA_ONNX_LOGE("Invalid channel %d. Max %d", ch,
static_cast<int32_t>(input.samples.data.size()));
SHERPA_ONNX_EXIT(-1);
}
if (input.samples.data[ch].empty()) {
return {};
}
return ComputeStft(input.samples.data[ch]);
}
knf::StftResult ComputeStft(const std::vector<float> &samples) const {
auto stft_config = GetStftConfig();
knf::Stft stft(stft_config);
return stft.Compute(samples.data(), samples.size());
}
knf::StftConfig GetStftConfig() const {
const auto &meta = model_.GetMetaData();
knf::StftConfig stft_config;
stft_config.n_fft = meta.n_fft;
stft_config.hop_length = meta.hop_length;
stft_config.win_length = meta.window_length;
stft_config.window_type = meta.window_type;
stft_config.center = meta.center;
stft_config.center = false;
return stft_config;
}
private:
OfflineSourceSeparationConfig config_;
OfflineSourceSeparationSpleeterModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineSourceSeparationSpleeterModelConfig::Register(ParseOptions *po) {
po->Register("spleeter-vocals", &vocals, "Path to the spleeter vocals model");
po->Register("spleeter-accompaniment", &accompaniment,
"Path to the spleeter accompaniment model");
}
bool OfflineSourceSeparationSpleeterModelConfig::Validate() const {
if (vocals.empty()) {
SHERPA_ONNX_LOGE("Please provide --spleeter-vocals");
return false;
}
if (!FileExists(vocals)) {
SHERPA_ONNX_LOGE("spleeter vocals '%s' does not exist. ", vocals.c_str());
return false;
}
if (accompaniment.empty()) {
SHERPA_ONNX_LOGE("Please provide --spleeter-accompaniment");
return false;
}
if (!FileExists(accompaniment)) {
SHERPA_ONNX_LOGE("spleeter accompaniment '%s' does not exist. ",
accompaniment.c_str());
return false;
}
return true;
}
std::string OfflineSourceSeparationSpleeterModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSourceSeparationSpleeterModelConfig(";
os << "vocals=\"" << vocals << "\", ";
os << "accompaniment=\"" << accompaniment << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSourceSeparationSpleeterModelConfig {
std::string vocals;
std::string accompaniment;
OfflineSourceSeparationSpleeterModelConfig() = default;
OfflineSourceSeparationSpleeterModelConfig(const std::string &vocals,
const std::string &accompaniment)
: vocals(vocals), accompaniment(accompaniment) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace sherpa_onnx {
// See also
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py
struct OfflineSourceSeparationSpleeterModelMetaData {
int32_t sample_rate = 44100;
int32_t num_stems = 2;
int32_t n_fft = 4096;
int32_t hop_length = 1024;
int32_t window_length = 4096;
bool center = false;
std::string window_type = "hann";
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineSourceSeparationSpleeterModel::Impl {
public:
explicit Impl(const OfflineSourceSeparationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.spleeter.vocals);
InitVocals(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.spleeter.accompaniment);
InitAccompaniment(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OfflineSourceSeparationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.spleeter.vocals);
InitVocals(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.spleeter.accompaniment);
InitAccompaniment(buf.data(), buf.size());
}
}
const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const {
return meta_;
}
Ort::Value RunVocals(Ort::Value x) const {
auto out = vocals_sess_->Run({}, vocals_input_names_ptr_.data(), &x, 1,
vocals_output_names_ptr_.data(),
vocals_output_names_ptr_.size());
return std::move(out[0]);
}
Ort::Value RunAccompaniment(Ort::Value x) const {
auto out =
accompaniment_sess_->Run({}, accompaniment_input_names_ptr_.data(), &x,
1, accompaniment_output_names_ptr_.data(),
accompaniment_output_names_ptr_.size());
return std::move(out[0]);
}
private:
void InitVocals(void *model_data, size_t model_data_length) {
vocals_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(vocals_sess_.get(), &vocals_input_names_,
&vocals_input_names_ptr_);
GetOutputNames(vocals_sess_.get(), &vocals_output_names_,
&vocals_output_names_ptr_);
Ort::ModelMetadata meta_data = vocals_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---vocals model---\n";
PrintModelMetadata(os, meta_data);
os << "----------input names----------\n";
int32_t i = 0;
for (const auto &s : vocals_input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : vocals_output_names_) {
os << i << " " << s << "\n";
++i;
}
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
if (model_type != "spleeter") {
SHERPA_ONNX_LOGE("Expect model type 'spleeter'. Given: '%s'",
model_type.c_str());
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA(meta_.num_stems, "stems");
if (meta_.num_stems != 2) {
SHERPA_ONNX_LOGE("Only 2stems is supported. Given %d stems",
meta_.num_stems);
SHERPA_ONNX_EXIT(-1);
}
}
void InitAccompaniment(void *model_data, size_t model_data_length) {
accompaniment_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(accompaniment_sess_.get(), &accompaniment_input_names_,
&accompaniment_input_names_ptr_);
GetOutputNames(accompaniment_sess_.get(), &accompaniment_output_names_,
&accompaniment_output_names_ptr_);
}
private:
OfflineSourceSeparationModelConfig config_;
OfflineSourceSeparationSpleeterModelMetaData meta_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> vocals_sess_;
std::vector<std::string> vocals_input_names_;
std::vector<const char *> vocals_input_names_ptr_;
std::vector<std::string> vocals_output_names_;
std::vector<const char *> vocals_output_names_ptr_;
std::unique_ptr<Ort::Session> accompaniment_sess_;
std::vector<std::string> accompaniment_input_names_;
std::vector<const char *> accompaniment_input_names_ptr_;
std::vector<std::string> accompaniment_output_names_;
std::vector<const char *> accompaniment_output_names_ptr_;
};
OfflineSourceSeparationSpleeterModel::~OfflineSourceSeparationSpleeterModel() =
default;
OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
const OfflineSourceSeparationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineSourceSeparationSpleeterModel::OfflineSourceSeparationSpleeterModel(
Manager *mgr, const OfflineSourceSeparationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
Ort::Value OfflineSourceSeparationSpleeterModel::RunVocals(Ort::Value x) const {
return impl_->RunVocals(std::move(x));
}
Ort::Value OfflineSourceSeparationSpleeterModel::RunAccompaniment(
Ort::Value x) const {
return impl_->RunAccompaniment(std::move(x));
}
const OfflineSourceSeparationSpleeterModelMetaData &
OfflineSourceSeparationSpleeterModel::GetMetaData() const {
return impl_->GetMetaData();
}
#if __ANDROID_API__ >= 9
template OfflineSourceSeparationSpleeterModel::
OfflineSourceSeparationSpleeterModel(
AAssetManager *mgr, const OfflineSourceSeparationModelConfig &config);
#endif
#if __OHOS__
template OfflineSourceSeparationSpleeterModel::
OfflineSourceSeparationSpleeterModel(
NativeResourceManager *mgr,
const OfflineSourceSeparationModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h"
namespace sherpa_onnx {
class OfflineSourceSeparationSpleeterModel {
public:
~OfflineSourceSeparationSpleeterModel();
explicit OfflineSourceSeparationSpleeterModel(
const OfflineSourceSeparationModelConfig &config);
template <typename Manager>
OfflineSourceSeparationSpleeterModel(
Manager *mgr, const OfflineSourceSeparationModelConfig &config);
Ort::Value RunVocals(Ort::Value x) const;
Ort::Value RunAccompaniment(Ort::Value x) const;
const OfflineSourceSeparationSpleeterModelMetaData &GetMetaData() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
... ...
// sherpa-onnx/csrc/offline-source-separation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include <memory>
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
namespace sherpa_onnx {
void OfflineSourceSeparationConfig::Register(ParseOptions *po) {
model.Register(po);
}
bool OfflineSourceSeparationConfig::Validate() const {
return model.Validate();
}
std::string OfflineSourceSeparationConfig::ToString() const {
std::ostringstream os;
os << "OfflineSourceSeparationConfig(";
os << "model=" << model.ToString() << ")";
return os.str();
}
template <typename Manager>
OfflineSourceSeparation::OfflineSourceSeparation(
Manager *mgr, const OfflineSourceSeparationConfig &config)
: impl_(OfflineSourceSeparationImpl::Create(mgr, config)) {}
OfflineSourceSeparation::OfflineSourceSeparation(
const OfflineSourceSeparationConfig &config)
: impl_(OfflineSourceSeparationImpl::Create(config)) {}
OfflineSourceSeparation::~OfflineSourceSeparation() = default;
OfflineSourceSeparationOutput OfflineSourceSeparation::Process(
const OfflineSourceSeparationInput &input) const {
return impl_->Process(input);
}
int32_t OfflineSourceSeparation::GetOutputSampleRate() const {
return impl_->GetOutputSampleRate();
}
// e.g., it is 2 for 2stems from spleeter
int32_t OfflineSourceSeparation::GetNumberOfStems() const {
return impl_->GetNumberOfStems();
}
#if __ANDROID_API__ >= 9
template OfflineSourceSeparation::OfflineSourceSeparation(
AAssetManager *mgr, const OfflineSourceSeparationConfig &config);
#endif
#if __OHOS__
template OfflineSourceSeparation::OfflineSourceSeparation(
NativeResourceManager *mgr, const OfflineSourceSeparationConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-source-separation.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSourceSeparationConfig {
OfflineSourceSeparationModelConfig model;
OfflineSourceSeparationConfig() = default;
OfflineSourceSeparationConfig(const OfflineSourceSeparationModelConfig &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct MultiChannelSamples {
// data[i] is for the i-th channel
//
// each sample is in the range [-1, 1]
std::vector<std::vector<float>> data;
};
struct OfflineSourceSeparationInput {
MultiChannelSamples samples;
int32_t sample_rate;
};
struct OfflineSourceSeparationOutput {
std::vector<MultiChannelSamples> stems;
int32_t sample_rate;
};
class OfflineSourceSeparationImpl;
class OfflineSourceSeparation {
public:
~OfflineSourceSeparation();
OfflineSourceSeparation(const OfflineSourceSeparationConfig &config);
template <typename Manager>
OfflineSourceSeparation(Manager *mgr,
const OfflineSourceSeparationConfig &config);
OfflineSourceSeparationOutput Process(
const OfflineSourceSeparationInput &input) const;
int32_t GetOutputSampleRate() const;
// e.g., it is 2 for 2stems from spleeter
int32_t GetNumberOfStems() const;
private:
std::unique_ptr<OfflineSourceSeparationImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
... ...
... ... @@ -12,7 +12,7 @@
namespace sherpa_onnx {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/gtcrn/add_meta_data.py
struct OfflineSpeechDenoiserGtcrnModelMetaData {
int32_t sample_rate = 0;
int32_t version = 1;
... ...
... ... @@ -11,7 +11,7 @@
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Non-stremaing speech denoising with sherpa-onnx.
Non-streaming speech denoising with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Non-streaming source separation with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
to download models.
Usage:
(1) Use spleeter models
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
./bin/sherpa-onnx-offline-source-separation \
--spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \
--spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \
--input-wav=audio_example.wav \
--output-vocals-wav=output_vocals.wav \
--output-accompaniment-wav=output_accompaniment.wav
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineSourceSeparationConfig config;
std::string input_wave;
std::string output_vocals_wave;
std::string output_accompaniment_wave;
config.Register(&po);
po.Register("input-wav", &input_wave, "Path to input wav.");
po.Register("output-vocals-wav", &output_vocals_wave,
"Path to output vocals wav");
po.Register("output-accompaniment-wav", &output_accompaniment_wave,
"Path to output accompaniment wav");
po.Read(argc, argv);
if (po.NumArgs() != 0) {
fprintf(stderr, "Please don't give positional arguments\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (input_wave.empty()) {
fprintf(stderr, "Please provide --input-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (output_vocals_wave.empty()) {
fprintf(stderr, "Please provide --output-vocals-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (output_accompaniment_wave.empty()) {
fprintf(stderr, "Please provide --output-accompaniment-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
exit(EXIT_FAILURE);
}
bool is_ok = false;
sherpa_onnx::OfflineSourceSeparationInput input;
input.samples.data =
sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
return -1;
}
fprintf(stderr, "Started\n");
sherpa_onnx::OfflineSourceSeparation sp(config);
const auto begin = std::chrono::steady_clock::now();
auto output = sp.Process(input);
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
is_ok = sherpa_onnx::WriteWave(
output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(),
output.stems[0].data[1].data(), output.stems[0].data[0].size());
if (!is_ok) {
fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str());
exit(EXIT_FAILURE);
}
is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate,
output.stems[1].data[0].data(),
output.stems[1].data[1].data(),
output.stems[1].data[0].size());
if (!is_ok) {
fprintf(stderr, "Failed to write to '%s'\n",
output_accompaniment_wave.c_str());
exit(EXIT_FAILURE);
}
fprintf(stderr, "Done\n");
fprintf(stderr, "Saved to write to '%s' and '%s'\n",
output_vocals_wave.c_str(), output_accompaniment_wave.c_str());
float duration =
input.samples.data[0].size() / static_cast<float>(input.sample_rate);
fprintf(stderr, "num threads: %d\n", config.model.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...
... ... @@ -63,8 +63,9 @@ in sherpa-onnx.
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
bool *is_ok) {
std::vector<std::vector<float>> ReadWaveImpl(std::istream &is,
int32_t *sampling_rate,
bool *is_ok) {
WaveHeader header{};
is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id));
... ... @@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
is.read(reinterpret_cast<char *>(&header.num_channels),
sizeof(header.num_channels));
if (header.num_channels != 1) { // we support only single channel for now
SHERPA_ONNX_LOGE(
"Warning: %d channels are found. We only use the first channel.\n",
header.num_channels);
}
is.read(reinterpret_cast<char *>(&header.sample_rate),
sizeof(header.sample_rate));
... ... @@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
*sampling_rate = header.sample_rate;
std::vector<float> ans;
std::vector<std::vector<float>> ans(header.num_channels);
if (header.bits_per_sample == 16 && header.audio_format == 1) {
// header.subchunk2_size contains the number of bytes in the data.
... ... @@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}
ans.resize(samples.size() / header.num_channels);
for (auto &v : ans) {
v.resize(samples.size() / header.num_channels);
}
// samples are interleaved
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels] / 32768.;
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
i += header.num_channels, ++k) {
for (int32_t c = 0; c != header.num_channels; ++c) {
ans[c][k] = samples[i + c] / 32768.;
}
}
} else if (header.bits_per_sample == 8 && header.audio_format == 1) {
// number of samples == number of bytes for 8-bit encoded samples
... ... @@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}
ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
// Note(fangjun): We want to normalize each sample into the range [-1, 1]
// Since each original sample is in the range [0, 256], dividing
// them by 128 converts them to the range [0, 2];
// so after subtracting 1, we get the range [-1, 1]
//
ans[i] = samples[i * header.num_channels] / 128. - 1;
for (auto &v : ans) {
v.resize(samples.size() / header.num_channels);
}
// samples are interleaved
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
i += header.num_channels, ++k) {
for (int32_t c = 0; c != header.num_channels; ++c) {
// Note(fangjun): We want to normalize each sample into the range [-1,
// 1] Since each original sample is in the range [0, 256], dividing them
// by 128 converts them to the range [0, 2]; so after subtracting 1, we
// get the range [-1, 1]
//
ans[c][k] = samples[i + c] / 128. - 1;
}
}
} else if (header.bits_per_sample == 32 && header.audio_format == 1) {
// 32 here is for int32
... ... @@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}
ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = static_cast<float>(samples[i * header.num_channels]) / (1 << 31);
for (auto &v : ans) {
v.resize(samples.size() / header.num_channels);
}
// samples are interleaved
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
i += header.num_channels, ++k) {
for (int32_t c = 0; c != header.num_channels; ++c) {
ans[c][k] = static_cast<float>(samples[i + c]) / (1 << 31);
}
}
} else if (header.bits_per_sample == 32 && header.audio_format == 3) {
// 32 here is for float32
... ... @@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return {};
}
ans.resize(samples.size() / header.num_channels);
for (int32_t i = 0; i != static_cast<int32_t>(ans.size()); ++i) {
ans[i] = samples[i * header.num_channels];
for (auto &v : ans) {
v.resize(samples.size() / header.num_channels);
}
// samples are interleaved
for (int32_t i = 0, k = 0; i < static_cast<int32_t>(samples.size());
i += header.num_channels, ++k) {
for (int32_t c = 0; c != header.num_channels; ++c) {
ans[c][k] = samples[i + c];
}
}
} else {
SHERPA_ONNX_LOGE(
... ... @@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
bool *is_ok) {
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
if (samples.size() > 1) {
SHERPA_ONNX_LOGE(
"Warning: %d channels are found. We only use the first channel.\n",
static_cast<int32_t>(samples.size()));
}
return samples[0];
}
std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
int32_t *sampling_rate,
bool *is_ok) {
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
return samples;
}
std::vector<std::vector<float>> ReadWaveMultiChannel(
const std::string &filename, int32_t *sampling_rate, bool *is_ok) {
std::ifstream is(filename, std::ifstream::binary);
return ReadWaveMultiChannel(is, sampling_rate, is_ok);
}
} // namespace sherpa_onnx
... ...
... ... @@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
bool *is_ok);
std::vector<std::vector<float>> ReadWaveMultiChannel(std::istream &is,
int32_t *sampling_rate,
bool *is_ok);
std::vector<std::vector<float>> ReadWaveMultiChannel(
const std::string &filename, int32_t *sampling_rate, bool *is_ok);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/wave-writer.h"
#include <algorithm>
#include <cstring>
#include <fstream>
#include <string>
... ... @@ -36,12 +37,44 @@ struct WaveHeader {
} // namespace
int64_t WaveFileSize(int32_t n_samples) {
return sizeof(WaveHeader) + n_samples * sizeof(int16_t);
int64_t WaveFileSize(int32_t n_samples, int32_t num_channels /*= 1*/) {
return sizeof(WaveHeader) + n_samples * sizeof(int16_t) * num_channels;
}
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
int32_t n) {
WriteWave(buffer, sampling_rate, samples, nullptr, n);
}
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples, int32_t n) {
return WriteWave(filename, sampling_rate, samples, nullptr, n);
}
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples_ch0, const float *samples_ch1, int32_t n) {
std::string buffer;
buffer.resize(WaveFileSize(n, samples_ch1 == nullptr ? 1 : 2));
WriteWave(buffer.data(), sampling_rate, samples_ch0, samples_ch1, n);
std::ofstream os(filename, std::ios::binary);
if (!os) {
SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());
return false;
}
os << buffer;
if (!os) {
SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());
return false;
}
return true;
}
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
const float *samples_ch1, int32_t n) {
WaveHeader header{};
header.chunk_id = 0x46464952; // FFIR
header.format = 0x45564157; // EVAW
... ... @@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
header.subchunk1_size = 16; // 16 for PCM
header.audio_format = 1; // PCM =1
int32_t num_channels = 1;
int32_t num_channels = samples_ch1 == nullptr ? 1 : 2;
int32_t bits_per_sample = 16; // int16_t
header.num_channels = num_channels;
header.sample_rate = sampling_rate;
header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8;
... ... @@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
header.chunk_size = 36 + header.subchunk2_size;
std::vector<int16_t> samples_int16(n);
std::vector<int16_t> samples_int16_ch0(n);
for (int32_t i = 0; i != n; ++i) {
samples_int16[i] = samples[i] * 32767;
samples_int16_ch0[i] = std::min<int32_t>(samples_ch0[i] * 32767, 32767);
}
std::vector<int16_t> samples_int16_ch1;
if (samples_ch1) {
samples_int16_ch1.resize(n);
for (int32_t i = 0; i != n; ++i) {
samples_int16_ch1[i] = std::min<int32_t>(samples_ch1[i] * 32767, 32767);
}
}
memcpy(buffer, &header, sizeof(WaveHeader));
memcpy(buffer + sizeof(WaveHeader), samples_int16.data(),
n * sizeof(int16_t));
}
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples, int32_t n) {
std::string buffer;
buffer.resize(WaveFileSize(n));
WriteWave(buffer.data(), sampling_rate, samples, n);
std::ofstream os(filename, std::ios::binary);
if (!os) {
SHERPA_ONNX_LOGE("Failed to create '%s'", filename.c_str());
return false;
}
os << buffer;
if (!os) {
SHERPA_ONNX_LOGE("Write '%s' failed", filename.c_str());
return false;
if (samples_ch1 == nullptr) {
memcpy(buffer + sizeof(WaveHeader), samples_int16_ch0.data(),
n * sizeof(int16_t));
} else {
auto p = reinterpret_cast<int16_t *>(buffer + sizeof(WaveHeader));
for (int32_t i = 0; i != n; ++i) {
p[2 * i] = samples_int16_ch0[i];
p[2 * i + 1] = samples_int16_ch1[i];
}
}
return true;
}
} // namespace sherpa_onnx
... ...
... ... @@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate,
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
int32_t n);
int64_t WaveFileSize(int32_t n_samples);
bool WriteWave(const std::string &filename, int32_t sampling_rate,
const float *samples_ch0, const float *samples_ch1, int32_t n);
void WriteWave(char *buffer, int32_t sampling_rate, const float *samples_ch0,
const float *samples_ch1, int32_t n);
int64_t WaveFileSize(int32_t n_samples, int32_t num_channels = 1);
} // namespace sherpa_onnx
... ...
... ... @@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) {
console.log('ArrayBuffer length:', arrayBuffer.byteLength);
const uint8Array = new Uint8Array(arrayBuffer);
const wave = readWaveFromBinaryData(uint8Array);
const wave = readWaveFromBinaryData(uint8Array, Module);
if (wave == null) {
alert(
`${file.name} is not a valid .wav file. Please select a *.wav file`);
... ...