offline-source-separation.cc
4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// sherpa-onnx/python/csrc/offline-source-separation-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include <algorithm>
#include <string>
#include "sherpa-onnx/python/csrc/offline-source-separation-model-config.h"
#include "sherpa-onnx/python/csrc/offline-source-separation.h"
#define C_CONTIGUOUS py::detail::npy_api::constants::NPY_ARRAY_C_CONTIGUOUS_
namespace sherpa_onnx {
static void PybindOfflineSourceSeparationConfig(py::module *m) {
PybindOfflineSourceSeparationModelConfig(m);
using PyClass = OfflineSourceSeparationConfig;
py::class_<PyClass>(*m, "OfflineSourceSeparationConfig")
.def(py::init<const OfflineSourceSeparationModelConfig &>(),
py::arg("model") = OfflineSourceSeparationModelConfig{})
.def_readwrite("model", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
static void PybindMultiChannelSamples(py::module *m) {
using PyClass = MultiChannelSamples;
py::class_<PyClass>(*m, "MultiChannelSamples")
.def_property_readonly("data", [](PyClass &self) -> py::object {
// if data is not empty, return a float array of
// shape (num_channels, num_samples)
int32_t num_channels = self.data.size();
if (num_channels == 0) {
return py::none();
}
int32_t num_samples = self.data[0].size();
if (num_samples == 0) {
return py::none();
}
py::array_t<float> ans({num_channels, num_samples});
py::buffer_info buf = ans.request();
auto p = static_cast<float *>(buf.ptr);
for (int32_t i = 0; i != num_channels; ++i) {
std::copy(self.data[i].begin(), self.data[i].end(),
p + i * num_samples);
}
return ans;
});
}
static void PybindOfflineSourceSeparationOutput(py::module *m) {
using PyClass = OfflineSourceSeparationOutput;
py::class_<PyClass>(*m, "OfflineSourceSeparationOutput")
.def_property_readonly(
"sample_rate", [](const PyClass &self) { return self.sample_rate; })
.def_property_readonly("stems",
[](const PyClass &self) { return self.stems; });
}
void PybindOfflineSourceSeparation(py::module *m) {
PybindOfflineSourceSeparationConfig(m);
PybindOfflineSourceSeparationOutput(m);
PybindMultiChannelSamples(m);
using PyClass = OfflineSourceSeparation;
py::class_<PyClass>(*m, "OfflineSourceSeparation")
.def(py::init<const OfflineSourceSeparationConfig &>(),
py::arg("config") = OfflineSourceSeparationConfig{})
.def(
"process",
[](const PyClass &self, int32_t sample_rate,
const py::array_t<float> &samples) {
if (!(samples.flags() & py::array::c_style)) {
throw py::value_error(
"input samples should be contiguous. Please use "
"np.ascontiguousarray(samples)");
}
int num_dim = samples.ndim();
if (samples.ndim() != 2) {
std::ostringstream os;
os << "Expect an array of 2 dimensions [num_channels x "
"num_samples]. "
"Given dim: "
<< num_dim << "\n";
throw py::value_error(os.str());
}
// if num_samples is less than 10, it is very likely the user
// has swapped num_channels and num_samples.
if (samples.shape(1) < 10) {
std::ostringstream os;
os << "Expect an array of 2 dimensions [num_channels x "
"num_samples]. "
"Given ["
<< samples.shape(0) << " x " << samples.shape(1) << "]"
<< "\n";
throw py::value_error(os.str());
}
int32_t num_channels = samples.shape(0);
int32_t num_samples = samples.shape(1);
const float *p = samples.data();
OfflineSourceSeparationInput input;
input.samples.data.resize(num_channels);
input.sample_rate = sample_rate;
for (int32_t i = 0; i != num_channels; ++i) {
input.samples.data[i] = {p + i * num_samples,
p + (i + 1) * num_samples};
}
pybind11::gil_scoped_release release;
return self.Process(input);
},
py::arg("sample_rate"), py::arg("samples"),
"samples is of shape (num_channels, num-samples) with dtype "
"np.float32");
}
} // namespace sherpa_onnx