Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-01-13 21:42:09 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-01-13 21:42:09 +0800
Commit
2024e966399bbf68171ae006f9b940ec43667c16
2024e966
1 parent
68a525a0
Add C++ runtime for speaker verification models from NeMo (#527)
显示空白字符变更
内嵌
并排对比
正在显示
20 个修改的文件
包含
405 行增加
和
24 行删除
.github/scripts/test-speaker-recognition-python.sh
cmake/kaldi-native-fbank.cmake
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/features.h
sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
sherpa-onnx/csrc/speaker-embedding-extractor-model.h
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
sherpa-onnx/csrc/speaker-embedding-extractor.cc
sherpa-onnx/csrc/speaker-embedding-extractor.h
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
sherpa-onnx/csrc/speaker-embedding-manager.cc
sherpa-onnx/csrc/speaker-embedding-manager.h
sherpa-onnx/python/tests/test_speaker_recognition.py
.github/scripts/test-speaker-recognition-python.sh
查看文件 @
2024e96
...
...
@@ -57,5 +57,19 @@ done
ls -lh
popd
log
"Download NeMo models"
model_dir
=
$d
/nemo
mkdir -p
$model_dir
pushd
$model_dir
models
=(
nemo_en_titanet_large.onnx
nemo_en_titanet_small.onnx
nemo_en_speakerverification_speakernet.onnx
)
for
m
in
${
models
[@]
}
;
do
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/
$m
done
ls -lh
popd
python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
...
...
cmake/kaldi-native-fbank.cmake
查看文件 @
2024e96
function
(
download_kaldi_native_fbank
)
include
(
FetchContent
)
set
(
kaldi_native_fbank_URL
"https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz"
)
set
(
kaldi_native_fbank_URL2
"https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz"
)
set
(
kaldi_native_fbank_HASH
"SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283"
)
set
(
kaldi_native_fbank_URL
"https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz"
)
set
(
kaldi_native_fbank_URL2
"https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz"
)
set
(
kaldi_native_fbank_HASH
"SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19"
)
set
(
KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL
""
FORCE
)
set
(
KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL
""
FORCE
)
...
...
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set
(
possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
${
PROJECT_SOURCE_DIR
}
/kaldi-native-fbank-1.18.5.tar.gz
${
PROJECT_BINARY_DIR
}
/kaldi-native-fbank-1.18.5.tar.gz
/tmp/kaldi-native-fbank-1.18.5.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz
${
PROJECT_SOURCE_DIR
}
/kaldi-native-fbank-1.18.6.tar.gz
${
PROJECT_BINARY_DIR
}
/kaldi-native-fbank-1.18.6.tar.gz
/tmp/kaldi-native-fbank-1.18.6.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz
)
foreach
(
f IN LISTS possible_file_locations
)
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
2024e96
...
...
@@ -100,6 +100,7 @@ set(sources
list
(
APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-model.cc
speaker-embedding-extractor-nemo-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
...
...
sherpa-onnx/csrc/features.cc
查看文件 @
2024e96
...
...
@@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
public
:
explicit
Impl
(
const
FeatureExtractorConfig
&
config
)
:
config_
(
config
)
{
opts_
.
frame_opts
.
dither
=
0
;
opts_
.
frame_opts
.
snip_edges
=
false
;
opts_
.
frame_opts
.
snip_edges
=
config
.
snip_edges
;
opts_
.
frame_opts
.
samp_freq
=
config
.
sampling_rate
;
opts_
.
frame_opts
.
frame_shift_ms
=
config
.
frame_shift_ms
;
opts_
.
frame_opts
.
frame_length_ms
=
config
.
frame_length_ms
;
opts_
.
frame_opts
.
remove_dc_offset
=
config
.
remove_dc_offset
;
opts_
.
frame_opts
.
window_type
=
config
.
window_type
;
opts_
.
mel_opts
.
num_bins
=
config
.
feature_dim
;
...
...
@@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_
.
mel_opts
.
high_freq
=
-
400
;
opts_
.
mel_opts
.
low_freq
=
config
.
low_freq
;
opts_
.
mel_opts
.
is_librosa
=
config
.
is_librosa
;
fbank_
=
std
::
make_unique
<
knf
::
OnlineFbank
>
(
opts_
);
}
...
...
sherpa-onnx/csrc/features.h
查看文件 @
2024e96
...
...
@@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
// If false, we will multiply the inputs by 32768
bool
normalize_samples
=
true
;
bool
snip_edges
=
false
;
float
frame_shift_ms
=
10
.
0
f
;
// in milliseconds.
float
frame_length_ms
=
25
.
0
f
;
// in milliseconds.
int32_t
low_freq
=
20
;
bool
is_librosa
=
false
;
bool
remove_dc_offset
=
true
;
// Subtract mean of wave before FFT.
std
::
string
window_type
=
"povey"
;
// e.g. Hamming window
std
::
string
ToString
()
const
;
void
Register
(
ParseOptions
*
po
);
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"
namespace
sherpa_onnx
{
...
...
@@ -14,6 +15,7 @@ namespace {
enum
class
ModelType
{
kWeSpeaker
,
k3dSpeaker
,
kNeMo
,
kUnkown
,
};
...
...
@@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return
ModelType
::
kWeSpeaker
;
}
else
if
(
model_type
.
get
()
==
std
::
string
(
"3d-speaker"
))
{
return
ModelType
::
k3dSpeaker
;
}
else
if
(
model_type
.
get
()
==
std
::
string
(
"nemo"
))
{
return
ModelType
::
kNeMo
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported model_type: %s"
,
model_type
.
get
());
return
ModelType
::
kUnkown
;
...
...
@@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
// fall through
case
ModelType
:
:
k3dSpeaker
:
return
std
::
make_unique
<
SpeakerEmbeddingExtractorGeneralImpl
>
(
config
);
case
ModelType
:
:
kNeMo
:
return
std
::
make_unique
<
SpeakerEmbeddingExtractorNeMoImpl
>
(
config
);
case
ModelType
:
:
kUnkown
:
SHERPA_ONNX_LOGE
(
"Unknown model type in for speaker embedding extractor!"
);
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
//
// Copyright (c) 202
3-202
4 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-model.h
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
//
// Copyright (c) 202
3-202
4 Xiaomi Corporation
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
0 → 100644
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace
sherpa_onnx
{
class
SpeakerEmbeddingExtractorNeMoImpl
:
public
SpeakerEmbeddingExtractorImpl
{
public
:
explicit
SpeakerEmbeddingExtractorNeMoImpl
(
const
SpeakerEmbeddingExtractorConfig
&
config
)
:
model_
(
config
)
{}
int32_t
Dim
()
const
override
{
return
model_
.
GetMetaData
().
output_dim
;
}
std
::
unique_ptr
<
OnlineStream
>
CreateStream
()
const
override
{
FeatureExtractorConfig
feat_config
;
const
auto
&
meta_data
=
model_
.
GetMetaData
();
feat_config
.
sampling_rate
=
meta_data
.
sample_rate
;
feat_config
.
feature_dim
=
meta_data
.
feat_dim
;
feat_config
.
normalize_samples
=
true
;
feat_config
.
snip_edges
=
true
;
feat_config
.
frame_shift_ms
=
meta_data
.
window_stride_ms
;
feat_config
.
frame_length_ms
=
meta_data
.
window_size_ms
;
feat_config
.
low_freq
=
0
;
feat_config
.
is_librosa
=
true
;
feat_config
.
remove_dc_offset
=
false
;
feat_config
.
window_type
=
meta_data
.
window_type
;
return
std
::
make_unique
<
OnlineStream
>
(
feat_config
);
}
bool
IsReady
(
OnlineStream
*
s
)
const
override
{
return
s
->
GetNumProcessedFrames
()
<
s
->
NumFramesReady
();
}
std
::
vector
<
float
>
Compute
(
OnlineStream
*
s
)
const
override
{
int32_t
num_frames
=
s
->
NumFramesReady
()
-
s
->
GetNumProcessedFrames
();
if
(
num_frames
<=
0
)
{
SHERPA_ONNX_LOGE
(
"Please make sure IsReady(s) returns true. num_frames: %d"
,
num_frames
);
return
{};
}
std
::
vector
<
float
>
features
=
s
->
GetFrames
(
s
->
GetNumProcessedFrames
(),
num_frames
);
s
->
GetNumProcessedFrames
()
+=
num_frames
;
int32_t
feat_dim
=
features
.
size
()
/
num_frames
;
const
auto
&
meta_data
=
model_
.
GetMetaData
();
if
(
!
meta_data
.
feature_normalize_type
.
empty
())
{
if
(
meta_data
.
feature_normalize_type
==
"per_feature"
)
{
NormalizePerFeature
(
features
.
data
(),
num_frames
,
feat_dim
);
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported feature_normalize_type: %s"
,
meta_data
.
feature_normalize_type
.
c_str
());
exit
(
-
1
);
}
}
if
(
num_frames
%
16
!=
0
)
{
int32_t
pad
=
16
-
num_frames
%
16
;
features
.
resize
((
num_frames
+
pad
)
*
feat_dim
);
}
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
3
>
x_shape
{
1
,
num_frames
,
feat_dim
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
features
.
data
(),
features
.
size
(),
x_shape
.
data
(),
x_shape
.
size
());
x
=
Transpose12
(
model_
.
Allocator
(),
&
x
);
int64_t
x_lens
=
num_frames
;
std
::
array
<
int64_t
,
1
>
x_lens_shape
{
1
};
Ort
::
Value
x_lens_tensor
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
&
x_lens
,
1
,
x_lens_shape
.
data
(),
x_lens_shape
.
size
());
Ort
::
Value
embedding
=
model_
.
Compute
(
std
::
move
(
x
),
std
::
move
(
x_lens_tensor
));
std
::
vector
<
int64_t
>
embedding_shape
=
embedding
.
GetTensorTypeAndShapeInfo
().
GetShape
();
std
::
vector
<
float
>
ans
(
embedding_shape
[
1
]);
std
::
copy
(
embedding
.
GetTensorData
<
float
>
(),
embedding
.
GetTensorData
<
float
>
()
+
ans
.
size
(),
ans
.
begin
());
return
ans
;
}
private
:
void
NormalizePerFeature
(
float
*
p
,
int32_t
num_frames
,
int32_t
feat_dim
)
const
{
auto
m
=
Eigen
::
Map
<
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
(
p
,
num_frames
,
feat_dim
);
auto
EX
=
m
.
colwise
().
mean
();
auto
EX2
=
m
.
array
().
pow
(
2
).
colwise
().
sum
()
/
num_frames
;
auto
variance
=
EX2
-
EX
.
array
().
pow
(
2
);
auto
stddev
=
variance
.
array
().
sqrt
();
m
=
(
m
.
rowwise
()
-
EX
).
array
().
rowwise
()
/
stddev
.
array
();
}
private
:
SpeakerEmbeddingExtractorNeMoModel
model_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
0 → 100644
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
namespace
sherpa_onnx
{
struct
SpeakerEmbeddingExtractorNeMoModelMetaData
{
int32_t
output_dim
=
0
;
int32_t
feat_dim
=
80
;
int32_t
sample_rate
=
0
;
int32_t
window_size_ms
=
25
;
int32_t
window_stride_ms
=
25
;
// Chinese, English, etc.
std
::
string
language
;
// for 3d-speaker, it is global-mean
std
::
string
feature_normalize_type
;
std
::
string
window_type
=
"hann"
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
0 → 100644
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
namespace
sherpa_onnx
{
class
SpeakerEmbeddingExtractorNeMoModel
::
Impl
{
public
:
explicit
Impl
(
const
SpeakerEmbeddingExtractorConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
config
.
model
);
Init
(
buf
.
data
(),
buf
.
size
());
}
}
Ort
::
Value
Compute
(
Ort
::
Value
x
,
Ort
::
Value
x_lens
)
const
{
std
::
array
<
Ort
::
Value
,
2
>
inputs
=
{
std
::
move
(
x
),
std
::
move
(
x_lens
)};
// output_names_ptr_[0] is logits
// output_names_ptr_[1] is embeddings
// so we use output_names_ptr_.data() + 1 here to extract only the
// embeddings
auto
outputs
=
sess_
->
Run
({},
input_names_ptr_
.
data
(),
inputs
.
data
(),
inputs
.
size
(),
output_names_ptr_
.
data
()
+
1
,
1
);
return
std
::
move
(
outputs
[
0
]);
}
OrtAllocator
*
Allocator
()
const
{
return
allocator_
;
}
const
SpeakerEmbeddingExtractorNeMoModelMetaData
&
GetMetaData
()
const
{
return
meta_data_
;
}
private
:
void
Init
(
void
*
model_data
,
size_t
model_data_length
)
{
sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
sess_
.
get
(),
&
input_names_
,
&
input_names_ptr_
);
GetOutputNames
(
sess_
.
get
(),
&
output_names_
,
&
output_names_ptr_
);
// get meta data
Ort
::
ModelMetadata
meta_data
=
sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
PrintModelMetadata
(
os
,
meta_data
);
SHERPA_ONNX_LOGE
(
"%s"
,
os
.
str
().
c_str
());
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
// used in the macro below
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
output_dim
,
"output_dim"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
feat_dim
,
"feat_dim"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
sample_rate
,
"sample_rate"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
window_size_ms
,
"window_size_ms"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
window_stride_ms
,
"window_stride_ms"
);
SHERPA_ONNX_READ_META_DATA_STR
(
meta_data_
.
language
,
"language"
);
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT
(
meta_data_
.
feature_normalize_type
,
"feature_normalize_type"
,
""
);
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT
(
meta_data_
.
window_type
,
"window_type"
,
"povey"
);
std
::
string
framework
;
SHERPA_ONNX_READ_META_DATA_STR
(
framework
,
"framework"
);
if
(
framework
!=
"nemo"
)
{
SHERPA_ONNX_LOGE
(
"Expect a NeMo model, given: %s"
,
framework
.
c_str
());
exit
(
-
1
);
}
}
private
:
SpeakerEmbeddingExtractorConfig
config_
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
sess_
;
std
::
vector
<
std
::
string
>
input_names_
;
std
::
vector
<
const
char
*>
input_names_ptr_
;
std
::
vector
<
std
::
string
>
output_names_
;
std
::
vector
<
const
char
*>
output_names_ptr_
;
SpeakerEmbeddingExtractorNeMoModelMetaData
meta_data_
;
};
SpeakerEmbeddingExtractorNeMoModel
::
SpeakerEmbeddingExtractorNeMoModel
(
const
SpeakerEmbeddingExtractorConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
SpeakerEmbeddingExtractorNeMoModel
::~
SpeakerEmbeddingExtractorNeMoModel
()
=
default
;
const
SpeakerEmbeddingExtractorNeMoModelMetaData
&
SpeakerEmbeddingExtractorNeMoModel
::
GetMetaData
()
const
{
return
impl_
->
GetMetaData
();
}
Ort
::
Value
SpeakerEmbeddingExtractorNeMoModel
::
Compute
(
Ort
::
Value
x
,
Ort
::
Value
x_lens
)
const
{
return
impl_
->
Compute
(
std
::
move
(
x
),
std
::
move
(
x_lens
));
}
OrtAllocator
*
SpeakerEmbeddingExtractorNeMoModel
::
Allocator
()
const
{
return
impl_
->
Allocator
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
0 → 100644
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace
sherpa_onnx
{
class
SpeakerEmbeddingExtractorNeMoModel
{
public
:
explicit
SpeakerEmbeddingExtractorNeMoModel
(
const
SpeakerEmbeddingExtractorConfig
&
config
);
~
SpeakerEmbeddingExtractorNeMoModel
();
const
SpeakerEmbeddingExtractorNeMoModelMetaData
&
GetMetaData
()
const
;
/**
* @param x A float32 tensor of shape (N, C, T)
* @param x_len A int64 tensor of shape (N,)
* @return A float32 tensor of shape (N, C)
*/
Ort
::
Value
Compute
(
Ort
::
Value
x
,
Ort
::
Value
x_len
)
const
;
OrtAllocator
*
Allocator
()
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
...
...
sherpa-onnx/csrc/speaker-embedding-extractor.cc
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
...
...
sherpa-onnx/csrc/speaker-embedding-extractor.h
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-extractor.h
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
...
...
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
//
// Copyright (c) 202
3
Jingzhao Ou (jingzhao.ou@gmail.com)
// Copyright (c) 202
4
Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
...
...
sherpa-onnx/csrc/speaker-embedding-manager.cc
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
...
...
sherpa-onnx/csrc/speaker-embedding-manager.h
查看文件 @
2024e96
// sherpa-onnx/csrc/speaker-embedding-manager.h
//
// Copyright (c) 202
3
Xiaomi Corporation
// Copyright (c) 202
4
Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
...
...
sherpa-onnx/python/tests/test_speaker_recognition.py
查看文件 @
2024e96
...
...
@@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename):
return
extractor
def
test_
wespeaker_model
(
model_filename
:
str
):
def
test_
zh_models
(
model_filename
:
str
):
model_filename
=
str
(
model_filename
)
if
"en"
in
model_filename
:
print
(
f
"skip {model_filename}"
)
...
...
@@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str):
assert
ans
==
name
,
(
name
,
ans
)
def
test_3dspeaker_model
(
model_filename
:
str
):
extractor
=
load_speaker_embedding_model
(
str
(
model_filename
))
def
test_en_and_zh_models
(
model_filename
:
str
):
model_filename
=
str
(
model_filename
)
extractor
=
load_speaker_embedding_model
(
model_filename
)
manager
=
sherpa_onnx
.
SpeakerEmbeddingManager
(
extractor
.
dim
)
filenames
=
[
...
...
@@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str):
"speaker1_a_en_16k"
,
"speaker2_a_en_16k"
,
]
is_en
=
"en"
in
model_filename
for
filename
in
filenames
:
if
is_en
and
"cn"
in
filename
:
continue
if
not
is_en
and
"en"
in
filename
:
continue
name
=
filename
.
rsplit
(
"_"
,
maxsplit
=
1
)[
0
]
data
,
sample_rate
=
read_wave
(
f
"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
...
...
@@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str):
"speaker1_b_en_16k"
,
]
for
filename
in
filenames
:
if
is_en
and
"cn"
in
filename
:
continue
if
not
is_en
and
"en"
in
filename
:
continue
print
(
filename
)
name
=
filename
.
rsplit
(
"_"
,
maxsplit
=
1
)[
0
]
name
=
name
.
replace
(
"b_cn"
,
"a_cn"
)
...
...
@@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase):
return
for
filename
in
model_dir
.
glob
(
"*.onnx"
):
print
(
filename
)
test_wespeaker_model
(
filename
)
test_zh_models
(
filename
)
test_en_and_zh_models
(
filename
)
def
test_3dpeaker_models
(
self
):
model_dir
=
Path
(
d
)
/
"3dspeaker"
...
...
@@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase):
return
for
filename
in
model_dir
.
glob
(
"*.onnx"
):
print
(
filename
)
test_3dspeaker_model
(
filename
)
test_en_and_zh_models
(
filename
)
def
test_nemo_models
(
self
):
model_dir
=
Path
(
d
)
/
"nemo"
if
not
model_dir
.
is_dir
():
print
(
f
"{model_dir} does not exist - skip it"
)
return
for
filename
in
model_dir
.
glob
(
"*.onnx"
):
print
(
filename
)
test_en_and_zh_models
(
filename
)
if
__name__
==
"__main__"
:
...
...
请
注册
或
登录
后发表评论