Committed by
GitHub
Add C++ runtime for speaker verification models from NeMo (#527)
正在显示
20 个修改的文件
包含
405 行增加
和
24 行删除
| @@ -57,5 +57,19 @@ done | @@ -57,5 +57,19 @@ done | ||
| 57 | ls -lh | 57 | ls -lh |
| 58 | popd | 58 | popd |
| 59 | 59 | ||
| 60 | +log "Download NeMo models" | ||
| 61 | +model_dir=$d/nemo | ||
| 62 | +mkdir -p $model_dir | ||
| 63 | +pushd $model_dir | ||
| 64 | +models=( | ||
| 65 | +nemo_en_titanet_large.onnx | ||
| 66 | +nemo_en_titanet_small.onnx | ||
| 67 | +nemo_en_speakerverification_speakernet.onnx | ||
| 68 | +) | ||
| 69 | +for m in ${models[@]}; do | ||
| 70 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | ||
| 71 | +done | ||
| 72 | +ls -lh | ||
| 73 | +popd | ||
| 60 | 74 | ||
| 61 | python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose | 75 | python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose |
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz | ||
| 16 | - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz | ||
| 17 | - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.18.5.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz |
| 16 | + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz | ||
| 17 | + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.18.6.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| @@ -100,6 +100,7 @@ set(sources | @@ -100,6 +100,7 @@ set(sources | ||
| 100 | list(APPEND sources | 100 | list(APPEND sources |
| 101 | speaker-embedding-extractor-impl.cc | 101 | speaker-embedding-extractor-impl.cc |
| 102 | speaker-embedding-extractor-model.cc | 102 | speaker-embedding-extractor-model.cc |
| 103 | + speaker-embedding-extractor-nemo-model.cc | ||
| 103 | speaker-embedding-extractor.cc | 104 | speaker-embedding-extractor.cc |
| 104 | speaker-embedding-manager.cc | 105 | speaker-embedding-manager.cc |
| 105 | ) | 106 | ) |
| @@ -41,8 +41,12 @@ class FeatureExtractor::Impl { | @@ -41,8 +41,12 @@ class FeatureExtractor::Impl { | ||
| 41 | public: | 41 | public: |
| 42 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { | 42 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { |
| 43 | opts_.frame_opts.dither = 0; | 43 | opts_.frame_opts.dither = 0; |
| 44 | - opts_.frame_opts.snip_edges = false; | 44 | + opts_.frame_opts.snip_edges = config.snip_edges; |
| 45 | opts_.frame_opts.samp_freq = config.sampling_rate; | 45 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| 46 | + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; | ||
| 47 | + opts_.frame_opts.frame_length_ms = config.frame_length_ms; | ||
| 48 | + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset; | ||
| 49 | + opts_.frame_opts.window_type = config.window_type; | ||
| 46 | 50 | ||
| 47 | opts_.mel_opts.num_bins = config.feature_dim; | 51 | opts_.mel_opts.num_bins = config.feature_dim; |
| 48 | 52 | ||
| @@ -52,6 +56,9 @@ class FeatureExtractor::Impl { | @@ -52,6 +56,9 @@ class FeatureExtractor::Impl { | ||
| 52 | // https://github.com/k2-fsa/sherpa-onnx/issues/514 | 56 | // https://github.com/k2-fsa/sherpa-onnx/issues/514 |
| 53 | opts_.mel_opts.high_freq = -400; | 57 | opts_.mel_opts.high_freq = -400; |
| 54 | 58 | ||
| 59 | + opts_.mel_opts.low_freq = config.low_freq; | ||
| 60 | + opts_.mel_opts.is_librosa = config.is_librosa; | ||
| 61 | + | ||
| 55 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 62 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 56 | } | 63 | } |
| 57 | 64 |
| @@ -28,6 +28,14 @@ struct FeatureExtractorConfig { | @@ -28,6 +28,14 @@ struct FeatureExtractorConfig { | ||
| 28 | // If false, we will multiply the inputs by 32768 | 28 | // If false, we will multiply the inputs by 32768 |
| 29 | bool normalize_samples = true; | 29 | bool normalize_samples = true; |
| 30 | 30 | ||
| 31 | + bool snip_edges = false; | ||
| 32 | + float frame_shift_ms = 10.0f; // in milliseconds. | ||
| 33 | + float frame_length_ms = 25.0f; // in milliseconds. | ||
| 34 | + int32_t low_freq = 20; | ||
| 35 | + bool is_librosa = false; | ||
| 36 | + bool remove_dc_offset = true; // Subtract mean of wave before FFT. | ||
| 37 | + std::string window_type = "povey"; // e.g. Hamming window | ||
| 38 | + | ||
| 31 | std::string ToString() const; | 39 | std::string ToString() const; |
| 32 | 40 | ||
| 33 | void Register(ParseOptions *po); | 41 | void Register(ParseOptions *po); |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | 4 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" |
| 5 | 5 | ||
| 6 | #include "sherpa-onnx/csrc/macros.h" | 6 | #include "sherpa-onnx/csrc/macros.h" |
| 7 | #include "sherpa-onnx/csrc/onnx-utils.h" | 7 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 8 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" | 8 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" |
| 9 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h" | ||
| 9 | 10 | ||
| 10 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 11 | 12 | ||
| @@ -14,6 +15,7 @@ namespace { | @@ -14,6 +15,7 @@ namespace { | ||
| 14 | enum class ModelType { | 15 | enum class ModelType { |
| 15 | kWeSpeaker, | 16 | kWeSpeaker, |
| 16 | k3dSpeaker, | 17 | k3dSpeaker, |
| 18 | + kNeMo, | ||
| 17 | kUnkown, | 19 | kUnkown, |
| 18 | }; | 20 | }; |
| 19 | 21 | ||
| @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 52 | return ModelType::kWeSpeaker; | 54 | return ModelType::kWeSpeaker; |
| 53 | } else if (model_type.get() == std::string("3d-speaker")) { | 55 | } else if (model_type.get() == std::string("3d-speaker")) { |
| 54 | return ModelType::k3dSpeaker; | 56 | return ModelType::k3dSpeaker; |
| 57 | + } else if (model_type.get() == std::string("nemo")) { | ||
| 58 | + return ModelType::kNeMo; | ||
| 55 | } else { | 59 | } else { |
| 56 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 60 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 57 | return ModelType::kUnkown; | 61 | return ModelType::kUnkown; |
| @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create( | @@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create( | ||
| 74 | // fall through | 78 | // fall through |
| 75 | case ModelType::k3dSpeaker: | 79 | case ModelType::k3dSpeaker: |
| 76 | return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); | 80 | return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); |
| 81 | + case ModelType::kNeMo: | ||
| 82 | + return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config); | ||
| 77 | case ModelType::kUnkown: | 83 | case ModelType::kUnkown: |
| 78 | SHERPA_ONNX_LOGE( | 84 | SHERPA_ONNX_LOGE( |
| 79 | "Unknown model type in for speaker embedding extractor!"); | 85 | "Unknown model type in for speaker embedding extractor!"); |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-impl.h | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-impl.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-model.cc | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-model.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023-2024 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" | 5 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" |
| 6 | 6 |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-model.h | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor-model.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023-2024 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ | 4 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ |
| 5 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ | 5 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ |
| 6 | 6 |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "Eigen/Dense" | ||
| 13 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
| 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" | ||
| 15 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { | ||
| 20 | + public: | ||
| 21 | + explicit SpeakerEmbeddingExtractorNeMoImpl( | ||
| 22 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 23 | + : model_(config) {} | ||
| 24 | + | ||
| 25 | + int32_t Dim() const override { return model_.GetMetaData().output_dim; } | ||
| 26 | + | ||
| 27 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 28 | + FeatureExtractorConfig feat_config; | ||
| 29 | + const auto &meta_data = model_.GetMetaData(); | ||
| 30 | + feat_config.sampling_rate = meta_data.sample_rate; | ||
| 31 | + feat_config.feature_dim = meta_data.feat_dim; | ||
| 32 | + feat_config.normalize_samples = true; | ||
| 33 | + feat_config.snip_edges = true; | ||
| 34 | + feat_config.frame_shift_ms = meta_data.window_stride_ms; | ||
| 35 | + feat_config.frame_length_ms = meta_data.window_size_ms; | ||
| 36 | + feat_config.low_freq = 0; | ||
| 37 | + feat_config.is_librosa = true; | ||
| 38 | + feat_config.remove_dc_offset = false; | ||
| 39 | + feat_config.window_type = meta_data.window_type; | ||
| 40 | + | ||
| 41 | + return std::make_unique<OnlineStream>(feat_config); | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + bool IsReady(OnlineStream *s) const override { | ||
| 45 | + return s->GetNumProcessedFrames() < s->NumFramesReady(); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + std::vector<float> Compute(OnlineStream *s) const override { | ||
| 49 | + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); | ||
| 50 | + if (num_frames <= 0) { | ||
| 51 | + SHERPA_ONNX_LOGE( | ||
| 52 | + "Please make sure IsReady(s) returns true. num_frames: %d", | ||
| 53 | + num_frames); | ||
| 54 | + return {}; | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + std::vector<float> features = | ||
| 58 | + s->GetFrames(s->GetNumProcessedFrames(), num_frames); | ||
| 59 | + | ||
| 60 | + s->GetNumProcessedFrames() += num_frames; | ||
| 61 | + | ||
| 62 | + int32_t feat_dim = features.size() / num_frames; | ||
| 63 | + | ||
| 64 | + const auto &meta_data = model_.GetMetaData(); | ||
| 65 | + if (!meta_data.feature_normalize_type.empty()) { | ||
| 66 | + if (meta_data.feature_normalize_type == "per_feature") { | ||
| 67 | + NormalizePerFeature(features.data(), num_frames, feat_dim); | ||
| 68 | + } else { | ||
| 69 | + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", | ||
| 70 | + meta_data.feature_normalize_type.c_str()); | ||
| 71 | + exit(-1); | ||
| 72 | + } | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + if (num_frames % 16 != 0) { | ||
| 76 | + int32_t pad = 16 - num_frames % 16; | ||
| 77 | + features.resize((num_frames + pad) * feat_dim); | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + auto memory_info = | ||
| 81 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 82 | + | ||
| 83 | + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim}; | ||
| 84 | + Ort::Value x = | ||
| 85 | + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 86 | + x_shape.data(), x_shape.size()); | ||
| 87 | + | ||
| 88 | + x = Transpose12(model_.Allocator(), &x); | ||
| 89 | + | ||
| 90 | + int64_t x_lens = num_frames; | ||
| 91 | + std::array<int64_t, 1> x_lens_shape{1}; | ||
| 92 | + Ort::Value x_lens_tensor = Ort::Value::CreateTensor( | ||
| 93 | + memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size()); | ||
| 94 | + | ||
| 95 | + Ort::Value embedding = | ||
| 96 | + model_.Compute(std::move(x), std::move(x_lens_tensor)); | ||
| 97 | + std::vector<int64_t> embedding_shape = | ||
| 98 | + embedding.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 99 | + | ||
| 100 | + std::vector<float> ans(embedding_shape[1]); | ||
| 101 | + std::copy(embedding.GetTensorData<float>(), | ||
| 102 | + embedding.GetTensorData<float>() + ans.size(), ans.begin()); | ||
| 103 | + | ||
| 104 | + return ans; | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + private: | ||
| 108 | + void NormalizePerFeature(float *p, int32_t num_frames, | ||
| 109 | + int32_t feat_dim) const { | ||
| 110 | + auto m = Eigen::Map< | ||
| 111 | + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( | ||
| 112 | + p, num_frames, feat_dim); | ||
| 113 | + | ||
| 114 | + auto EX = m.colwise().mean(); | ||
| 115 | + auto EX2 = m.array().pow(2).colwise().sum() / num_frames; | ||
| 116 | + auto variance = EX2 - EX.array().pow(2); | ||
| 117 | + auto stddev = variance.array().sqrt(); | ||
| 118 | + | ||
| 119 | + m = (m.rowwise() - EX).array().rowwise() / stddev.array(); | ||
| 120 | + } | ||
| 121 | + | ||
| 122 | + private: | ||
| 123 | + SpeakerEmbeddingExtractorNeMoModel model_; | ||
| 124 | +}; | ||
| 125 | + | ||
| 126 | +} // namespace sherpa_onnx | ||
| 127 | + | ||
| 128 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ | ||
| 6 | + | ||
| 7 | +#include <cstdint> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +struct SpeakerEmbeddingExtractorNeMoModelMetaData { | ||
| 13 | + int32_t output_dim = 0; | ||
| 14 | + int32_t feat_dim = 80; | ||
| 15 | + int32_t sample_rate = 0; | ||
| 16 | + int32_t window_size_ms = 25; | ||
| 17 | + int32_t window_stride_ms = 25; | ||
| 18 | + | ||
| 19 | + // Chinese, English, etc. | ||
| 20 | + std::string language; | ||
| 21 | + | ||
| 22 | + // for 3d-speaker, it is global-mean | ||
| 23 | + std::string feature_normalize_type; | ||
| 24 | + std::string window_type = "hann"; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/session.h" | ||
| 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class SpeakerEmbeddingExtractorNeMoModel::Impl { | ||
| 19 | + public: | ||
| 20 | + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) | ||
| 21 | + : config_(config), | ||
| 22 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 23 | + sess_opts_(GetSessionOptions(config)), | ||
| 24 | + allocator_{} { | ||
| 25 | + { | ||
| 26 | + auto buf = ReadFile(config.model); | ||
| 27 | + Init(buf.data(), buf.size()); | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const { | ||
| 32 | + std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)}; | ||
| 33 | + | ||
| 34 | + // output_names_ptr_[0] is logits | ||
| 35 | + // output_names_ptr_[1] is embeddings | ||
| 36 | + // so we use output_names_ptr_.data() + 1 here to extract only the | ||
| 37 | + // embeddings | ||
| 38 | + auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(), | ||
| 39 | + inputs.size(), output_names_ptr_.data() + 1, 1); | ||
| 40 | + return std::move(outputs[0]); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 44 | + | ||
| 45 | + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { | ||
| 46 | + return meta_data_; | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + private: | ||
| 50 | + void Init(void *model_data, size_t model_data_length) { | ||
| 51 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 52 | + sess_opts_); | ||
| 53 | + | ||
| 54 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 55 | + | ||
| 56 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 57 | + | ||
| 58 | + // get meta data | ||
| 59 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 60 | + if (config_.debug) { | ||
| 61 | + std::ostringstream os; | ||
| 62 | + PrintModelMetadata(os, meta_data); | ||
| 63 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 67 | + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); | ||
| 68 | + SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim"); | ||
| 69 | + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms"); | ||
| 71 | + SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms"); | ||
| 72 | + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); | ||
| 73 | + | ||
| 74 | + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( | ||
| 75 | + meta_data_.feature_normalize_type, "feature_normalize_type", ""); | ||
| 76 | + | ||
| 77 | + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type, | ||
| 78 | + "window_type", "povey"); | ||
| 79 | + | ||
| 80 | + std::string framework; | ||
| 81 | + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); | ||
| 82 | + if (framework != "nemo") { | ||
| 83 | + SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str()); | ||
| 84 | + exit(-1); | ||
| 85 | + } | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + private: | ||
| 89 | + SpeakerEmbeddingExtractorConfig config_; | ||
| 90 | + Ort::Env env_; | ||
| 91 | + Ort::SessionOptions sess_opts_; | ||
| 92 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 93 | + | ||
| 94 | + std::unique_ptr<Ort::Session> sess_; | ||
| 95 | + | ||
| 96 | + std::vector<std::string> input_names_; | ||
| 97 | + std::vector<const char *> input_names_ptr_; | ||
| 98 | + | ||
| 99 | + std::vector<std::string> output_names_; | ||
| 100 | + std::vector<const char *> output_names_ptr_; | ||
| 101 | + | ||
| 102 | + SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_; | ||
| 103 | +}; | ||
| 104 | + | ||
| 105 | +SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( | ||
| 106 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 107 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 108 | + | ||
| 109 | +SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() = | ||
| 110 | + default; | ||
| 111 | + | ||
| 112 | +const SpeakerEmbeddingExtractorNeMoModelMetaData & | ||
| 113 | +SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const { | ||
| 114 | + return impl_->GetMetaData(); | ||
| 115 | +} | ||
| 116 | + | ||
| 117 | +Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute( | ||
| 118 | + Ort::Value x, Ort::Value x_lens) const { | ||
| 119 | + return impl_->Compute(std::move(x), std::move(x_lens)); | ||
| 120 | +} | ||
| 121 | + | ||
| 122 | +OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const { | ||
| 123 | + return impl_->Allocator(); | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" | ||
| 11 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class SpeakerEmbeddingExtractorNeMoModel { | ||
| 16 | + public: | ||
| 17 | + explicit SpeakerEmbeddingExtractorNeMoModel( | ||
| 18 | + const SpeakerEmbeddingExtractorConfig &config); | ||
| 19 | + | ||
| 20 | + ~SpeakerEmbeddingExtractorNeMoModel(); | ||
| 21 | + | ||
| 22 | + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const; | ||
| 23 | + | ||
| 24 | + /** | ||
| 25 | + * @param x A float32 tensor of shape (N, C, T) | ||
| 26 | + * @param x_len A int64 tensor of shape (N,) | ||
| 27 | + * @return A float32 tensor of shape (N, C) | ||
| 28 | + */ | ||
| 29 | + Ort::Value Compute(Ort::Value x, Ort::Value x_len) const; | ||
| 30 | + | ||
| 31 | + OrtAllocator *Allocator() const; | ||
| 32 | + | ||
| 33 | + private: | ||
| 34 | + class Impl; | ||
| 35 | + std::unique_ptr<Impl> impl_; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor.cc | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | 5 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" |
| 6 | 6 |
| 1 | // sherpa-onnx/csrc/speaker-embedding-extractor.h | 1 | // sherpa-onnx/csrc/speaker-embedding-extractor.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ |
| 1 | // sherpa-onnx/csrc/speaker-embedding-manager-test.cc | 1 | // sherpa-onnx/csrc/speaker-embedding-manager-test.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | 3 | +// Copyright (c) 2024 Jingzhao Ou (jingzhao.ou@gmail.com) |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/speaker-embedding-manager.h" | 5 | #include "sherpa-onnx/csrc/speaker-embedding-manager.h" |
| 6 | 6 |
| 1 | // sherpa-onnx/csrc/speaker-embedding-manager.cc | 1 | // sherpa-onnx/csrc/speaker-embedding-manager.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/speaker-embedding-manager.h" | 5 | #include "sherpa-onnx/csrc/speaker-embedding-manager.h" |
| 6 | 6 |
| 1 | // sherpa-onnx/csrc/speaker-embedding-manager.h | 1 | // sherpa-onnx/csrc/speaker-embedding-manager.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ |
| 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | 6 | #define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ |
| @@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename): | @@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename): | ||
| 56 | return extractor | 56 | return extractor |
| 57 | 57 | ||
| 58 | 58 | ||
| 59 | -def test_wespeaker_model(model_filename: str): | 59 | +def test_zh_models(model_filename: str): |
| 60 | model_filename = str(model_filename) | 60 | model_filename = str(model_filename) |
| 61 | if "en" in model_filename: | 61 | if "en" in model_filename: |
| 62 | print(f"skip {model_filename}") | 62 | print(f"skip {model_filename}") |
| @@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str): | @@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str): | ||
| 114 | assert ans == name, (name, ans) | 114 | assert ans == name, (name, ans) |
| 115 | 115 | ||
| 116 | 116 | ||
| 117 | -def test_3dspeaker_model(model_filename: str): | ||
| 118 | - extractor = load_speaker_embedding_model(str(model_filename)) | 117 | +def test_en_and_zh_models(model_filename: str): |
| 118 | + model_filename = str(model_filename) | ||
| 119 | + extractor = load_speaker_embedding_model(model_filename) | ||
| 119 | manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | 120 | manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) |
| 120 | 121 | ||
| 121 | filenames = [ | 122 | filenames = [ |
| @@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str): | @@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str): | ||
| 124 | "speaker1_a_en_16k", | 125 | "speaker1_a_en_16k", |
| 125 | "speaker2_a_en_16k", | 126 | "speaker2_a_en_16k", |
| 126 | ] | 127 | ] |
| 128 | + is_en = "en" in model_filename | ||
| 127 | for filename in filenames: | 129 | for filename in filenames: |
| 130 | + if is_en and "cn" in filename: | ||
| 131 | + continue | ||
| 132 | + | ||
| 133 | + if not is_en and "en" in filename: | ||
| 134 | + continue | ||
| 135 | + | ||
| 128 | name = filename.rsplit("_", maxsplit=1)[0] | 136 | name = filename.rsplit("_", maxsplit=1)[0] |
| 129 | data, sample_rate = read_wave( | 137 | data, sample_rate = read_wave( |
| 130 | f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" | 138 | f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" |
| @@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str): | @@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str): | ||
| 145 | "speaker1_b_en_16k", | 153 | "speaker1_b_en_16k", |
| 146 | ] | 154 | ] |
| 147 | for filename in filenames: | 155 | for filename in filenames: |
| 156 | + if is_en and "cn" in filename: | ||
| 157 | + continue | ||
| 158 | + | ||
| 159 | + if not is_en and "en" in filename: | ||
| 160 | + continue | ||
| 148 | print(filename) | 161 | print(filename) |
| 149 | name = filename.rsplit("_", maxsplit=1)[0] | 162 | name = filename.rsplit("_", maxsplit=1)[0] |
| 150 | name = name.replace("b_cn", "a_cn") | 163 | name = name.replace("b_cn", "a_cn") |
| @@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase): | @@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase): | ||
| 178 | return | 191 | return |
| 179 | for filename in model_dir.glob("*.onnx"): | 192 | for filename in model_dir.glob("*.onnx"): |
| 180 | print(filename) | 193 | print(filename) |
| 181 | - test_wespeaker_model(filename) | 194 | + test_zh_models(filename) |
| 195 | + test_en_and_zh_models(filename) | ||
| 182 | 196 | ||
| 183 | def test_3dpeaker_models(self): | 197 | def test_3dpeaker_models(self): |
| 184 | model_dir = Path(d) / "3dspeaker" | 198 | model_dir = Path(d) / "3dspeaker" |
| @@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase): | @@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase): | ||
| 187 | return | 201 | return |
| 188 | for filename in model_dir.glob("*.onnx"): | 202 | for filename in model_dir.glob("*.onnx"): |
| 189 | print(filename) | 203 | print(filename) |
| 190 | - test_3dspeaker_model(filename) | 204 | + test_en_and_zh_models(filename) |
| 205 | + | ||
| 206 | + def test_nemo_models(self): | ||
| 207 | + model_dir = Path(d) / "nemo" | ||
| 208 | + if not model_dir.is_dir(): | ||
| 209 | + print(f"{model_dir} does not exist - skip it") | ||
| 210 | + return | ||
| 211 | + for filename in model_dir.glob("*.onnx"): | ||
| 212 | + print(filename) | ||
| 213 | + test_en_and_zh_models(filename) | ||
| 191 | 214 | ||
| 192 | 215 | ||
| 193 | if __name__ == "__main__": | 216 | if __name__ == "__main__": |
-
请 注册 或 登录 后发表评论