sherpa-onnx-offline-source-separation.cc
4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Non-streaming source separation with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
to download models.
Usage:
(1) Use spleeter models
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
./bin/sherpa-onnx-offline-source-separation \
--spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \
--spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \
--input-wav=audio_example.wav \
--output-vocals-wav=output_vocals.wav \
--output-accompaniment-wav=output_accompaniment.wav
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineSourceSeparationConfig config;
std::string input_wave;
std::string output_vocals_wave;
std::string output_accompaniment_wave;
config.Register(&po);
po.Register("input-wav", &input_wave, "Path to input wav.");
po.Register("output-vocals-wav", &output_vocals_wave,
"Path to output vocals wav");
po.Register("output-accompaniment-wav", &output_accompaniment_wave,
"Path to output accompaniment wav");
po.Read(argc, argv);
if (po.NumArgs() != 0) {
fprintf(stderr, "Please don't give positional arguments\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (input_wave.empty()) {
fprintf(stderr, "Please provide --input-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (output_vocals_wave.empty()) {
fprintf(stderr, "Please provide --output-vocals-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (output_accompaniment_wave.empty()) {
fprintf(stderr, "Please provide --output-accompaniment-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
exit(EXIT_FAILURE);
}
bool is_ok = false;
sherpa_onnx::OfflineSourceSeparationInput input;
input.samples.data =
sherpa_onnx::ReadWaveMultiChannel(input_wave, &input.sample_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
return -1;
}
fprintf(stderr, "Started\n");
sherpa_onnx::OfflineSourceSeparation sp(config);
const auto begin = std::chrono::steady_clock::now();
auto output = sp.Process(input);
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
is_ok = sherpa_onnx::WriteWave(
output_vocals_wave, output.sample_rate, output.stems[0].data[0].data(),
output.stems[0].data[1].data(), output.stems[0].data[0].size());
if (!is_ok) {
fprintf(stderr, "Failed to write to '%s'\n", output_vocals_wave.c_str());
exit(EXIT_FAILURE);
}
is_ok = sherpa_onnx::WriteWave(output_accompaniment_wave, output.sample_rate,
output.stems[1].data[0].data(),
output.stems[1].data[1].data(),
output.stems[1].data[0].size());
if (!is_ok) {
fprintf(stderr, "Failed to write to '%s'\n",
output_accompaniment_wave.c_str());
exit(EXIT_FAILURE);
}
fprintf(stderr, "Done\n");
fprintf(stderr, "Saved to write to '%s' and '%s'\n",
output_vocals_wave.c_str(), output_accompaniment_wave.c_str());
float duration =
input.samples.data[0].size() / static_cast<float>(input.sample_rate);
fprintf(stderr, "num threads: %d\n", config.model.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}