Committed by
GitHub
Add C++ runtime for models from 3d-speaker (#523)
正在显示
13 个修改的文件
包含
337 行增加
和
46 行删除
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +d=/tmp/sr-models | ||
| 12 | +mkdir -p $d | ||
| 13 | + | ||
| 14 | +pushd $d | ||
| 15 | +log "Download test waves" | ||
| 16 | +git clone https://github.com/csukuangfj/sr-data | ||
| 17 | +popd | ||
| 18 | + | ||
| 19 | +log "Download wespeaker models" | ||
| 20 | +model_dir=$d/wespeaker | ||
| 21 | +mkdir -p $model_dir | ||
| 22 | +pushd $model_dir | ||
| 23 | +models=( | ||
| 24 | +en_voxceleb_CAM++.onnx | ||
| 25 | +en_voxceleb_CAM++_LM.onnx | ||
| 26 | +en_voxceleb_resnet152_LM.onnx | ||
| 27 | +en_voxceleb_resnet221_LM.onnx | ||
| 28 | +en_voxceleb_resnet293_LM.onnx | ||
| 29 | +en_voxceleb_resnet34.onnx | ||
| 30 | +en_voxceleb_resnet34_LM.onnx | ||
| 31 | +zh_cnceleb_resnet34.onnx | ||
| 32 | +zh_cnceleb_resnet34_LM.onnx | ||
| 33 | +) | ||
| 34 | +for m in ${models[@]}; do | ||
| 35 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | ||
| 36 | +done | ||
| 37 | +ls -lh | ||
| 38 | +popd | ||
| 39 | + | ||
| 40 | +log "Download 3d-speaker models" | ||
| 41 | +model_dir=$d/3dspeaker | ||
| 42 | +mkdir -p $model_dir | ||
| 43 | +pushd $model_dir | ||
| 44 | +models=( | ||
| 45 | +speech_campplus_sv_en_voxceleb_16k.onnx | ||
| 46 | +speech_campplus_sv_zh-cn_16k-common.onnx | ||
| 47 | +speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx | ||
| 48 | +speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 49 | +speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx | ||
| 50 | +speech_eres2net_sv_en_voxceleb_16k.onnx | ||
| 51 | +speech_eres2net_sv_zh-cn_16k-common.onnx | ||
| 52 | +) | ||
| 53 | +for m in ${models[@]}; do | ||
| 54 | + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m | ||
| 55 | +done | ||
| 56 | +ls -lh | ||
| 57 | +popd | ||
| 58 | + | ||
| 59 | + | ||
| 60 | +python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose |
| @@ -76,6 +76,7 @@ jobs: | @@ -76,6 +76,7 @@ jobs: | ||
| 76 | - name: Test sherpa-onnx | 76 | - name: Test sherpa-onnx |
| 77 | shell: bash | 77 | shell: bash |
| 78 | run: | | 78 | run: | |
| 79 | + .github/scripts/test-speaker-recognition-python.sh | ||
| 79 | .github/scripts/test-python.sh | 80 | .github/scripts/test-python.sh |
| 80 | 81 | ||
| 81 | - uses: actions/upload-artifact@v3 | 82 | - uses: actions/upload-artifact@v3 |
| @@ -99,7 +99,7 @@ set(sources | @@ -99,7 +99,7 @@ set(sources | ||
| 99 | # speaker embedding extractor | 99 | # speaker embedding extractor |
| 100 | list(APPEND sources | 100 | list(APPEND sources |
| 101 | speaker-embedding-extractor-impl.cc | 101 | speaker-embedding-extractor-impl.cc |
| 102 | - speaker-embedding-extractor-wespeaker-model.cc | 102 | + speaker-embedding-extractor-model.cc |
| 103 | speaker-embedding-extractor.cc | 103 | speaker-embedding-extractor.cc |
| 104 | speaker-embedding-manager.cc | 104 | speaker-embedding-manager.cc |
| 105 | ) | 105 | ) |
| 1 | -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h | 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ | ||
| 6 | -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ | 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ |
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "Eigen/Dense" | ||
| 12 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | 13 | #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" |
| 13 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" | 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" |
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| 17 | -class SpeakerEmbeddingExtractorWeSpeakerImpl | 18 | +class SpeakerEmbeddingExtractorGeneralImpl |
| 18 | : public SpeakerEmbeddingExtractorImpl { | 19 | : public SpeakerEmbeddingExtractorImpl { |
| 19 | public: | 20 | public: |
| 20 | - explicit SpeakerEmbeddingExtractorWeSpeakerImpl( | 21 | + explicit SpeakerEmbeddingExtractorGeneralImpl( |
| 21 | const SpeakerEmbeddingExtractorConfig &config) | 22 | const SpeakerEmbeddingExtractorConfig &config) |
| 22 | : model_(config) {} | 23 | : model_(config) {} |
| 23 | 24 | ||
| @@ -25,7 +26,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | @@ -25,7 +26,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | ||
| 25 | 26 | ||
| 26 | std::unique_ptr<OnlineStream> CreateStream() const override { | 27 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| 27 | FeatureExtractorConfig feat_config; | 28 | FeatureExtractorConfig feat_config; |
| 28 | - auto meta_data = model_.GetMetaData(); | 29 | + const auto &meta_data = model_.GetMetaData(); |
| 29 | feat_config.sampling_rate = meta_data.sample_rate; | 30 | feat_config.sampling_rate = meta_data.sample_rate; |
| 30 | feat_config.normalize_samples = meta_data.normalize_samples; | 31 | feat_config.normalize_samples = meta_data.normalize_samples; |
| 31 | 32 | ||
| @@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | @@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | ||
| 52 | 53 | ||
| 53 | int32_t feat_dim = features.size() / num_frames; | 54 | int32_t feat_dim = features.size() / num_frames; |
| 54 | 55 | ||
| 56 | + const auto &meta_data = model_.GetMetaData(); | ||
| 57 | + if (!meta_data.feature_normalize_type.empty()) { | ||
| 58 | + if (meta_data.feature_normalize_type == "global-mean") { | ||
| 59 | + SubtractGlobalMean(features.data(), num_frames, feat_dim); | ||
| 60 | + } else { | ||
| 61 | + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", | ||
| 62 | + meta_data.feature_normalize_type.c_str()); | ||
| 63 | + exit(-1); | ||
| 64 | + } | ||
| 65 | + } | ||
| 66 | + | ||
| 55 | auto memory_info = | 67 | auto memory_info = |
| 56 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 68 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 57 | 69 | ||
| @@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | @@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl | ||
| 71 | } | 83 | } |
| 72 | 84 | ||
| 73 | private: | 85 | private: |
| 74 | - SpeakerEmbeddingExtractorWeSpeakerModel model_; | 86 | + void SubtractGlobalMean(float *p, int32_t num_frames, |
| 87 | + int32_t feat_dim) const { | ||
| 88 | + auto m = Eigen::Map< | ||
| 89 | + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( | ||
| 90 | + p, num_frames, feat_dim); | ||
| 91 | + | ||
| 92 | + m = m.rowwise() - m.colwise().mean(); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + private: | ||
| 96 | + SpeakerEmbeddingExtractorModel model_; | ||
| 75 | }; | 97 | }; |
| 76 | 98 | ||
| 77 | } // namespace sherpa_onnx | 99 | } // namespace sherpa_onnx |
| 78 | 100 | ||
| 79 | -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ | 101 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ |
| @@ -5,7 +5,7 @@ | @@ -5,7 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | #include "sherpa-onnx/csrc/macros.h" | 6 | #include "sherpa-onnx/csrc/macros.h" |
| 7 | #include "sherpa-onnx/csrc/onnx-utils.h" | 7 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 8 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h" | 8 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" |
| 9 | 9 | ||
| 10 | namespace sherpa_onnx { | 10 | namespace sherpa_onnx { |
| 11 | 11 | ||
| @@ -13,6 +13,7 @@ namespace { | @@ -13,6 +13,7 @@ namespace { | ||
| 13 | 13 | ||
| 14 | enum class ModelType { | 14 | enum class ModelType { |
| 15 | kWeSpeaker, | 15 | kWeSpeaker, |
| 16 | + k3dSpeaker, | ||
| 16 | kUnkown, | 17 | kUnkown, |
| 17 | }; | 18 | }; |
| 18 | 19 | ||
| @@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 49 | 50 | ||
| 50 | if (model_type.get() == std::string("wespeaker")) { | 51 | if (model_type.get() == std::string("wespeaker")) { |
| 51 | return ModelType::kWeSpeaker; | 52 | return ModelType::kWeSpeaker; |
| 53 | + } else if (model_type.get() == std::string("3d-speaker")) { | ||
| 54 | + return ModelType::k3dSpeaker; | ||
| 52 | } else { | 55 | } else { |
| 53 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 56 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 54 | return ModelType::kUnkown; | 57 | return ModelType::kUnkown; |
| @@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create( | @@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create( | ||
| 68 | 71 | ||
| 69 | switch (model_type) { | 72 | switch (model_type) { |
| 70 | case ModelType::kWeSpeaker: | 73 | case ModelType::kWeSpeaker: |
| 71 | - return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config); | 74 | + // fall through |
| 75 | + case ModelType::k3dSpeaker: | ||
| 76 | + return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config); | ||
| 72 | case ModelType::kUnkown: | 77 | case ModelType::kUnkown: |
| 73 | SHERPA_ONNX_LOGE( | 78 | SHERPA_ONNX_LOGE( |
| 74 | "Unknown model type in for speaker embedding extractor!"); | 79 | "Unknown model type in for speaker embedding extractor!"); |
| 1 | -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h | 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ | ||
| 5 | -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ | 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ |
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ | ||
| 6 | 6 | ||
| 7 | #include <cstdint> | 7 | #include <cstdint> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | 9 | ||
| 10 | namespace sherpa_onnx { | 10 | namespace sherpa_onnx { |
| 11 | 11 | ||
| 12 | -struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { | 12 | +struct SpeakerEmbeddingExtractorModelMetaData { |
| 13 | int32_t output_dim = 0; | 13 | int32_t output_dim = 0; |
| 14 | int32_t sample_rate = 0; | 14 | int32_t sample_rate = 0; |
| 15 | - int32_t normalize_samples = 0; | 15 | + |
| 16 | + // for wespeaker models, it is 0; | ||
| 17 | + // for 3d-speaker models, it is 1 | ||
| 18 | + int32_t normalize_samples = 1; | ||
| 19 | + | ||
| 20 | + // Chinese, English, etc. | ||
| 16 | std::string language; | 21 | std::string language; |
| 22 | + | ||
| 23 | + // for 3d-speaker, it is global-mean | ||
| 24 | + std::string feature_normalize_type; | ||
| 17 | }; | 25 | }; |
| 18 | 26 | ||
| 19 | } // namespace sherpa_onnx | 27 | } // namespace sherpa_onnx |
| 20 | -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ | 28 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ |
| 1 | -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc | 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" | 5 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" |
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | #include <utility> | 8 | #include <utility> |
| @@ -11,11 +11,11 @@ | @@ -11,11 +11,11 @@ | ||
| 11 | #include "sherpa-onnx/csrc/macros.h" | 11 | #include "sherpa-onnx/csrc/macros.h" |
| 12 | #include "sherpa-onnx/csrc/onnx-utils.h" | 12 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 13 | #include "sherpa-onnx/csrc/session.h" | 13 | #include "sherpa-onnx/csrc/session.h" |
| 14 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" | 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h" |
| 15 | 15 | ||
| 16 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 17 | 17 | ||
| 18 | -class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | 18 | +class SpeakerEmbeddingExtractorModel::Impl { |
| 19 | public: | 19 | public: |
| 20 | explicit Impl(const SpeakerEmbeddingExtractorConfig &config) | 20 | explicit Impl(const SpeakerEmbeddingExtractorConfig &config) |
| 21 | : config_(config), | 21 | : config_(config), |
| @@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | @@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | ||
| 37 | return std::move(outputs[0]); | 37 | return std::move(outputs[0]); |
| 38 | } | 38 | } |
| 39 | 39 | ||
| 40 | - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const { | 40 | + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const { |
| 41 | return meta_data_; | 41 | return meta_data_; |
| 42 | } | 42 | } |
| 43 | 43 | ||
| @@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | @@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | ||
| 65 | "normalize_samples"); | 65 | "normalize_samples"); |
| 66 | SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); | 66 | SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); |
| 67 | 67 | ||
| 68 | + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( | ||
| 69 | + meta_data_.feature_normalize_type, "feature_normalize_type", ""); | ||
| 70 | + | ||
| 68 | std::string framework; | 71 | std::string framework; |
| 69 | SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); | 72 | SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); |
| 70 | - if (framework != "wespeaker") { | ||
| 71 | - SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s", | 73 | + if (framework != "wespeaker" && framework != "3d-speaker") { |
| 74 | + SHERPA_ONNX_LOGE("Expect a wespeaker or a 3d-speaker model, given: %s", | ||
| 72 | framework.c_str()); | 75 | framework.c_str()); |
| 73 | exit(-1); | 76 | exit(-1); |
| 74 | } | 77 | } |
| @@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | @@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | ||
| 88 | std::vector<std::string> output_names_; | 91 | std::vector<std::string> output_names_; |
| 89 | std::vector<const char *> output_names_ptr_; | 92 | std::vector<const char *> output_names_ptr_; |
| 90 | 93 | ||
| 91 | - SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_; | 94 | + SpeakerEmbeddingExtractorModelMetaData meta_data_; |
| 92 | }; | 95 | }; |
| 93 | 96 | ||
| 94 | -SpeakerEmbeddingExtractorWeSpeakerModel:: | ||
| 95 | - SpeakerEmbeddingExtractorWeSpeakerModel( | ||
| 96 | - const SpeakerEmbeddingExtractorConfig &config) | 97 | +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( |
| 98 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 97 | : impl_(std::make_unique<Impl>(config)) {} | 99 | : impl_(std::make_unique<Impl>(config)) {} |
| 98 | 100 | ||
| 99 | -SpeakerEmbeddingExtractorWeSpeakerModel:: | ||
| 100 | - ~SpeakerEmbeddingExtractorWeSpeakerModel() = default; | 101 | +SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default; |
| 101 | 102 | ||
| 102 | -const SpeakerEmbeddingExtractorWeSpeakerModelMetaData & | ||
| 103 | -SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const { | 103 | +const SpeakerEmbeddingExtractorModelMetaData & |
| 104 | +SpeakerEmbeddingExtractorModel::GetMetaData() const { | ||
| 104 | return impl_->GetMetaData(); | 105 | return impl_->GetMetaData(); |
| 105 | } | 106 | } |
| 106 | 107 | ||
| 107 | -Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute( | ||
| 108 | - Ort::Value x) const { | 108 | +Ort::Value SpeakerEmbeddingExtractorModel::Compute(Ort::Value x) const { |
| 109 | return impl_->Compute(std::move(x)); | 109 | return impl_->Compute(std::move(x)); |
| 110 | } | 110 | } |
| 111 | 111 |
| 1 | -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h | 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-model.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ | ||
| 5 | -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ | 3 | +// Copyright (c) 2023-2024 Xiaomi Corporation |
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ | ||
| 6 | 6 | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | 8 | ||
| 9 | #include "onnxruntime_cxx_api.h" // NOLINT | 9 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 10 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" | 10 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h" |
| 11 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | 11 | #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" |
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | -class SpeakerEmbeddingExtractorWeSpeakerModel { | 15 | +class SpeakerEmbeddingExtractorModel { |
| 16 | public: | 16 | public: |
| 17 | - explicit SpeakerEmbeddingExtractorWeSpeakerModel( | 17 | + explicit SpeakerEmbeddingExtractorModel( |
| 18 | const SpeakerEmbeddingExtractorConfig &config); | 18 | const SpeakerEmbeddingExtractorConfig &config); |
| 19 | 19 | ||
| 20 | - ~SpeakerEmbeddingExtractorWeSpeakerModel(); | 20 | + ~SpeakerEmbeddingExtractorModel(); |
| 21 | 21 | ||
| 22 | - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const; | 22 | + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const; |
| 23 | 23 | ||
| 24 | /** | 24 | /** |
| 25 | * @param x A float32 tensor of shape (N, T, C) | 25 | * @param x A float32 tensor of shape (N, T, C) |
| @@ -34,4 +34,4 @@ class SpeakerEmbeddingExtractorWeSpeakerModel { | @@ -34,4 +34,4 @@ class SpeakerEmbeddingExtractorWeSpeakerModel { | ||
| 34 | 34 | ||
| 35 | } // namespace sherpa_onnx | 35 | } // namespace sherpa_onnx |
| 36 | 36 | ||
| 37 | -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ | 37 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ |
| @@ -23,6 +23,7 @@ set(py_test_files | @@ -23,6 +23,7 @@ set(py_test_files | ||
| 23 | test_offline_recognizer.py | 23 | test_offline_recognizer.py |
| 24 | test_online_recognizer.py | 24 | test_online_recognizer.py |
| 25 | test_online_transducer_model_config.py | 25 | test_online_transducer_model_config.py |
| 26 | + test_speaker_recognition.py | ||
| 26 | test_text2token.py | 27 | test_text2token.py |
| 27 | ) | 28 | ) |
| 28 | 29 |
sherpa-onnx/python/tests/test_feature_extractor_config.py
100644 → 100755
sherpa-onnx/python/tests/test_online_transducer_model_config.py
100644 → 100755
| 1 | +# sherpa-onnx/python/tests/test_speaker_recognition.py | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +# | ||
| 5 | +# To run this single test, use | ||
| 6 | +# | ||
| 7 | +# ctest --verbose -R test_speaker_recognition_py | ||
| 8 | + | ||
| 9 | +import unittest | ||
| 10 | +import wave | ||
| 11 | +from collections import defaultdict | ||
| 12 | +from pathlib import Path | ||
| 13 | +from typing import Tuple | ||
| 14 | + | ||
| 15 | +import numpy as np | ||
| 16 | +import sherpa_onnx | ||
| 17 | + | ||
| 18 | +d = "/tmp/sr-models" | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 22 | + """ | ||
| 23 | + Args: | ||
| 24 | + wave_filename: | ||
| 25 | + Path to a wave file. It should be single channel and each sample should | ||
| 26 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 27 | + Returns: | ||
| 28 | + Return a tuple containing: | ||
| 29 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 30 | + normalized to the range [-1, 1]. | ||
| 31 | + - sample rate of the wave file | ||
| 32 | + """ | ||
| 33 | + | ||
| 34 | + with wave.open(wave_filename) as f: | ||
| 35 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 36 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 37 | + num_samples = f.getnframes() | ||
| 38 | + samples = f.readframes(num_samples) | ||
| 39 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 40 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 41 | + | ||
| 42 | + samples_float32 = samples_float32 / 32768 | ||
| 43 | + return samples_float32, f.getframerate() | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +def load_speaker_embedding_model(model_filename): | ||
| 47 | + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( | ||
| 48 | + model=model_filename, | ||
| 49 | + num_threads=1, | ||
| 50 | + debug=True, | ||
| 51 | + provider="cpu", | ||
| 52 | + ) | ||
| 53 | + if not config.validate(): | ||
| 54 | + raise ValueError(f"Invalid config. {config}") | ||
| 55 | + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) | ||
| 56 | + return extractor | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +def test_wespeaker_model(model_filename: str): | ||
| 60 | + model_filename = str(model_filename) | ||
| 61 | + if "en" in model_filename: | ||
| 62 | + print(f"skip {model_filename}") | ||
| 63 | + return | ||
| 64 | + extractor = load_speaker_embedding_model(model_filename) | ||
| 65 | + filenames = [ | ||
| 66 | + "leijun-sr-1", | ||
| 67 | + "leijun-sr-2", | ||
| 68 | + "fangjun-sr-1", | ||
| 69 | + "fangjun-sr-2", | ||
| 70 | + "fangjun-sr-3", | ||
| 71 | + ] | ||
| 72 | + tmp = defaultdict(list) | ||
| 73 | + for filename in filenames: | ||
| 74 | + print(filename) | ||
| 75 | + name = filename.split("-", maxsplit=1)[0] | ||
| 76 | + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav") | ||
| 77 | + stream = extractor.create_stream() | ||
| 78 | + stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
| 79 | + stream.input_finished() | ||
| 80 | + assert extractor.is_ready(stream) | ||
| 81 | + embedding = extractor.compute(stream) | ||
| 82 | + embedding = np.array(embedding) | ||
| 83 | + tmp[name].append(embedding) | ||
| 84 | + | ||
| 85 | + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | ||
| 86 | + for name, embedding_list in tmp.items(): | ||
| 87 | + print(name, len(embedding_list)) | ||
| 88 | + embedding = sum(embedding_list) / len(embedding_list) | ||
| 89 | + status = manager.add(name, embedding) | ||
| 90 | + if not status: | ||
| 91 | + raise RuntimeError(f"Failed to register speaker {name}") | ||
| 92 | + | ||
| 93 | + filenames = [ | ||
| 94 | + "leijun-test-sr-1", | ||
| 95 | + "leijun-test-sr-2", | ||
| 96 | + "leijun-test-sr-3", | ||
| 97 | + "fangjun-test-sr-1", | ||
| 98 | + "fangjun-test-sr-2", | ||
| 99 | + ] | ||
| 100 | + for filename in filenames: | ||
| 101 | + name = filename.split("-", maxsplit=1)[0] | ||
| 102 | + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav") | ||
| 103 | + stream = extractor.create_stream() | ||
| 104 | + stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
| 105 | + stream.input_finished() | ||
| 106 | + assert extractor.is_ready(stream) | ||
| 107 | + embedding = extractor.compute(stream) | ||
| 108 | + embedding = np.array(embedding) | ||
| 109 | + status = manager.verify(name, embedding, threshold=0.5) | ||
| 110 | + if not status: | ||
| 111 | + raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav") | ||
| 112 | + | ||
| 113 | + ans = manager.search(embedding, threshold=0.5) | ||
| 114 | + assert ans == name, (name, ans) | ||
| 115 | + | ||
| 116 | + | ||
| 117 | +def test_3dspeaker_model(model_filename: str): | ||
| 118 | + extractor = load_speaker_embedding_model(str(model_filename)) | ||
| 119 | + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | ||
| 120 | + | ||
| 121 | + filenames = [ | ||
| 122 | + "speaker1_a_cn_16k", | ||
| 123 | + "speaker2_a_cn_16k", | ||
| 124 | + "speaker1_a_en_16k", | ||
| 125 | + "speaker2_a_en_16k", | ||
| 126 | + ] | ||
| 127 | + for filename in filenames: | ||
| 128 | + name = filename.rsplit("_", maxsplit=1)[0] | ||
| 129 | + data, sample_rate = read_wave( | ||
| 130 | + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" | ||
| 131 | + ) | ||
| 132 | + stream = extractor.create_stream() | ||
| 133 | + stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
| 134 | + stream.input_finished() | ||
| 135 | + assert extractor.is_ready(stream) | ||
| 136 | + embedding = extractor.compute(stream) | ||
| 137 | + embedding = np.array(embedding) | ||
| 138 | + | ||
| 139 | + status = manager.add(name, embedding) | ||
| 140 | + if not status: | ||
| 141 | + raise RuntimeError(f"Failed to register speaker {name}") | ||
| 142 | + | ||
| 143 | + filenames = [ | ||
| 144 | + "speaker1_b_cn_16k", | ||
| 145 | + "speaker1_b_en_16k", | ||
| 146 | + ] | ||
| 147 | + for filename in filenames: | ||
| 148 | + print(filename) | ||
| 149 | + name = filename.rsplit("_", maxsplit=1)[0] | ||
| 150 | + name = name.replace("b_cn", "a_cn") | ||
| 151 | + name = name.replace("b_en", "a_en") | ||
| 152 | + print(name) | ||
| 153 | + | ||
| 154 | + data, sample_rate = read_wave( | ||
| 155 | + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" | ||
| 156 | + ) | ||
| 157 | + stream = extractor.create_stream() | ||
| 158 | + stream.accept_waveform(sample_rate=sample_rate, waveform=data) | ||
| 159 | + stream.input_finished() | ||
| 160 | + assert extractor.is_ready(stream) | ||
| 161 | + embedding = extractor.compute(stream) | ||
| 162 | + embedding = np.array(embedding) | ||
| 163 | + status = manager.verify(name, embedding, threshold=0.5) | ||
| 164 | + if not status: | ||
| 165 | + raise RuntimeError( | ||
| 166 | + f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}" | ||
| 167 | + ) | ||
| 168 | + | ||
| 169 | + ans = manager.search(embedding, threshold=0.5) | ||
| 170 | + assert ans == name, (name, ans) | ||
| 171 | + | ||
| 172 | + | ||
| 173 | +class TestSpeakerRecognition(unittest.TestCase): | ||
| 174 | + def test_wespeaker_models(self): | ||
| 175 | + model_dir = Path(d) / "wespeaker" | ||
| 176 | + if not model_dir.is_dir(): | ||
| 177 | + print(f"{model_dir} does not exist - skip it") | ||
| 178 | + return | ||
| 179 | + for filename in model_dir.glob("*.onnx"): | ||
| 180 | + print(filename) | ||
| 181 | + test_wespeaker_model(filename) | ||
| 182 | + | ||
| 183 | + def test_3dpeaker_models(self): | ||
| 184 | + model_dir = Path(d) / "3dspeaker" | ||
| 185 | + if not model_dir.is_dir(): | ||
| 186 | + print(f"{model_dir} does not exist - skip it") | ||
| 187 | + return | ||
| 188 | + for filename in model_dir.glob("*.onnx"): | ||
| 189 | + print(filename) | ||
| 190 | + test_3dspeaker_model(filename) | ||
| 191 | + | ||
| 192 | + | ||
| 193 | +if __name__ == "__main__": | ||
| 194 | + unittest.main() |
sherpa-onnx/python/tests/test_text2token.py
100644 → 100755
-
请 注册 或 登录 后发表评论