Fangjun Kuang
Committed by GitHub

Add C++ runtime for speaker verification models from NeMo (#527)

@@ -57,5 +57,19 @@ done @@ -57,5 +57,19 @@ done
57 ls -lh 57 ls -lh
58 popd 58 popd
59 59
  60 +log "Download NeMo models"
  61 +model_dir=$d/nemo
  62 +mkdir -p $model_dir
  63 +pushd $model_dir
  64 +models=(
  65 +nemo_en_titanet_large.onnx
  66 +nemo_en_titanet_small.onnx
  67 +nemo_en_speakerverification_speakernet.onnx
  68 +)
  69 +for m in ${models[@]}; do
  70 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
  71 +done
  72 +ls -lh
  73 +popd
60 74
61 python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose 75 python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz  
16 - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz  
17 - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz  
18 - /tmp/kaldi-native-fbank-1.18.5.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz
  16 + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz
  17 + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz
  18 + /tmp/kaldi-native-fbank-1.18.6.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
@@ -100,6 +100,7 @@ set(sources @@ -100,6 +100,7 @@ set(sources
100 list(APPEND sources 100 list(APPEND sources
101 speaker-embedding-extractor-impl.cc 101 speaker-embedding-extractor-impl.cc
102 speaker-embedding-extractor-model.cc 102 speaker-embedding-extractor-model.cc
  103 + speaker-embedding-extractor-nemo-model.cc
103 speaker-embedding-extractor.cc 104 speaker-embedding-extractor.cc
104 speaker-embedding-manager.cc 105 speaker-embedding-manager.cc
105 ) 106 )
@@ -41,8 +41,12 @@ class FeatureExtractor::Impl { @@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
41 public: 41 public:
42 explicit Impl(const FeatureExtractorConfig &config) : config_(config) { 42 explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
43 opts_.frame_opts.dither = 0; 43 opts_.frame_opts.dither = 0;
44 - opts_.frame_opts.snip_edges = false; 44 + opts_.frame_opts.snip_edges = config.snip_edges;
45 opts_.frame_opts.samp_freq = config.sampling_rate; 45 opts_.frame_opts.samp_freq = config.sampling_rate;
  46 + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
  47 + opts_.frame_opts.frame_length_ms = config.frame_length_ms;
  48 + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
  49 + opts_.frame_opts.window_type = config.window_type;
46 50
47 opts_.mel_opts.num_bins = config.feature_dim; 51 opts_.mel_opts.num_bins = config.feature_dim;
48 52
@@ -52,6 +56,9 @@ class FeatureExtractor::Impl { @@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
52 // https://github.com/k2-fsa/sherpa-onnx/issues/514 56 // https://github.com/k2-fsa/sherpa-onnx/issues/514
53 opts_.mel_opts.high_freq = -400; 57 opts_.mel_opts.high_freq = -400;
54 58
  59 + opts_.mel_opts.low_freq = config.low_freq;
  60 + opts_.mel_opts.is_librosa = config.is_librosa;
  61 +
55 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 62 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
56 } 63 }
57 64
@@ -28,6 +28,14 @@ struct FeatureExtractorConfig { @@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
28 // If false, we will multiply the inputs by 32768 28 // If false, we will multiply the inputs by 32768
29 bool normalize_samples = true; 29 bool normalize_samples = true;
30 30
  31 + bool snip_edges = false;
  32 + float frame_shift_ms = 10.0f; // in milliseconds.
  33 + float frame_length_ms = 25.0f; // in milliseconds.
  34 + int32_t low_freq = 20;
  35 + bool is_librosa = false;
  36 + bool remove_dc_offset = true; // Subtract mean of wave before FFT.
  37 + std::string window_type = "povey"; // e.g. Hamming window
  38 +
31 std::string ToString() const; 39 std::string ToString() const;
32 40
33 void Register(ParseOptions *po); 41 void Register(ParseOptions *po);
1 // sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h 1 // sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ 5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
1 // sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc 1 // sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" 4 #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
5 5
6 #include "sherpa-onnx/csrc/macros.h" 6 #include "sherpa-onnx/csrc/macros.h"
7 #include "sherpa-onnx/csrc/onnx-utils.h" 7 #include "sherpa-onnx/csrc/onnx-utils.h"
8 #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" 8 #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
  9 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"
9 10
10 namespace sherpa_onnx { 11 namespace sherpa_onnx {
11 12
@@ -14,6 +15,7 @@ namespace { @@ -14,6 +15,7 @@ namespace {
14 enum class ModelType { 15 enum class ModelType {
15 kWeSpeaker, 16 kWeSpeaker,
16 k3dSpeaker, 17 k3dSpeaker,
  18 + kNeMo,
17 kUnkown, 19 kUnkown,
18 }; 20 };
19 21
@@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
52 return ModelType::kWeSpeaker; 54 return ModelType::kWeSpeaker;
53 } else if (model_type.get() == std::string("3d-speaker")) { 55 } else if (model_type.get() == std::string("3d-speaker")) {
54 return ModelType::k3dSpeaker; 56 return ModelType::k3dSpeaker;
  57 + } else if (model_type.get() == std::string("nemo")) {
  58 + return ModelType::kNeMo;
55 } else { 59 } else {
56 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 60 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
57 return ModelType::kUnkown; 61 return ModelType::kUnkown;
@@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create( @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
74 // fall through 78 // fall through
75 case ModelType::k3dSpeaker: 79 case ModelType::k3dSpeaker:
76 return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); 80 return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
  81 + case ModelType::kNeMo:
  82 + return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
77 case ModelType::kUnkown: 83 case ModelType::kUnkown:
78 SHERPA_ONNX_LOGE( 84 SHERPA_ONNX_LOGE(
79 "Unknown model type in for speaker embedding extractor!"); 85 "Unknown model type in for speaker embedding extractor!");
1 // sherpa-onnx/csrc/speaker-embedding-extractor-impl.h 1 // sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ 5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
1 // sherpa-onnx/csrc/speaker-embedding-extractor-model.cc 1 // sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
2 // 2 //
3 -// Copyright (c) 2023-2024 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" 5 #include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
6 6
1 // sherpa-onnx/csrc/speaker-embedding-extractor-model.h 1 // sherpa-onnx/csrc/speaker-embedding-extractor-model.h
2 // 2 //
3 -// Copyright (c) 2023-2024 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ 4 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
5 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ 5 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
6 6
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "Eigen/Dense"
  13 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
  14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
  15 +#include "sherpa-onnx/csrc/transpose.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
  20 + public:
  21 + explicit SpeakerEmbeddingExtractorNeMoImpl(
  22 + const SpeakerEmbeddingExtractorConfig &config)
  23 + : model_(config) {}
  24 +
  25 + int32_t Dim() const override { return model_.GetMetaData().output_dim; }
  26 +
  27 + std::unique_ptr<OnlineStream> CreateStream() const override {
  28 + FeatureExtractorConfig feat_config;
  29 + const auto &meta_data = model_.GetMetaData();
  30 + feat_config.sampling_rate = meta_data.sample_rate;
  31 + feat_config.feature_dim = meta_data.feat_dim;
  32 + feat_config.normalize_samples = true;
  33 + feat_config.snip_edges = true;
  34 + feat_config.frame_shift_ms = meta_data.window_stride_ms;
  35 + feat_config.frame_length_ms = meta_data.window_size_ms;
  36 + feat_config.low_freq = 0;
  37 + feat_config.is_librosa = true;
  38 + feat_config.remove_dc_offset = false;
  39 + feat_config.window_type = meta_data.window_type;
  40 +
  41 + return std::make_unique<OnlineStream>(feat_config);
  42 + }
  43 +
  44 + bool IsReady(OnlineStream *s) const override {
  45 + return s->GetNumProcessedFrames() < s->NumFramesReady();
  46 + }
  47 +
  48 + std::vector<float> Compute(OnlineStream *s) const override {
  49 + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
  50 + if (num_frames <= 0) {
  51 + SHERPA_ONNX_LOGE(
  52 + "Please make sure IsReady(s) returns true. num_frames: %d",
  53 + num_frames);
  54 + return {};
  55 + }
  56 +
  57 + std::vector<float> features =
  58 + s->GetFrames(s->GetNumProcessedFrames(), num_frames);
  59 +
  60 + s->GetNumProcessedFrames() += num_frames;
  61 +
  62 + int32_t feat_dim = features.size() / num_frames;
  63 +
  64 + const auto &meta_data = model_.GetMetaData();
  65 + if (!meta_data.feature_normalize_type.empty()) {
  66 + if (meta_data.feature_normalize_type == "per_feature") {
  67 + NormalizePerFeature(features.data(), num_frames, feat_dim);
  68 + } else {
  69 + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
  70 + meta_data.feature_normalize_type.c_str());
  71 + exit(-1);
  72 + }
  73 + }
  74 +
  75 + if (num_frames % 16 != 0) {
  76 + int32_t pad = 16 - num_frames % 16;
  77 + features.resize((num_frames + pad) * feat_dim);
  78 + }
  79 +
  80 + auto memory_info =
  81 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  82 +
  83 + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
  84 + Ort::Value x =
  85 + Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
  86 + x_shape.data(), x_shape.size());
  87 +
  88 + x = Transpose12(model_.Allocator(), &x);
  89 +
  90 + int64_t x_lens = num_frames;
  91 + std::array<int64_t, 1> x_lens_shape{1};
  92 + Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
  93 + memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());
  94 +
  95 + Ort::Value embedding =
  96 + model_.Compute(std::move(x), std::move(x_lens_tensor));
  97 + std::vector<int64_t> embedding_shape =
  98 + embedding.GetTensorTypeAndShapeInfo().GetShape();
  99 +
  100 + std::vector<float> ans(embedding_shape[1]);
  101 + std::copy(embedding.GetTensorData<float>(),
  102 + embedding.GetTensorData<float>() + ans.size(), ans.begin());
  103 +
  104 + return ans;
  105 + }
  106 +
  107 + private:
  108 + void NormalizePerFeature(float *p, int32_t num_frames,
  109 + int32_t feat_dim) const {
  110 + auto m = Eigen::Map<
  111 + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
  112 + p, num_frames, feat_dim);
  113 +
  114 + auto EX = m.colwise().mean();
  115 + auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
  116 + auto variance = EX2 - EX.array().pow(2);
  117 + auto stddev = variance.array().sqrt();
  118 +
  119 + m = (m.rowwise() - EX).array().rowwise() / stddev.array();
  120 + }
  121 +
  122 + private:
  123 + SpeakerEmbeddingExtractorNeMoModel model_;
  124 +};
  125 +
  126 +} // namespace sherpa_onnx
  127 +
  128 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
  6 +
  7 +#include <cstdint>
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +struct SpeakerEmbeddingExtractorNeMoModelMetaData {
  13 + int32_t output_dim = 0;
  14 + int32_t feat_dim = 80;
  15 + int32_t sample_rate = 0;
  16 + int32_t window_size_ms = 25;
  17 + int32_t window_stride_ms = 25;
  18 +
  19 + // Chinese, English, etc.
  20 + std::string language;
  21 +
  22 + // for 3d-speaker, it is global-mean
  23 + std::string feature_normalize_type;
  24 + std::string window_type = "hann";
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +#include "sherpa-onnx/csrc/session.h"
  14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class SpeakerEmbeddingExtractorNeMoModel::Impl {
  19 + public:
  20 + explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
  21 + : config_(config),
  22 + env_(ORT_LOGGING_LEVEL_ERROR),
  23 + sess_opts_(GetSessionOptions(config)),
  24 + allocator_{} {
  25 + {
  26 + auto buf = ReadFile(config.model);
  27 + Init(buf.data(), buf.size());
  28 + }
  29 + }
  30 +
  31 + Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const {
  32 + std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
  33 +
  34 + // output_names_ptr_[0] is logits
  35 + // output_names_ptr_[1] is embeddings
  36 + // so we use output_names_ptr_.data() + 1 here to extract only the
  37 + // embeddings
  38 + auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(),
  39 + inputs.size(), output_names_ptr_.data() + 1, 1);
  40 + return std::move(outputs[0]);
  41 + }
  42 +
  43 + OrtAllocator *Allocator() const { return allocator_; }
  44 +
  45 + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
  46 + return meta_data_;
  47 + }
  48 +
  49 + private:
  50 + void Init(void *model_data, size_t model_data_length) {
  51 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  52 + sess_opts_);
  53 +
  54 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  55 +
  56 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  57 +
  58 + // get meta data
  59 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  60 + if (config_.debug) {
  61 + std::ostringstream os;
  62 + PrintModelMetadata(os, meta_data);
  63 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  64 + }
  65 +
  66 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  67 + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
  68 + SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim");
  69 + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
  70 + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms");
  71 + SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms");
  72 + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
  73 +
  74 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(
  75 + meta_data_.feature_normalize_type, "feature_normalize_type", "");
  76 +
  77 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type,
  78 + "window_type", "povey");
  79 +
  80 + std::string framework;
  81 + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
  82 + if (framework != "nemo") {
  83 + SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str());
  84 + exit(-1);
  85 + }
  86 + }
  87 +
  88 + private:
  89 + SpeakerEmbeddingExtractorConfig config_;
  90 + Ort::Env env_;
  91 + Ort::SessionOptions sess_opts_;
  92 + Ort::AllocatorWithDefaultOptions allocator_;
  93 +
  94 + std::unique_ptr<Ort::Session> sess_;
  95 +
  96 + std::vector<std::string> input_names_;
  97 + std::vector<const char *> input_names_ptr_;
  98 +
  99 + std::vector<std::string> output_names_;
  100 + std::vector<const char *> output_names_ptr_;
  101 +
  102 + SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_;
  103 +};
  104 +
  105 +SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel(
  106 + const SpeakerEmbeddingExtractorConfig &config)
  107 + : impl_(std::make_unique<Impl>(config)) {}
  108 +
  109 +SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() =
  110 + default;
  111 +
  112 +const SpeakerEmbeddingExtractorNeMoModelMetaData &
  113 +SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const {
  114 + return impl_->GetMetaData();
  115 +}
  116 +
  117 +Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute(
  118 + Ort::Value x, Ort::Value x_lens) const {
  119 + return impl_->Compute(std::move(x), std::move(x_lens));
  120 +}
  121 +
  122 +OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const {
  123 + return impl_->Allocator();
  124 +}
  125 +
  126 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
  6 +
  7 +#include <memory>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
  11 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class SpeakerEmbeddingExtractorNeMoModel {
  16 + public:
  17 + explicit SpeakerEmbeddingExtractorNeMoModel(
  18 + const SpeakerEmbeddingExtractorConfig &config);
  19 +
  20 + ~SpeakerEmbeddingExtractorNeMoModel();
  21 +
  22 + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const;
  23 +
  24 + /**
  25 + * @param x A float32 tensor of shape (N, C, T)
  26 + * @param x_len A int64 tensor of shape (N,)
  27 + * @return A float32 tensor of shape (N, C)
  28 + */
  29 + Ort::Value Compute(Ort::Value x, Ort::Value x_len) const;
  30 +
  31 + OrtAllocator *Allocator() const;
  32 +
  33 + private:
  34 + class Impl;
  35 + std::unique_ptr<Impl> impl_;
  36 +};
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
1 // sherpa-onnx/csrc/speaker-embedding-extractor.cc 1 // sherpa-onnx/csrc/speaker-embedding-extractor.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 5 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
6 6
1 // sherpa-onnx/csrc/speaker-embedding-extractor.h 1 // sherpa-onnx/csrc/speaker-embedding-extractor.h
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ 5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ 6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
1 // sherpa-onnx/csrc/speaker-embedding-manager-test.cc 1 // sherpa-onnx/csrc/speaker-embedding-manager-test.cc
2 // 2 //
3 -// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) 3 +// Copyright (c) 2024 Jingzhao Ou (jingzhao.ou@gmail.com)
4 4
5 #include "sherpa-onnx/csrc/speaker-embedding-manager.h" 5 #include "sherpa-onnx/csrc/speaker-embedding-manager.h"
6 6
1 // sherpa-onnx/csrc/speaker-embedding-manager.cc 1 // sherpa-onnx/csrc/speaker-embedding-manager.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #include "sherpa-onnx/csrc/speaker-embedding-manager.h" 5 #include "sherpa-onnx/csrc/speaker-embedding-manager.h"
6 6
1 // sherpa-onnx/csrc/speaker-embedding-manager.h 1 // sherpa-onnx/csrc/speaker-embedding-manager.h
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2024 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ 5 #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ 6 #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
@@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename): @@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename):
56 return extractor 56 return extractor
57 57
58 58
59 -def test_wespeaker_model(model_filename: str): 59 +def test_zh_models(model_filename: str):
60 model_filename = str(model_filename) 60 model_filename = str(model_filename)
61 if "en" in model_filename: 61 if "en" in model_filename:
62 print(f"skip {model_filename}") 62 print(f"skip {model_filename}")
@@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str): @@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str):
114 assert ans == name, (name, ans) 114 assert ans == name, (name, ans)
115 115
116 116
117 -def test_3dspeaker_model(model_filename: str):  
118 - extractor = load_speaker_embedding_model(str(model_filename)) 117 +def test_en_and_zh_models(model_filename: str):
  118 + model_filename = str(model_filename)
  119 + extractor = load_speaker_embedding_model(model_filename)
119 manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) 120 manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
120 121
121 filenames = [ 122 filenames = [
@@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str): @@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str):
124 "speaker1_a_en_16k", 125 "speaker1_a_en_16k",
125 "speaker2_a_en_16k", 126 "speaker2_a_en_16k",
126 ] 127 ]
  128 + is_en = "en" in model_filename
127 for filename in filenames: 129 for filename in filenames:
  130 + if is_en and "cn" in filename:
  131 + continue
  132 +
  133 + if not is_en and "en" in filename:
  134 + continue
  135 +
128 name = filename.rsplit("_", maxsplit=1)[0] 136 name = filename.rsplit("_", maxsplit=1)[0]
129 data, sample_rate = read_wave( 137 data, sample_rate = read_wave(
130 f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" 138 f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
@@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str): @@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str):
145 "speaker1_b_en_16k", 153 "speaker1_b_en_16k",
146 ] 154 ]
147 for filename in filenames: 155 for filename in filenames:
  156 + if is_en and "cn" in filename:
  157 + continue
  158 +
  159 + if not is_en and "en" in filename:
  160 + continue
148 print(filename) 161 print(filename)
149 name = filename.rsplit("_", maxsplit=1)[0] 162 name = filename.rsplit("_", maxsplit=1)[0]
150 name = name.replace("b_cn", "a_cn") 163 name = name.replace("b_cn", "a_cn")
@@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase): @@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase):
178 return 191 return
179 for filename in model_dir.glob("*.onnx"): 192 for filename in model_dir.glob("*.onnx"):
180 print(filename) 193 print(filename)
181 - test_wespeaker_model(filename) 194 + test_zh_models(filename)
  195 + test_en_and_zh_models(filename)
182 196
183 def test_3dpeaker_models(self): 197 def test_3dpeaker_models(self):
184 model_dir = Path(d) / "3dspeaker" 198 model_dir = Path(d) / "3dspeaker"
@@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase): @@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase):
187 return 201 return
188 for filename in model_dir.glob("*.onnx"): 202 for filename in model_dir.glob("*.onnx"):
189 print(filename) 203 print(filename)
190 - test_3dspeaker_model(filename) 204 + test_en_and_zh_models(filename)
  205 +
  206 + def test_nemo_models(self):
  207 + model_dir = Path(d) / "nemo"
  208 + if not model_dir.is_dir():
  209 + print(f"{model_dir} does not exist - skip it")
  210 + return
  211 + for filename in model_dir.glob("*.onnx"):
  212 + print(filename)
  213 + test_en_and_zh_models(filename)
191 214
192 215
193 if __name__ == "__main__": 216 if __name__ == "__main__":