Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-05-23 22:30:57 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-05-23 22:30:57 +0800
Commit
716ba8317bb36cbcce9ac6c130396e8ec71b3c9c
716ba831
1 parent
55a44793
Add C++ runtime for spleeter about source separation (#2242)
隐藏空白字符变更
内嵌
并排对比
正在显示
28 个修改的文件
包含
1267 行增加
和
72 行删除
.github/workflows/export-spleeter-to-onnx.yaml
cmake/cmake_extension.py
scripts/spleeter/convert_to_torch.py
scripts/spleeter/export_onnx.py
scripts/spleeter/separate.py
scripts/spleeter/separate_onnx.py
scripts/spleeter/unet.py
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/offline-source-separation-impl.cc
sherpa-onnx/csrc/offline-source-separation-impl.h
sherpa-onnx/csrc/offline-source-separation-model-config.cc
sherpa-onnx/csrc/offline-source-separation-model-config.h
sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc
sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h
sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
sherpa-onnx/csrc/offline-source-separation.cc
sherpa-onnx/csrc/offline-source-separation.h
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h
sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
sherpa-onnx/csrc/wave-reader.cc
sherpa-onnx/csrc/wave-reader.h
sherpa-onnx/csrc/wave-writer.cc
sherpa-onnx/csrc/wave-writer.h
wasm/speech-enhancement/app-speech-enhancement.js
.github/workflows/export-spleeter-to-onnx.yaml
查看文件 @
716ba83
...
...
@@ -3,7 +3,7 @@ name: export-spleeter-to-onnx
on
:
push
:
branches
:
-
spleeter-2
-
spleeter-
cpp-
2
workflow_dispatch
:
concurrency
:
...
...
cmake/cmake_extension.py
查看文件 @
716ba83
...
...
@@ -56,6 +56,7 @@ def get_binaries():
"sherpa-onnx-offline-denoiser"
,
"sherpa-onnx-offline-language-identification"
,
"sherpa-onnx-offline-punctuation"
,
"sherpa-onnx-offline-source-separation"
,
"sherpa-onnx-offline-speaker-diarization"
,
"sherpa-onnx-offline-tts"
,
"sherpa-onnx-offline-tts-play"
,
...
...
scripts/spleeter/convert_to_torch.py
查看文件 @
716ba83
...
...
@@ -217,8 +217,8 @@ def main(name):
# for the batchnormalization in torch,
# default input shape is NCHW
# NHWC to NCHW
torch_y1_out
=
unet
(
torch
.
from_numpy
(
y0_out
)
.
permute
(
0
,
3
,
1
,
2
))
torch_y1_out
=
unet
(
torch
.
from_numpy
(
y0_out
)
.
permute
(
3
,
0
,
1
,
2
))
torch_y1_out
=
torch_y1_out
.
permute
(
1
,
0
,
2
,
3
)
# print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
assert
torch
.
allclose
(
...
...
scripts/spleeter/export_onnx.py
查看文件 @
716ba83
...
...
@@ -46,7 +46,7 @@ def add_meta_data(filename, prefix):
def
export
(
model
,
prefix
):
num_splits
=
1
x
=
torch
.
rand
(
num_splits
,
2
,
512
,
1024
,
dtype
=
torch
.
float32
)
x
=
torch
.
rand
(
2
,
num_splits
,
512
,
1024
,
dtype
=
torch
.
float32
)
filename
=
f
"./2stems/{prefix}.onnx"
torch
.
onnx
.
export
(
...
...
@@ -56,7 +56,7 @@ def export(model, prefix):
input_names
=
[
"x"
],
output_names
=
[
"y"
],
dynamic_axes
=
{
"x"
:
{
0
:
"num_splits"
},
"x"
:
{
1
:
"num_splits"
},
},
opset_version
=
13
,
)
...
...
scripts/spleeter/separate.py
查看文件 @
716ba83
...
...
@@ -101,13 +101,17 @@ def main():
print
(
"y2"
,
y
.
shape
,
y
.
dtype
)
y
=
y
.
abs
()
y
=
y
.
permute
(
0
,
3
,
1
,
2
)
# (1, 2, 512, 1024)
y
=
y
.
permute
(
3
,
0
,
1
,
2
)
# (2, 1, 512, 1024)
print
(
"y3"
,
y
.
shape
,
y
.
dtype
)
vocals_spec
=
vocals
(
y
)
accompaniment_spec
=
accompaniment
(
y
)
vocals_spec
=
vocals_spec
.
permute
(
1
,
0
,
2
,
3
)
accompaniment_spec
=
accompaniment_spec
.
permute
(
1
,
0
,
2
,
3
)
sum_spec
=
(
vocals_spec
**
2
+
accompaniment_spec
**
2
)
+
1e-10
print
(
"vocals_spec"
,
...
...
scripts/spleeter/separate_onnx.py
查看文件 @
716ba83
...
...
@@ -12,15 +12,14 @@ from separate import load_audio
"""
----------inputs for ./2stems/vocals.onnx----------
NodeArg(name='x', type='tensor(float)', shape=[
'num_splits', 2
, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[
2, 'num_splits'
, 512, 1024])
----------outputs for ./2stems/vocals.onnx----------
NodeArg(name='y', type='tensor(float)', shape=[
'Muly_dim_0', 2
, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[
2, 'Transposey_dim_1'
, 512, 1024])
----------inputs for ./2stems/accompaniment.onnx----------
NodeArg(name='x', type='tensor(float)', shape=[
'num_splits', 2
, 512, 1024])
NodeArg(name='x', type='tensor(float)', shape=[
2, 'num_splits'
, 512, 1024])
----------outputs for ./2stems/accompaniment.onnx----------
NodeArg(name='y', type='tensor(float)', shape=['Muly_dim_0', 2, 512, 1024])
NodeArg(name='y', type='tensor(float)', shape=[2, 'Transposey_dim_1', 512, 1024])
"""
...
...
@@ -123,16 +122,16 @@ def main():
if
padding
>
0
:
stft0
=
torch
.
nn
.
functional
.
pad
(
stft0
,
(
0
,
0
,
0
,
padding
))
stft1
=
torch
.
nn
.
functional
.
pad
(
stft1
,
(
0
,
0
,
0
,
padding
))
stft0
=
stft0
.
reshape
(
-
1
,
1
,
512
,
1024
)
stft1
=
stft1
.
reshape
(
-
1
,
1
,
512
,
1024
)
stft0
=
stft0
.
reshape
(
1
,
-
1
,
512
,
1024
)
stft1
=
stft1
.
reshape
(
1
,
-
1
,
512
,
1024
)
stft_01
=
torch
.
cat
([
stft0
,
stft1
],
axis
=
1
)
stft_01
=
torch
.
cat
([
stft0
,
stft1
],
axis
=
0
)
print
(
"stft_01"
,
stft_01
.
shape
,
stft_01
.
dtype
)
vocals_spec
=
vocals
(
stft_01
)
accompaniment_spec
=
accompaniment
(
stft_01
)
# (num_
splits, num_channel
s, 512, 1024)
# (num_
channels, num_split
s, 512, 1024)
sum_spec
=
(
vocals_spec
.
square
()
+
accompaniment_spec
.
square
())
+
1e-10
...
...
@@ -142,8 +141,8 @@ def main():
for
name
,
spec
in
zip
(
[
"vocals"
,
"accompaniment"
],
[
vocals_spec
,
accompaniment_spec
]
):
spec_c0
=
spec
[:,
0
,
:,
:]
spec_c1
=
spec
[:,
1
,
:,
:]
spec_c0
=
spec
[
0
]
spec_c1
=
spec
[
1
]
spec_c0
=
spec_c0
.
reshape
(
-
1
,
1024
)
spec_c1
=
spec_c1
.
reshape
(
-
1
,
1024
)
...
...
scripts/spleeter/unet.py
查看文件 @
716ba83
...
...
@@ -67,6 +67,14 @@ class UNet(torch.nn.Module):
self
.
up7
=
torch
.
nn
.
Conv2d
(
1
,
2
,
kernel_size
=
4
,
dilation
=
2
,
padding
=
3
)
def
forward
(
self
,
x
):
"""
Args:
x: (num_audio_channels, num_splits, 512, 1024)
Returns:
y: (num_audio_channels, num_splits, 512, 1024)
"""
x
=
x
.
permute
(
1
,
0
,
2
,
3
)
in_x
=
x
# in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
1
,
2
,
1
,
2
),
"constant"
,
0
)
...
...
@@ -147,4 +155,5 @@ class UNet(torch.nn.Module):
up7
=
self
.
up7
(
batch12
)
up7
=
torch
.
sigmoid
(
up7
)
# (3, 2, 512, 1024)
return
up7
*
in_x
ans
=
up7
*
in_x
return
ans
.
permute
(
1
,
0
,
2
,
3
)
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
716ba83
...
...
@@ -50,6 +50,13 @@ set(sources
offline-rnn-lm.cc
offline-sense-voice-model-config.cc
offline-sense-voice-model.cc
offline-source-separation-impl.cc
offline-source-separation-model-config.cc
offline-source-separation-spleeter-model-config.cc
offline-source-separation-spleeter-model.cc
offline-source-separation.cc
offline-stream.cc
offline-tdnn-ctc-model.cc
offline-tdnn-model-config.cc
...
...
@@ -326,6 +333,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable
(
sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc
)
add_executable
(
sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc
)
add_executable
(
sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc
)
add_executable
(
sherpa-onnx-offline-source-separation sherpa-onnx-offline-source-separation.cc
)
add_executable
(
sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc
)
add_executable
(
sherpa-onnx-vad sherpa-onnx-vad.cc
)
...
...
@@ -346,6 +354,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-offline-source-separation
sherpa-onnx-online-punctuation
sherpa-onnx-vad
)
...
...
sherpa-onnx/csrc/offline-source-separation-impl.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h"
namespace
sherpa_onnx
{
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
OfflineSourceSeparationImpl
::
Create
(
const
OfflineSourceSeparationConfig
&
config
)
{
// TODO(fangjun): Support other models
return
std
::
make_unique
<
OfflineSourceSeparationSpleeterImpl
>
(
config
);
}
template
<
typename
Manager
>
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
OfflineSourceSeparationImpl
::
Create
(
Manager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
)
{
// TODO(fangjun): Support other models
return
std
::
make_unique
<
OfflineSourceSeparationSpleeterImpl
>
(
mgr
,
config
);
}
#if __ANDROID_API__ >= 9
template
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
OfflineSourceSeparationImpl
::
Create
(
AAssetManager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
#endif
#if __OHOS__
template
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
OfflineSourceSeparationImpl
::
Create
(
NativeResourceManager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-source-separation-impl.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-source-separation.h"
namespace
sherpa_onnx
{
class
OfflineSourceSeparationImpl
{
public
:
static
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
Create
(
const
OfflineSourceSeparationConfig
&
config
);
template
<
typename
Manager
>
static
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
Create
(
Manager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
virtual
~
OfflineSourceSeparationImpl
()
=
default
;
virtual
OfflineSourceSeparationOutput
Process
(
const
OfflineSourceSeparationInput
&
input
)
const
=
0
;
virtual
int32_t
GetOutputSampleRate
()
const
=
0
;
virtual
int32_t
GetNumberOfStems
()
const
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_IMPL_H_
...
...
sherpa-onnx/csrc/offline-source-separation-model-config.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
namespace
sherpa_onnx
{
void
OfflineSourceSeparationModelConfig
::
Register
(
ParseOptions
*
po
)
{
spleeter
.
Register
(
po
);
po
->
Register
(
"num-threads"
,
&
num_threads
,
"Number of threads to run the neural network"
);
po
->
Register
(
"debug"
,
&
debug
,
"true to print model information while loading it."
);
po
->
Register
(
"provider"
,
&
provider
,
"Specify a provider to use: cpu, cuda, coreml"
);
}
bool
OfflineSourceSeparationModelConfig
::
Validate
()
const
{
return
spleeter
.
Validate
();
}
std
::
string
OfflineSourceSeparationModelConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSourceSeparationModelConfig("
;
os
<<
"spleeter="
<<
spleeter
.
ToString
()
<<
", "
;
os
<<
"num_threads="
<<
num_threads
<<
", "
;
os
<<
"debug="
<<
(
debug
?
"True"
:
"False"
)
<<
", "
;
os
<<
"provider=
\"
"
<<
provider
<<
"
\"
)"
;
return
os
.
str
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-source-separation-model-config.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
OfflineSourceSeparationModelConfig
{
OfflineSourceSeparationSpleeterModelConfig
spleeter
;
int32_t
num_threads
=
1
;
bool
debug
=
false
;
std
::
string
provider
=
"cpu"
;
OfflineSourceSeparationModelConfig
()
=
default
;
OfflineSourceSeparationModelConfig
(
const
OfflineSourceSeparationSpleeterModelConfig
&
spleeter
,
int32_t
num_threads
,
bool
debug
,
const
std
::
string
&
provider
)
:
spleeter
(
spleeter
),
num_threads
(
num_threads
),
debug
(
debug
),
provider
(
provider
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_MODEL_CONFIG_H_
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
#include "Eigen/Dense"
#include "kaldi-native-fbank/csrc/istft.h"
#include "kaldi-native-fbank/csrc/stft.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/resample.h"
namespace
sherpa_onnx
{
class
OfflineSourceSeparationSpleeterImpl
:
public
OfflineSourceSeparationImpl
{
public
:
OfflineSourceSeparationSpleeterImpl
(
const
OfflineSourceSeparationConfig
&
config
)
:
config_
(
config
),
model_
(
config_
.
model
)
{}
template
<
typename
Manager
>
OfflineSourceSeparationSpleeterImpl
(
Manager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
)
:
config_
(
config
),
model_
(
mgr
,
config_
.
model
)
{}
OfflineSourceSeparationOutput
Process
(
const
OfflineSourceSeparationInput
&
input
)
const
override
{
const
OfflineSourceSeparationInput
*
p_input
=
&
input
;
OfflineSourceSeparationInput
tmp_input
;
int32_t
output_sample_rate
=
GetOutputSampleRate
();
if
(
input
.
sample_rate
!=
output_sample_rate
)
{
SHERPA_ONNX_LOGE
(
"Creating a resampler:
\n
"
" in_sample_rate: %d
\n
"
" output_sample_rate: %d
\n
"
,
input
.
sample_rate
,
output_sample_rate
);
float
min_freq
=
std
::
min
<
int32_t
>
(
input
.
sample_rate
,
output_sample_rate
);
float
lowpass_cutoff
=
0
.
99
*
0
.
5
*
min_freq
;
int32_t
lowpass_filter_width
=
6
;
auto
resampler
=
std
::
make_unique
<
LinearResample
>
(
input
.
sample_rate
,
output_sample_rate
,
lowpass_cutoff
,
lowpass_filter_width
);
std
::
vector
<
float
>
s
;
for
(
const
auto
&
samples
:
input
.
samples
.
data
)
{
resampler
->
Reset
();
resampler
->
Resample
(
samples
.
data
(),
samples
.
size
(),
true
,
&
s
);
tmp_input
.
samples
.
data
.
push_back
(
std
::
move
(
s
));
}
tmp_input
.
sample_rate
=
output_sample_rate
;
p_input
=
&
tmp_input
;
}
if
(
p_input
->
samples
.
data
.
size
()
>
1
)
{
if
(
config_
.
model
.
debug
)
{
SHERPA_ONNX_LOGE
(
"input ch1 samples size: %d"
,
static_cast
<
int32_t
>
(
p_input
->
samples
.
data
[
1
].
size
()));
}
if
(
p_input
->
samples
.
data
[
0
].
size
()
!=
p_input
->
samples
.
data
[
1
].
size
())
{
SHERPA_ONNX_LOGE
(
"ch0 samples size %d vs ch1 samples size %d"
,
static_cast
<
int32_t
>
(
p_input
->
samples
.
data
[
0
].
size
()),
static_cast
<
int32_t
>
(
p_input
->
samples
.
data
[
1
].
size
()));
SHERPA_ONNX_EXIT
(
-
1
);
}
}
auto
stft_ch0
=
ComputeStft
(
*
p_input
,
0
);
auto
stft_ch1
=
ComputeStft
(
*
p_input
,
1
);
knf
::
StftResult
*
p_stft_ch1
=
stft_ch1
.
real
.
empty
()
?
&
stft_ch0
:
&
stft_ch1
;
int32_t
num_frames
=
stft_ch0
.
num_frames
;
int32_t
fft_bins
=
stft_ch0
.
real
.
size
()
/
num_frames
;
int32_t
pad
=
512
-
(
stft_ch0
.
num_frames
%
512
);
if
(
pad
<
512
)
{
num_frames
+=
pad
;
}
if
(
num_frames
%
512
)
{
SHERPA_ONNX_LOGE
(
"num_frames should be multiple of 512, actual: %d. %d"
,
num_frames
,
num_frames
%
512
);
SHERPA_ONNX_EXIT
(
-
1
);
}
Eigen
::
VectorXf
real
(
2
*
num_frames
*
1024
);
Eigen
::
VectorXf
imag
(
2
*
num_frames
*
1024
);
real
.
setZero
();
imag
.
setZero
();
float
*
p_real
=
&
real
[
0
];
float
*
p_imag
=
&
imag
[
0
];
// copy stft result of channel 0
for
(
int32_t
i
=
0
;
i
!=
stft_ch0
.
num_frames
;
++
i
)
{
std
::
copy
(
stft_ch0
.
real
.
data
()
+
i
*
fft_bins
,
stft_ch0
.
real
.
data
()
+
i
*
fft_bins
+
1024
,
p_real
+
1024
*
i
);
std
::
copy
(
stft_ch0
.
imag
.
data
()
+
i
*
fft_bins
,
stft_ch0
.
imag
.
data
()
+
i
*
fft_bins
+
1024
,
p_imag
+
1024
*
i
);
}
p_real
+=
num_frames
*
1024
;
p_imag
+=
num_frames
*
1024
;
// copy stft result of channel 1
for
(
int32_t
i
=
0
;
i
!=
stft_ch1
.
num_frames
;
++
i
)
{
std
::
copy
(
p_stft_ch1
->
real
.
data
()
+
i
*
fft_bins
,
p_stft_ch1
->
real
.
data
()
+
i
*
fft_bins
+
1024
,
p_real
+
1024
*
i
);
std
::
copy
(
p_stft_ch1
->
imag
.
data
()
+
i
*
fft_bins
,
p_stft_ch1
->
imag
.
data
()
+
i
*
fft_bins
+
1024
,
p_imag
+
1024
*
i
);
}
Eigen
::
VectorXf
x
=
(
real
.
array
().
square
()
+
imag
.
array
().
square
()).
sqrt
();
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
4
>
x_shape
{
2
,
num_frames
/
512
,
512
,
1024
};
Ort
::
Value
x_tensor
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
x
[
0
],
x
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
Ort
::
Value
vocals_spec_tensor
=
model_
.
RunVocals
(
View
(
&
x_tensor
));
Ort
::
Value
accompaniment_spec_tensor
=
model_
.
RunAccompaniment
(
std
::
move
(
x_tensor
));
Eigen
::
VectorXf
vocals_spec
=
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
vocals_spec_tensor
.
GetTensorMutableData
<
float
>
(),
x
.
size
());
Eigen
::
VectorXf
accompaniment_spec
=
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
accompaniment_spec_tensor
.
GetTensorMutableData
<
float
>
(),
x
.
size
());
Eigen
::
VectorXf
sum_spec
=
vocals_spec
.
array
().
square
()
+
accompaniment_spec
.
array
().
square
()
+
1e-10
;
vocals_spec
=
(
vocals_spec
.
array
().
square
()
+
1e-10
/
2
)
/
sum_spec
.
array
();
accompaniment_spec
=
(
accompaniment_spec
.
array
().
square
()
+
1e-10
/
2
)
/
sum_spec
.
array
();
auto
vocals_samples_ch0
=
ProcessSpec
(
vocals_spec
,
stft_ch0
,
0
);
auto
vocals_samples_ch1
=
ProcessSpec
(
vocals_spec
,
*
p_stft_ch1
,
1
);
auto
accompaniment_samples_ch0
=
ProcessSpec
(
accompaniment_spec
,
stft_ch0
,
0
);
auto
accompaniment_samples_ch1
=
ProcessSpec
(
accompaniment_spec
,
*
p_stft_ch1
,
1
);
OfflineSourceSeparationOutput
ans
;
ans
.
sample_rate
=
GetOutputSampleRate
();
ans
.
stems
.
resize
(
2
);
ans
.
stems
[
0
].
data
.
reserve
(
2
);
ans
.
stems
[
1
].
data
.
reserve
(
2
);
ans
.
stems
[
0
].
data
.
push_back
(
std
::
move
(
vocals_samples_ch0
));
ans
.
stems
[
0
].
data
.
push_back
(
std
::
move
(
vocals_samples_ch1
));
ans
.
stems
[
1
].
data
.
push_back
(
std
::
move
(
accompaniment_samples_ch0
));
ans
.
stems
[
1
].
data
.
push_back
(
std
::
move
(
accompaniment_samples_ch1
));
return
ans
;
}
int32_t
GetOutputSampleRate
()
const
override
{
return
model_
.
GetMetaData
().
sample_rate
;
}
int32_t
GetNumberOfStems
()
const
override
{
return
model_
.
GetMetaData
().
num_stems
;
}
private
:
// spec is of shape (2, num_chunks, 512, 1024)
std
::
vector
<
float
>
ProcessSpec
(
const
Eigen
::
VectorXf
&
spec
,
const
knf
::
StftResult
&
stft
,
int32_t
channel
)
const
{
int32_t
fft_bins
=
stft
.
real
.
size
()
/
stft
.
num_frames
;
Eigen
::
VectorXf
mask
(
stft
.
real
.
size
());
mask
.
setZero
();
float
*
p_mask
=
&
mask
[
0
];
// assume there are 2 channels
const
float
*
p_spec
=
&
spec
[
0
]
+
(
spec
.
size
()
/
2
)
*
channel
;
for
(
int32_t
i
=
0
;
i
!=
stft
.
num_frames
;
++
i
)
{
std
::
copy
(
p_spec
+
i
*
1024
,
p_spec
+
(
i
+
1
)
*
1024
,
p_mask
+
i
*
fft_bins
);
}
knf
::
StftResult
masked_stft
;
masked_stft
.
num_frames
=
stft
.
num_frames
;
masked_stft
.
real
.
resize
(
stft
.
real
.
size
());
masked_stft
.
imag
.
resize
(
stft
.
imag
.
size
());
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
masked_stft
.
real
.
data
(),
masked_stft
.
real
.
size
())
=
mask
.
array
()
*
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
const_cast
<
float
*>
(
stft
.
real
.
data
()),
stft
.
real
.
size
())
.
array
();
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
masked_stft
.
imag
.
data
(),
masked_stft
.
imag
.
size
())
=
mask
.
array
()
*
Eigen
::
Map
<
Eigen
::
VectorXf
>
(
const_cast
<
float
*>
(
stft
.
imag
.
data
()),
stft
.
imag
.
size
())
.
array
();
auto
stft_config
=
GetStftConfig
();
knf
::
IStft
istft
(
stft_config
);
return
istft
.
Compute
(
masked_stft
);
}
knf
::
StftResult
ComputeStft
(
const
OfflineSourceSeparationInput
&
input
,
int32_t
ch
)
const
{
if
(
ch
>=
input
.
samples
.
data
.
size
())
{
SHERPA_ONNX_LOGE
(
"Invalid channel %d. Max %d"
,
ch
,
static_cast
<
int32_t
>
(
input
.
samples
.
data
.
size
()));
SHERPA_ONNX_EXIT
(
-
1
);
}
if
(
input
.
samples
.
data
[
ch
].
empty
())
{
return
{};
}
return
ComputeStft
(
input
.
samples
.
data
[
ch
]);
}
knf
::
StftResult
ComputeStft
(
const
std
::
vector
<
float
>
&
samples
)
const
{
auto
stft_config
=
GetStftConfig
();
knf
::
Stft
stft
(
stft_config
);
return
stft
.
Compute
(
samples
.
data
(),
samples
.
size
());
}
knf
::
StftConfig
GetStftConfig
()
const
{
const
auto
&
meta
=
model_
.
GetMetaData
();
knf
::
StftConfig
stft_config
;
stft_config
.
n_fft
=
meta
.
n_fft
;
stft_config
.
hop_length
=
meta
.
hop_length
;
stft_config
.
win_length
=
meta
.
window_length
;
stft_config
.
window_type
=
meta
.
window_type
;
stft_config
.
center
=
meta
.
center
;
stft_config
.
center
=
false
;
return
stft_config
;
}
private
:
OfflineSourceSeparationConfig
config_
;
OfflineSourceSeparationSpleeterModel
model_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_IMPL_H_
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
void
OfflineSourceSeparationSpleeterModelConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"spleeter-vocals"
,
&
vocals
,
"Path to the spleeter vocals model"
);
po
->
Register
(
"spleeter-accompaniment"
,
&
accompaniment
,
"Path to the spleeter accompaniment model"
);
}
bool
OfflineSourceSeparationSpleeterModelConfig
::
Validate
()
const
{
if
(
vocals
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --spleeter-vocals"
);
return
false
;
}
if
(
!
FileExists
(
vocals
))
{
SHERPA_ONNX_LOGE
(
"spleeter vocals '%s' does not exist. "
,
vocals
.
c_str
());
return
false
;
}
if
(
accompaniment
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide --spleeter-accompaniment"
);
return
false
;
}
if
(
!
FileExists
(
accompaniment
))
{
SHERPA_ONNX_LOGE
(
"spleeter accompaniment '%s' does not exist. "
,
accompaniment
.
c_str
());
return
false
;
}
return
true
;
}
std
::
string
OfflineSourceSeparationSpleeterModelConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSourceSeparationSpleeterModelConfig("
;
os
<<
"vocals=
\"
"
<<
vocals
<<
"
\"
, "
;
os
<<
"accompaniment=
\"
"
<<
accompaniment
<<
"
\"
)"
;
return
os
.
str
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter_model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
OfflineSourceSeparationSpleeterModelConfig
{
std
::
string
vocals
;
std
::
string
accompaniment
;
OfflineSourceSeparationSpleeterModelConfig
()
=
default
;
OfflineSourceSeparationSpleeterModelConfig
(
const
std
::
string
&
vocals
,
const
std
::
string
&
accompaniment
)
:
vocals
(
vocals
),
accompaniment
(
accompaniment
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_CONFIG_H_
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
#include <string>
#include <unordered_map>
#include <vector>
namespace
sherpa_onnx
{
// See also
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/spleeter/separate_onnx.py
struct
OfflineSourceSeparationSpleeterModelMetaData
{
int32_t
sample_rate
=
44100
;
int32_t
num_stems
=
2
;
int32_t
n_fft
=
4096
;
int32_t
hop_length
=
1024
;
int32_t
window_length
=
4096
;
bool
center
=
false
;
std
::
string
window_type
=
"hann"
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_META_DATA_H_
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace
sherpa_onnx
{
class
OfflineSourceSeparationSpleeterModel
::
Impl
{
public
:
explicit
Impl
(
const
OfflineSourceSeparationModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
config
.
spleeter
.
vocals
);
InitVocals
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
config
.
spleeter
.
accompaniment
);
InitAccompaniment
(
buf
.
data
(),
buf
.
size
());
}
}
template
<
typename
Manager
>
Impl
(
Manager
*
mgr
,
const
OfflineSourceSeparationModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
mgr
,
config
.
spleeter
.
vocals
);
InitVocals
(
buf
.
data
(),
buf
.
size
());
}
{
auto
buf
=
ReadFile
(
mgr
,
config
.
spleeter
.
accompaniment
);
InitAccompaniment
(
buf
.
data
(),
buf
.
size
());
}
}
const
OfflineSourceSeparationSpleeterModelMetaData
&
GetMetaData
()
const
{
return
meta_
;
}
Ort
::
Value
RunVocals
(
Ort
::
Value
x
)
const
{
auto
out
=
vocals_sess_
->
Run
({},
vocals_input_names_ptr_
.
data
(),
&
x
,
1
,
vocals_output_names_ptr_
.
data
(),
vocals_output_names_ptr_
.
size
());
return
std
::
move
(
out
[
0
]);
}
Ort
::
Value
RunAccompaniment
(
Ort
::
Value
x
)
const
{
auto
out
=
accompaniment_sess_
->
Run
({},
accompaniment_input_names_ptr_
.
data
(),
&
x
,
1
,
accompaniment_output_names_ptr_
.
data
(),
accompaniment_output_names_ptr_
.
size
());
return
std
::
move
(
out
[
0
]);
}
private
:
void
InitVocals
(
void
*
model_data
,
size_t
model_data_length
)
{
vocals_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
vocals_sess_
.
get
(),
&
vocals_input_names_
,
&
vocals_input_names_ptr_
);
GetOutputNames
(
vocals_sess_
.
get
(),
&
vocals_output_names_
,
&
vocals_output_names_ptr_
);
Ort
::
ModelMetadata
meta_data
=
vocals_sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
os
<<
"---vocals model---
\n
"
;
PrintModelMetadata
(
os
,
meta_data
);
os
<<
"----------input names----------
\n
"
;
int32_t
i
=
0
;
for
(
const
auto
&
s
:
vocals_input_names_
)
{
os
<<
i
<<
" "
<<
s
<<
"
\n
"
;
++
i
;
}
os
<<
"----------output names----------
\n
"
;
i
=
0
;
for
(
const
auto
&
s
:
vocals_output_names_
)
{
os
<<
i
<<
" "
<<
s
<<
"
\n
"
;
++
i
;
}
#if __OHOS__
SHERPA_ONNX_LOGE
(
"%{public}s
\n
"
,
os
.
str
().
c_str
());
#else
SHERPA_ONNX_LOGE
(
"%s
\n
"
,
os
.
str
().
c_str
());
#endif
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
// used in the macro below
std
::
string
model_type
;
SHERPA_ONNX_READ_META_DATA_STR
(
model_type
,
"model_type"
);
if
(
model_type
!=
"spleeter"
)
{
SHERPA_ONNX_LOGE
(
"Expect model type 'spleeter'. Given: '%s'"
,
model_type
.
c_str
());
SHERPA_ONNX_EXIT
(
-
1
);
}
SHERPA_ONNX_READ_META_DATA
(
meta_
.
num_stems
,
"stems"
);
if
(
meta_
.
num_stems
!=
2
)
{
SHERPA_ONNX_LOGE
(
"Only 2stems is supported. Given %d stems"
,
meta_
.
num_stems
);
SHERPA_ONNX_EXIT
(
-
1
);
}
}
void
InitAccompaniment
(
void
*
model_data
,
size_t
model_data_length
)
{
accompaniment_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
accompaniment_sess_
.
get
(),
&
accompaniment_input_names_
,
&
accompaniment_input_names_ptr_
);
GetOutputNames
(
accompaniment_sess_
.
get
(),
&
accompaniment_output_names_
,
&
accompaniment_output_names_ptr_
);
}
private
:
OfflineSourceSeparationModelConfig
config_
;
OfflineSourceSeparationSpleeterModelMetaData
meta_
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
vocals_sess_
;
std
::
vector
<
std
::
string
>
vocals_input_names_
;
std
::
vector
<
const
char
*>
vocals_input_names_ptr_
;
std
::
vector
<
std
::
string
>
vocals_output_names_
;
std
::
vector
<
const
char
*>
vocals_output_names_ptr_
;
std
::
unique_ptr
<
Ort
::
Session
>
accompaniment_sess_
;
std
::
vector
<
std
::
string
>
accompaniment_input_names_
;
std
::
vector
<
const
char
*>
accompaniment_input_names_ptr_
;
std
::
vector
<
std
::
string
>
accompaniment_output_names_
;
std
::
vector
<
const
char
*>
accompaniment_output_names_ptr_
;
};
OfflineSourceSeparationSpleeterModel
::~
OfflineSourceSeparationSpleeterModel
()
=
default
;
OfflineSourceSeparationSpleeterModel
::
OfflineSourceSeparationSpleeterModel
(
const
OfflineSourceSeparationModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
template
<
typename
Manager
>
OfflineSourceSeparationSpleeterModel
::
OfflineSourceSeparationSpleeterModel
(
Manager
*
mgr
,
const
OfflineSourceSeparationModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
mgr
,
config
))
{}
Ort
::
Value
OfflineSourceSeparationSpleeterModel
::
RunVocals
(
Ort
::
Value
x
)
const
{
return
impl_
->
RunVocals
(
std
::
move
(
x
));
}
Ort
::
Value
OfflineSourceSeparationSpleeterModel
::
RunAccompaniment
(
Ort
::
Value
x
)
const
{
return
impl_
->
RunAccompaniment
(
std
::
move
(
x
));
}
const
OfflineSourceSeparationSpleeterModelMetaData
&
OfflineSourceSeparationSpleeterModel
::
GetMetaData
()
const
{
return
impl_
->
GetMetaData
();
}
#if __ANDROID_API__ >= 9
template
OfflineSourceSeparationSpleeterModel
::
OfflineSourceSeparationSpleeterModel
(
AAssetManager
*
mgr
,
const
OfflineSourceSeparationModelConfig
&
config
);
#endif
#if __OHOS__
template
OfflineSourceSeparationSpleeterModel
::
OfflineSourceSeparationSpleeterModel
(
NativeResourceManager
*
mgr
,
const
OfflineSourceSeparationModelConfig
&
config
);
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation-spleeter-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
#include "sherpa-onnx/csrc/offline-source-separation-spleeter-model-meta-data.h"
namespace
sherpa_onnx
{
class
OfflineSourceSeparationSpleeterModel
{
public
:
~
OfflineSourceSeparationSpleeterModel
();
explicit
OfflineSourceSeparationSpleeterModel
(
const
OfflineSourceSeparationModelConfig
&
config
);
template
<
typename
Manager
>
OfflineSourceSeparationSpleeterModel
(
Manager
*
mgr
,
const
OfflineSourceSeparationModelConfig
&
config
);
Ort
::
Value
RunVocals
(
Ort
::
Value
x
)
const
;
Ort
::
Value
RunAccompaniment
(
Ort
::
Value
x
)
const
;
const
OfflineSourceSeparationSpleeterModelMetaData
&
GetMetaData
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_SPLEETER_MODEL_H_
...
...
sherpa-onnx/csrc/offline-source-separation.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include <memory>
#include "sherpa-onnx/csrc/offline-source-separation-impl.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
namespace
sherpa_onnx
{
void
OfflineSourceSeparationConfig
::
Register
(
ParseOptions
*
po
)
{
model
.
Register
(
po
);
}
bool
OfflineSourceSeparationConfig
::
Validate
()
const
{
return
model
.
Validate
();
}
std
::
string
OfflineSourceSeparationConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSourceSeparationConfig("
;
os
<<
"model="
<<
model
.
ToString
()
<<
")"
;
return
os
.
str
();
}
template
<
typename
Manager
>
OfflineSourceSeparation
::
OfflineSourceSeparation
(
Manager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
)
:
impl_
(
OfflineSourceSeparationImpl
::
Create
(
mgr
,
config
))
{}
OfflineSourceSeparation
::
OfflineSourceSeparation
(
const
OfflineSourceSeparationConfig
&
config
)
:
impl_
(
OfflineSourceSeparationImpl
::
Create
(
config
))
{}
OfflineSourceSeparation
::~
OfflineSourceSeparation
()
=
default
;
OfflineSourceSeparationOutput
OfflineSourceSeparation
::
Process
(
const
OfflineSourceSeparationInput
&
input
)
const
{
return
impl_
->
Process
(
input
);
}
int32_t
OfflineSourceSeparation
::
GetOutputSampleRate
()
const
{
return
impl_
->
GetOutputSampleRate
();
}
// e.g., it is 2 for 2stems from spleeter
int32_t
OfflineSourceSeparation
::
GetNumberOfStems
()
const
{
return
impl_
->
GetNumberOfStems
();
}
#if __ANDROID_API__ >= 9
template
OfflineSourceSeparation
::
OfflineSourceSeparation
(
AAssetManager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
#endif
#if __OHOS__
template
OfflineSourceSeparation
::
OfflineSourceSeparation
(
NativeResourceManager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
#endif
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-source-separation.h
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/offline-source-separation.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-source-separation-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
OfflineSourceSeparationConfig
{
OfflineSourceSeparationModelConfig
model
;
OfflineSourceSeparationConfig
()
=
default
;
OfflineSourceSeparationConfig
(
const
OfflineSourceSeparationModelConfig
&
model
)
:
model
(
model
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
struct
MultiChannelSamples
{
// data[i] is for the i-th channel
//
// each sample is in the range [-1, 1]
std
::
vector
<
std
::
vector
<
float
>>
data
;
};
struct
OfflineSourceSeparationInput
{
MultiChannelSamples
samples
;
int32_t
sample_rate
;
};
struct
OfflineSourceSeparationOutput
{
std
::
vector
<
MultiChannelSamples
>
stems
;
int32_t
sample_rate
;
};
class
OfflineSourceSeparationImpl
;
class
OfflineSourceSeparation
{
public
:
~
OfflineSourceSeparation
();
OfflineSourceSeparation
(
const
OfflineSourceSeparationConfig
&
config
);
template
<
typename
Manager
>
OfflineSourceSeparation
(
Manager
*
mgr
,
const
OfflineSourceSeparationConfig
&
config
);
OfflineSourceSeparationOutput
Process
(
const
OfflineSourceSeparationInput
&
input
)
const
;
int32_t
GetOutputSampleRate
()
const
;
// e.g., it is 2 for 2stems from spleeter
int32_t
GetNumberOfStems
()
const
;
private
:
std
::
unique_ptr
<
OfflineSourceSeparationImpl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SOURCE_SEPARATION_H_
...
...
sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h
查看文件 @
716ba83
...
...
@@ -12,7 +12,7 @@
namespace
sherpa_onnx
{
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/
kokoro/add-meta-
data.py
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/
gtcrn/add_meta_
data.py
struct
OfflineSpeechDenoiserGtcrnModelMetaData
{
int32_t
sample_rate
=
0
;
int32_t
version
=
1
;
...
...
sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
查看文件 @
716ba83
...
...
@@ -11,7 +11,7 @@
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Non-stre
ma
ing speech denoising with sherpa-onnx.
Non-stre
am
ing speech denoising with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
...
...
sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
0 → 100644
查看文件 @
716ba83
// sherpa-onnx/csrc/sherpa-onnx-offline-source-separation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include "sherpa-onnx/csrc/offline-source-separation.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Non-streaming source separation with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/source-separation-models
to download models.
Usage:
(1) Use spleeter models
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/sherpa-onnx-spleeter-2stems-fp16.tar.bz2
tar xvf sherpa-onnx-spleeter-2stems-fp16.tar.bz2
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
./bin/sherpa-onnx-offline-source-separation \
--spleeter-vocals=sherpa-onnx-spleeter-2stems-fp16/vocals.fp16.onnx \
--spleeter-accompaniment=sherpa-onnx-spleeter-2stems-fp16/accompaniment.fp16.onnx \
--input-wav=audio_example.wav \
--output-vocals-wav=output_vocals.wav \
--output-accompaniment-wav=output_accompaniment.wav
)usage"
;
sherpa_onnx
::
ParseOptions
po
(
kUsageMessage
);
sherpa_onnx
::
OfflineSourceSeparationConfig
config
;
std
::
string
input_wave
;
std
::
string
output_vocals_wave
;
std
::
string
output_accompaniment_wave
;
config
.
Register
(
&
po
);
po
.
Register
(
"input-wav"
,
&
input_wave
,
"Path to input wav."
);
po
.
Register
(
"output-vocals-wav"
,
&
output_vocals_wave
,
"Path to output vocals wav"
);
po
.
Register
(
"output-accompaniment-wav"
,
&
output_accompaniment_wave
,
"Path to output accompaniment wav"
);
po
.
Read
(
argc
,
argv
);
if
(
po
.
NumArgs
()
!=
0
)
{
fprintf
(
stderr
,
"Please don't give positional arguments
\n
"
);
po
.
PrintUsage
();
exit
(
EXIT_FAILURE
);
}
fprintf
(
stderr
,
"%s
\n
"
,
config
.
ToString
().
c_str
());
if
(
input_wave
.
empty
())
{
fprintf
(
stderr
,
"Please provide --input-wav
\n
"
);
po
.
PrintUsage
();
exit
(
EXIT_FAILURE
);
}
if
(
output_vocals_wave
.
empty
())
{
fprintf
(
stderr
,
"Please provide --output-vocals-wav
\n
"
);
po
.
PrintUsage
();
exit
(
EXIT_FAILURE
);
}
if
(
output_accompaniment_wave
.
empty
())
{
fprintf
(
stderr
,
"Please provide --output-accompaniment-wav
\n
"
);
po
.
PrintUsage
();
exit
(
EXIT_FAILURE
);
}
if
(
!
config
.
Validate
())
{
fprintf
(
stderr
,
"Errors in config!
\n
"
);
exit
(
EXIT_FAILURE
);
}
bool
is_ok
=
false
;
sherpa_onnx
::
OfflineSourceSeparationInput
input
;
input
.
samples
.
data
=
sherpa_onnx
::
ReadWaveMultiChannel
(
input_wave
,
&
input
.
sample_rate
,
&
is_ok
);
if
(
!
is_ok
)
{
fprintf
(
stderr
,
"Failed to read '%s'
\n
"
,
input_wave
.
c_str
());
return
-
1
;
}
fprintf
(
stderr
,
"Started
\n
"
);
sherpa_onnx
::
OfflineSourceSeparation
sp
(
config
);
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
auto
output
=
sp
.
Process
(
input
);
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
elapsed_seconds
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
begin
)
.
count
()
/
1000.
;
is_ok
=
sherpa_onnx
::
WriteWave
(
output_vocals_wave
,
output
.
sample_rate
,
output
.
stems
[
0
].
data
[
0
].
data
(),
output
.
stems
[
0
].
data
[
1
].
data
(),
output
.
stems
[
0
].
data
[
0
].
size
());
if
(
!
is_ok
)
{
fprintf
(
stderr
,
"Failed to write to '%s'
\n
"
,
output_vocals_wave
.
c_str
());
exit
(
EXIT_FAILURE
);
}
is_ok
=
sherpa_onnx
::
WriteWave
(
output_accompaniment_wave
,
output
.
sample_rate
,
output
.
stems
[
1
].
data
[
0
].
data
(),
output
.
stems
[
1
].
data
[
1
].
data
(),
output
.
stems
[
1
].
data
[
0
].
size
());
if
(
!
is_ok
)
{
fprintf
(
stderr
,
"Failed to write to '%s'
\n
"
,
output_accompaniment_wave
.
c_str
());
exit
(
EXIT_FAILURE
);
}
fprintf
(
stderr
,
"Done
\n
"
);
fprintf
(
stderr
,
"Saved to write to '%s' and '%s'
\n
"
,
output_vocals_wave
.
c_str
(),
output_accompaniment_wave
.
c_str
());
float
duration
=
input
.
samples
.
data
[
0
].
size
()
/
static_cast
<
float
>
(
input
.
sample_rate
);
fprintf
(
stderr
,
"num threads: %d
\n
"
,
config
.
model
.
num_threads
);
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
float
rtf
=
elapsed_seconds
/
duration
;
fprintf
(
stderr
,
"Real time factor (RTF): %.3f / %.3f = %.3f
\n
"
,
elapsed_seconds
,
duration
,
rtf
);
return
0
;
}
...
...
sherpa-onnx/csrc/wave-reader.cc
查看文件 @
716ba83
...
...
@@ -63,8 +63,9 @@ in sherpa-onnx.
// Read a wave file of mono-channel.
// Return its samples normalized to the range [-1, 1).
std
::
vector
<
float
>
ReadWaveImpl
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
std
::
vector
<
std
::
vector
<
float
>>
ReadWaveImpl
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
WaveHeader
header
{};
is
.
read
(
reinterpret_cast
<
char
*>
(
&
header
.
chunk_id
),
sizeof
(
header
.
chunk_id
));
...
...
@@ -144,12 +145,6 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
is
.
read
(
reinterpret_cast
<
char
*>
(
&
header
.
num_channels
),
sizeof
(
header
.
num_channels
));
if
(
header
.
num_channels
!=
1
)
{
// we support only single channel for now
SHERPA_ONNX_LOGE
(
"Warning: %d channels are found. We only use the first channel.
\n
"
,
header
.
num_channels
);
}
is
.
read
(
reinterpret_cast
<
char
*>
(
&
header
.
sample_rate
),
sizeof
(
header
.
sample_rate
));
...
...
@@ -219,7 +214,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
*
sampling_rate
=
header
.
sample_rate
;
std
::
vector
<
float
>
ans
;
std
::
vector
<
std
::
vector
<
float
>>
ans
(
header
.
num_channels
)
;
if
(
header
.
bits_per_sample
==
16
&&
header
.
audio_format
==
1
)
{
// header.subchunk2_size contains the number of bytes in the data.
...
...
@@ -233,11 +228,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return
{};
}
ans
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
for
(
auto
&
v
:
ans
)
{
v
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
}
// samples are interleaved
for
(
int32_t
i
=
0
;
i
!=
static_cast
<
int32_t
>
(
ans
.
size
());
++
i
)
{
ans
[
i
]
=
samples
[
i
*
header
.
num_channels
]
/
32768.
;
for
(
int32_t
i
=
0
,
k
=
0
;
i
<
static_cast
<
int32_t
>
(
samples
.
size
());
i
+=
header
.
num_channels
,
++
k
)
{
for
(
int32_t
c
=
0
;
c
!=
header
.
num_channels
;
++
c
)
{
ans
[
c
][
k
]
=
samples
[
i
+
c
]
/
32768.
;
}
}
}
else
if
(
header
.
bits_per_sample
==
8
&&
header
.
audio_format
==
1
)
{
// number of samples == number of bytes for 8-bit encoded samples
...
...
@@ -252,14 +252,21 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return
{};
}
ans
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
for
(
int32_t
i
=
0
;
i
!=
static_cast
<
int32_t
>
(
ans
.
size
());
++
i
)
{
// Note(fangjun): We want to normalize each sample into the range [-1, 1]
// Since each original sample is in the range [0, 256], dividing
// them by 128 converts them to the range [0, 2];
// so after subtracting 1, we get the range [-1, 1]
//
ans
[
i
]
=
samples
[
i
*
header
.
num_channels
]
/
128.
-
1
;
for
(
auto
&
v
:
ans
)
{
v
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
}
// samples are interleaved
for
(
int32_t
i
=
0
,
k
=
0
;
i
<
static_cast
<
int32_t
>
(
samples
.
size
());
i
+=
header
.
num_channels
,
++
k
)
{
for
(
int32_t
c
=
0
;
c
!=
header
.
num_channels
;
++
c
)
{
// Note(fangjun): We want to normalize each sample into the range [-1,
// 1] Since each original sample is in the range [0, 256], dividing them
// by 128 converts them to the range [0, 2]; so after subtracting 1, we
// get the range [-1, 1]
//
ans
[
c
][
k
]
=
samples
[
i
+
c
]
/
128.
-
1
;
}
}
}
else
if
(
header
.
bits_per_sample
==
32
&&
header
.
audio_format
==
1
)
{
// 32 here is for int32
...
...
@@ -275,9 +282,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return
{};
}
ans
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
for
(
int32_t
i
=
0
;
i
!=
static_cast
<
int32_t
>
(
ans
.
size
());
++
i
)
{
ans
[
i
]
=
static_cast
<
float
>
(
samples
[
i
*
header
.
num_channels
])
/
(
1
<<
31
);
for
(
auto
&
v
:
ans
)
{
v
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
}
// samples are interleaved
for
(
int32_t
i
=
0
,
k
=
0
;
i
<
static_cast
<
int32_t
>
(
samples
.
size
());
i
+=
header
.
num_channels
,
++
k
)
{
for
(
int32_t
c
=
0
;
c
!=
header
.
num_channels
;
++
c
)
{
ans
[
c
][
k
]
=
static_cast
<
float
>
(
samples
[
i
+
c
])
/
(
1
<<
31
);
}
}
}
else
if
(
header
.
bits_per_sample
==
32
&&
header
.
audio_format
==
3
)
{
// 32 here is for float32
...
...
@@ -293,9 +307,16 @@ std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
return
{};
}
ans
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
for
(
int32_t
i
=
0
;
i
!=
static_cast
<
int32_t
>
(
ans
.
size
());
++
i
)
{
ans
[
i
]
=
samples
[
i
*
header
.
num_channels
];
for
(
auto
&
v
:
ans
)
{
v
.
resize
(
samples
.
size
()
/
header
.
num_channels
);
}
// samples are interleaved
for
(
int32_t
i
=
0
,
k
=
0
;
i
<
static_cast
<
int32_t
>
(
samples
.
size
());
i
+=
header
.
num_channels
,
++
k
)
{
for
(
int32_t
c
=
0
;
c
!=
header
.
num_channels
;
++
c
)
{
ans
[
c
][
k
]
=
samples
[
i
+
c
];
}
}
}
else
{
SHERPA_ONNX_LOGE
(
...
...
@@ -321,7 +342,27 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
std
::
vector
<
float
>
ReadWave
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
auto
samples
=
ReadWaveImpl
(
is
,
sampling_rate
,
is_ok
);
if
(
samples
.
size
()
>
1
)
{
SHERPA_ONNX_LOGE
(
"Warning: %d channels are found. We only use the first channel.
\n
"
,
static_cast
<
int32_t
>
(
samples
.
size
()));
}
return
samples
[
0
];
}
std
::
vector
<
std
::
vector
<
float
>>
ReadWaveMultiChannel
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
auto
samples
=
ReadWaveImpl
(
is
,
sampling_rate
,
is_ok
);
return
samples
;
}
std
::
vector
<
std
::
vector
<
float
>>
ReadWaveMultiChannel
(
const
std
::
string
&
filename
,
int32_t
*
sampling_rate
,
bool
*
is_ok
)
{
std
::
ifstream
is
(
filename
,
std
::
ifstream
::
binary
);
return
ReadWaveMultiChannel
(
is
,
sampling_rate
,
is_ok
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/wave-reader.h
查看文件 @
716ba83
...
...
@@ -26,6 +26,13 @@ std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
std
::
vector
<
float
>
ReadWave
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
);
std
::
vector
<
std
::
vector
<
float
>>
ReadWaveMultiChannel
(
std
::
istream
&
is
,
int32_t
*
sampling_rate
,
bool
*
is_ok
);
std
::
vector
<
std
::
vector
<
float
>>
ReadWaveMultiChannel
(
const
std
::
string
&
filename
,
int32_t
*
sampling_rate
,
bool
*
is_ok
);
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_
...
...
sherpa-onnx/csrc/wave-writer.cc
查看文件 @
716ba83
...
...
@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/wave-writer.h"
#include <algorithm>
#include <cstring>
#include <fstream>
#include <string>
...
...
@@ -36,12 +37,44 @@ struct WaveHeader {
}
// namespace
int64_t
WaveFileSize
(
int32_t
n_samples
)
{
return
sizeof
(
WaveHeader
)
+
n_samples
*
sizeof
(
int16_t
);
int64_t
WaveFileSize
(
int32_t
n_samples
,
int32_t
num_channels
/*= 1*/
)
{
return
sizeof
(
WaveHeader
)
+
n_samples
*
sizeof
(
int16_t
)
*
num_channels
;
}
void
WriteWave
(
char
*
buffer
,
int32_t
sampling_rate
,
const
float
*
samples
,
int32_t
n
)
{
WriteWave
(
buffer
,
sampling_rate
,
samples
,
nullptr
,
n
);
}
bool
WriteWave
(
const
std
::
string
&
filename
,
int32_t
sampling_rate
,
const
float
*
samples
,
int32_t
n
)
{
return
WriteWave
(
filename
,
sampling_rate
,
samples
,
nullptr
,
n
);
}
bool
WriteWave
(
const
std
::
string
&
filename
,
int32_t
sampling_rate
,
const
float
*
samples_ch0
,
const
float
*
samples_ch1
,
int32_t
n
)
{
std
::
string
buffer
;
buffer
.
resize
(
WaveFileSize
(
n
,
samples_ch1
==
nullptr
?
1
:
2
));
WriteWave
(
buffer
.
data
(),
sampling_rate
,
samples_ch0
,
samples_ch1
,
n
);
std
::
ofstream
os
(
filename
,
std
::
ios
::
binary
);
if
(
!
os
)
{
SHERPA_ONNX_LOGE
(
"Failed to create '%s'"
,
filename
.
c_str
());
return
false
;
}
os
<<
buffer
;
if
(
!
os
)
{
SHERPA_ONNX_LOGE
(
"Write '%s' failed"
,
filename
.
c_str
());
return
false
;
}
return
true
;
}
void
WriteWave
(
char
*
buffer
,
int32_t
sampling_rate
,
const
float
*
samples_ch0
,
const
float
*
samples_ch1
,
int32_t
n
)
{
WaveHeader
header
{};
header
.
chunk_id
=
0x46464952
;
// FFIR
header
.
format
=
0x45564157
;
// EVAW
...
...
@@ -49,8 +82,9 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
header
.
subchunk1_size
=
16
;
// 16 for PCM
header
.
audio_format
=
1
;
// PCM =1
int32_t
num_channels
=
1
;
int32_t
num_channels
=
samples_ch1
==
nullptr
?
1
:
2
;
int32_t
bits_per_sample
=
16
;
// int16_t
header
.
num_channels
=
num_channels
;
header
.
sample_rate
=
sampling_rate
;
header
.
byte_rate
=
sampling_rate
*
num_channels
*
bits_per_sample
/
8
;
...
...
@@ -61,32 +95,32 @@ void WriteWave(char *buffer, int32_t sampling_rate, const float *samples,
header
.
chunk_size
=
36
+
header
.
subchunk2_size
;
std
::
vector
<
int16_t
>
samples_int16
(
n
);
std
::
vector
<
int16_t
>
samples_int16
_ch0
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
samples_int16
[
i
]
=
samples
[
i
]
*
32767
;
samples_int16_ch0
[
i
]
=
std
::
min
<
int32_t
>
(
samples_ch0
[
i
]
*
32767
,
32767
);
}
std
::
vector
<
int16_t
>
samples_int16_ch1
;
if
(
samples_ch1
)
{
samples_int16_ch1
.
resize
(
n
);
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
samples_int16_ch1
[
i
]
=
std
::
min
<
int32_t
>
(
samples_ch1
[
i
]
*
32767
,
32767
);
}
}
memcpy
(
buffer
,
&
header
,
sizeof
(
WaveHeader
));
memcpy
(
buffer
+
sizeof
(
WaveHeader
),
samples_int16
.
data
(),
n
*
sizeof
(
int16_t
));
}
bool
WriteWave
(
const
std
::
string
&
filename
,
int32_t
sampling_rate
,
const
float
*
samples
,
int32_t
n
)
{
std
::
string
buffer
;
buffer
.
resize
(
WaveFileSize
(
n
));
WriteWave
(
buffer
.
data
(),
sampling_rate
,
samples
,
n
);
std
::
ofstream
os
(
filename
,
std
::
ios
::
binary
);
if
(
!
os
)
{
SHERPA_ONNX_LOGE
(
"Failed to create '%s'"
,
filename
.
c_str
());
return
false
;
}
os
<<
buffer
;
if
(
!
os
)
{
SHERPA_ONNX_LOGE
(
"Write '%s' failed"
,
filename
.
c_str
());
return
false
;
if
(
samples_ch1
==
nullptr
)
{
memcpy
(
buffer
+
sizeof
(
WaveHeader
),
samples_int16_ch0
.
data
(),
n
*
sizeof
(
int16_t
));
}
else
{
auto
p
=
reinterpret_cast
<
int16_t
*>
(
buffer
+
sizeof
(
WaveHeader
));
for
(
int32_t
i
=
0
;
i
!=
n
;
++
i
)
{
p
[
2
*
i
]
=
samples_int16_ch0
[
i
];
p
[
2
*
i
+
1
]
=
samples_int16_ch1
[
i
];
}
}
return
true
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/wave-writer.h
查看文件 @
716ba83
...
...
@@ -25,7 +25,13 @@ bool WriteWave(const std::string &filename, int32_t sampling_rate,
void
WriteWave
(
char
*
buffer
,
int32_t
sampling_rate
,
const
float
*
samples
,
int32_t
n
);
int64_t
WaveFileSize
(
int32_t
n_samples
);
bool
WriteWave
(
const
std
::
string
&
filename
,
int32_t
sampling_rate
,
const
float
*
samples_ch0
,
const
float
*
samples_ch1
,
int32_t
n
);
void
WriteWave
(
char
*
buffer
,
int32_t
sampling_rate
,
const
float
*
samples_ch0
,
const
float
*
samples_ch1
,
int32_t
n
);
int64_t
WaveFileSize
(
int32_t
n_samples
,
int32_t
num_channels
=
1
);
}
// namespace sherpa_onnx
...
...
wasm/speech-enhancement/app-speech-enhancement.js
查看文件 @
716ba83
...
...
@@ -77,7 +77,7 @@ fileInput.addEventListener('change', function(event) {
console
.
log
(
'ArrayBuffer length:'
,
arrayBuffer
.
byteLength
);
const
uint8Array
=
new
Uint8Array
(
arrayBuffer
);
const
wave
=
readWaveFromBinaryData
(
uint8Array
);
const
wave
=
readWaveFromBinaryData
(
uint8Array
,
Module
);
if
(
wave
==
null
)
{
alert
(
`
$
{
file
.
name
}
is
not
a
valid
.
wav
file
.
Please
select
a
*
.
wav
file
`
);
...
...
请
注册
或
登录
后发表评论