Fangjun Kuang
Committed by GitHub

Add C++ runtime for models from 3d-speaker (#523)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +d=/tmp/sr-models
  12 +mkdir -p $d
  13 +
  14 +pushd $d
  15 +log "Download test waves"
  16 +git clone https://github.com/csukuangfj/sr-data
  17 +popd
  18 +
  19 +log "Download wespeaker models"
  20 +model_dir=$d/wespeaker
  21 +mkdir -p $model_dir
  22 +pushd $model_dir
  23 +models=(
  24 +en_voxceleb_CAM++.onnx
  25 +en_voxceleb_CAM++_LM.onnx
  26 +en_voxceleb_resnet152_LM.onnx
  27 +en_voxceleb_resnet221_LM.onnx
  28 +en_voxceleb_resnet293_LM.onnx
  29 +en_voxceleb_resnet34.onnx
  30 +en_voxceleb_resnet34_LM.onnx
  31 +zh_cnceleb_resnet34.onnx
  32 +zh_cnceleb_resnet34_LM.onnx
  33 +)
  34 +for m in ${models[@]}; do
  35 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
  36 +done
  37 +ls -lh
  38 +popd
  39 +
  40 +log "Download 3d-speaker models"
  41 +model_dir=$d/3dspeaker
  42 +mkdir -p $model_dir
  43 +pushd $model_dir
  44 +models=(
  45 +speech_campplus_sv_en_voxceleb_16k.onnx
  46 +speech_campplus_sv_zh-cn_16k-common.onnx
  47 +speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx
  48 +speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  49 +speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
  50 +speech_eres2net_sv_en_voxceleb_16k.onnx
  51 +speech_eres2net_sv_zh-cn_16k-common.onnx
  52 +)
  53 +for m in ${models[@]}; do
  54 + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
  55 +done
  56 +ls -lh
  57 +popd
  58 +
  59 +
  60 +python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
@@ -76,6 +76,7 @@ jobs: @@ -76,6 +76,7 @@ jobs:
76 - name: Test sherpa-onnx 76 - name: Test sherpa-onnx
77 shell: bash 77 shell: bash
78 run: | 78 run: |
  79 + .github/scripts/test-speaker-recognition-python.sh
79 .github/scripts/test-python.sh 80 .github/scripts/test-python.sh
80 81
81 - uses: actions/upload-artifact@v3 82 - uses: actions/upload-artifact@v3
@@ -99,7 +99,7 @@ set(sources @@ -99,7 +99,7 @@ set(sources
99 # speaker embedding extractor 99 # speaker embedding extractor
100 list(APPEND sources 100 list(APPEND sources
101 speaker-embedding-extractor-impl.cc 101 speaker-embedding-extractor-impl.cc
102 - speaker-embedding-extractor-wespeaker-model.cc 102 + speaker-embedding-extractor-model.cc
103 speaker-embedding-extractor.cc 103 speaker-embedding-extractor.cc
104 speaker-embedding-manager.cc 104 speaker-embedding-manager.cc
105 ) 105 )
1 -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h 1 +// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
5 -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_  
6 -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ 5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
7 #include <algorithm> 7 #include <algorithm>
8 #include <memory> 8 #include <memory>
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "Eigen/Dense"
12 #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" 13 #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
13 -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" 14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
17 -class SpeakerEmbeddingExtractorWeSpeakerImpl 18 +class SpeakerEmbeddingExtractorGeneralImpl
18 : public SpeakerEmbeddingExtractorImpl { 19 : public SpeakerEmbeddingExtractorImpl {
19 public: 20 public:
20 - explicit SpeakerEmbeddingExtractorWeSpeakerImpl( 21 + explicit SpeakerEmbeddingExtractorGeneralImpl(
21 const SpeakerEmbeddingExtractorConfig &config) 22 const SpeakerEmbeddingExtractorConfig &config)
22 : model_(config) {} 23 : model_(config) {}
23 24
@@ -25,7 +26,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl @@ -25,7 +26,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl
25 26
26 std::unique_ptr<OnlineStream> CreateStream() const override { 27 std::unique_ptr<OnlineStream> CreateStream() const override {
27 FeatureExtractorConfig feat_config; 28 FeatureExtractorConfig feat_config;
28 - auto meta_data = model_.GetMetaData(); 29 + const auto &meta_data = model_.GetMetaData();
29 feat_config.sampling_rate = meta_data.sample_rate; 30 feat_config.sampling_rate = meta_data.sample_rate;
30 feat_config.normalize_samples = meta_data.normalize_samples; 31 feat_config.normalize_samples = meta_data.normalize_samples;
31 32
@@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl @@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl
52 53
53 int32_t feat_dim = features.size() / num_frames; 54 int32_t feat_dim = features.size() / num_frames;
54 55
  56 + const auto &meta_data = model_.GetMetaData();
  57 + if (!meta_data.feature_normalize_type.empty()) {
  58 + if (meta_data.feature_normalize_type == "global-mean") {
  59 + SubtractGlobalMean(features.data(), num_frames, feat_dim);
  60 + } else {
  61 + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
  62 + meta_data.feature_normalize_type.c_str());
  63 + exit(-1);
  64 + }
  65 + }
  66 +
55 auto memory_info = 67 auto memory_info =
56 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 68 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
57 69
@@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl @@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl
71 } 83 }
72 84
73 private: 85 private:
74 - SpeakerEmbeddingExtractorWeSpeakerModel model_; 86 + void SubtractGlobalMean(float *p, int32_t num_frames,
  87 + int32_t feat_dim) const {
  88 + auto m = Eigen::Map<
  89 + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
  90 + p, num_frames, feat_dim);
  91 +
  92 + m = m.rowwise() - m.colwise().mean();
  93 + }
  94 +
  95 + private:
  96 + SpeakerEmbeddingExtractorModel model_;
75 }; 97 };
76 98
77 } // namespace sherpa_onnx 99 } // namespace sherpa_onnx
78 100
79 -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ 101 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
@@ -5,7 +5,7 @@ @@ -5,7 +5,7 @@
5 5
6 #include "sherpa-onnx/csrc/macros.h" 6 #include "sherpa-onnx/csrc/macros.h"
7 #include "sherpa-onnx/csrc/onnx-utils.h" 7 #include "sherpa-onnx/csrc/onnx-utils.h"
8 -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h" 8 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
9 9
10 namespace sherpa_onnx { 10 namespace sherpa_onnx {
11 11
@@ -13,6 +13,7 @@ namespace { @@ -13,6 +13,7 @@ namespace {
13 13
14 enum class ModelType { 14 enum class ModelType {
15 kWeSpeaker, 15 kWeSpeaker,
  16 + k3dSpeaker,
16 kUnkown, 17 kUnkown,
17 }; 18 };
18 19
@@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
49 50
50 if (model_type.get() == std::string("wespeaker")) { 51 if (model_type.get() == std::string("wespeaker")) {
51 return ModelType::kWeSpeaker; 52 return ModelType::kWeSpeaker;
  53 + } else if (model_type.get() == std::string("3d-speaker")) {
  54 + return ModelType::k3dSpeaker;
52 } else { 55 } else {
53 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 56 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
54 return ModelType::kUnkown; 57 return ModelType::kUnkown;
@@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create( @@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create(
68 71
69 switch (model_type) { 72 switch (model_type) {
70 case ModelType::kWeSpeaker: 73 case ModelType::kWeSpeaker:
71 - return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config); 74 + // fall through
  75 + case ModelType::k3dSpeaker:
  76 + return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
72 case ModelType::kUnkown: 77 case ModelType::kUnkown:
73 SHERPA_ONNX_LOGE( 78 SHERPA_ONNX_LOGE(
74 "Unknown model type in for speaker embedding extractor!"); 79 "Unknown model type in for speaker embedding extractor!");
1 -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h 1 +// sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_  
5 -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ 4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_
6 6
7 #include <cstdint> 7 #include <cstdint>
8 #include <string> 8 #include <string>
9 9
10 namespace sherpa_onnx { 10 namespace sherpa_onnx {
11 11
12 -struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { 12 +struct SpeakerEmbeddingExtractorModelMetaData {
13 int32_t output_dim = 0; 13 int32_t output_dim = 0;
14 int32_t sample_rate = 0; 14 int32_t sample_rate = 0;
15 - int32_t normalize_samples = 0; 15 +
  16 + // for wespeaker models, it is 0;
  17 + // for 3d-speaker models, it is 1
  18 + int32_t normalize_samples = 1;
  19 +
  20 + // Chinese, English, etc.
16 std::string language; 21 std::string language;
  22 +
  23 + // for 3d-speaker, it is global-mean
  24 + std::string feature_normalize_type;
17 }; 25 };
18 26
19 } // namespace sherpa_onnx 27 } // namespace sherpa_onnx
20 -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ 28 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_
1 -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc 1 +// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation 3 +// Copyright (c) 2023-2024 Xiaomi Corporation
4 4
5 -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" 5 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
6 6
7 #include <string> 7 #include <string>
8 #include <utility> 8 #include <utility>
@@ -11,11 +11,11 @@ @@ -11,11 +11,11 @@
11 #include "sherpa-onnx/csrc/macros.h" 11 #include "sherpa-onnx/csrc/macros.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 12 #include "sherpa-onnx/csrc/onnx-utils.h"
13 #include "sherpa-onnx/csrc/session.h" 13 #include "sherpa-onnx/csrc/session.h"
14 -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" 14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h"
15 15
16 namespace sherpa_onnx { 16 namespace sherpa_onnx {
17 17
18 -class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { 18 +class SpeakerEmbeddingExtractorModel::Impl {
19 public: 19 public:
20 explicit Impl(const SpeakerEmbeddingExtractorConfig &config) 20 explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
21 : config_(config), 21 : config_(config),
@@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { @@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
37 return std::move(outputs[0]); 37 return std::move(outputs[0]);
38 } 38 }
39 39
40 - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const { 40 + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const {
41 return meta_data_; 41 return meta_data_;
42 } 42 }
43 43
@@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { @@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
65 "normalize_samples"); 65 "normalize_samples");
66 SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); 66 SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
67 67
  68 + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(
  69 + meta_data_.feature_normalize_type, "feature_normalize_type", "");
  70 +
68 std::string framework; 71 std::string framework;
69 SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); 72 SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
70 - if (framework != "wespeaker") {  
71 - SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s", 73 + if (framework != "wespeaker" && framework != "3d-speaker") {
  74 + SHERPA_ONNX_LOGE("Expect a wespeaker or a 3d-speaker model, given: %s",
72 framework.c_str()); 75 framework.c_str());
73 exit(-1); 76 exit(-1);
74 } 77 }
@@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { @@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
88 std::vector<std::string> output_names_; 91 std::vector<std::string> output_names_;
89 std::vector<const char *> output_names_ptr_; 92 std::vector<const char *> output_names_ptr_;
90 93
91 - SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_; 94 + SpeakerEmbeddingExtractorModelMetaData meta_data_;
92 }; 95 };
93 96
94 -SpeakerEmbeddingExtractorWeSpeakerModel::  
95 - SpeakerEmbeddingExtractorWeSpeakerModel(  
96 - const SpeakerEmbeddingExtractorConfig &config) 97 +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel(
  98 + const SpeakerEmbeddingExtractorConfig &config)
97 : impl_(std::make_unique<Impl>(config)) {} 99 : impl_(std::make_unique<Impl>(config)) {}
98 100
99 -SpeakerEmbeddingExtractorWeSpeakerModel::  
100 - ~SpeakerEmbeddingExtractorWeSpeakerModel() = default; 101 +SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default;
101 102
102 -const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &  
103 -SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const { 103 +const SpeakerEmbeddingExtractorModelMetaData &
  104 +SpeakerEmbeddingExtractorModel::GetMetaData() const {
104 return impl_->GetMetaData(); 105 return impl_->GetMetaData();
105 } 106 }
106 107
107 -Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(  
108 - Ort::Value x) const { 108 +Ort::Value SpeakerEmbeddingExtractorModel::Compute(Ort::Value x) const {
109 return impl_->Compute(std::move(x)); 109 return impl_->Compute(std::move(x));
110 } 110 }
111 111
1 -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h 1 +// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
2 // 2 //
3 -// Copyright (c) 2023 Xiaomi Corporation  
4 -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_  
5 -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ 3 +// Copyright (c) 2023-2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
6 6
7 #include <memory> 7 #include <memory>
8 8
9 #include "onnxruntime_cxx_api.h" // NOLINT 9 #include "onnxruntime_cxx_api.h" // NOLINT
10 -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" 10 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h"
11 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 11 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 -class SpeakerEmbeddingExtractorWeSpeakerModel { 15 +class SpeakerEmbeddingExtractorModel {
16 public: 16 public:
17 - explicit SpeakerEmbeddingExtractorWeSpeakerModel( 17 + explicit SpeakerEmbeddingExtractorModel(
18 const SpeakerEmbeddingExtractorConfig &config); 18 const SpeakerEmbeddingExtractorConfig &config);
19 19
20 - ~SpeakerEmbeddingExtractorWeSpeakerModel(); 20 + ~SpeakerEmbeddingExtractorModel();
21 21
22 - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const; 22 + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const;
23 23
24 /** 24 /**
25 * @param x A float32 tensor of shape (N, T, C) 25 * @param x A float32 tensor of shape (N, T, C)
@@ -34,4 +34,4 @@ class SpeakerEmbeddingExtractorWeSpeakerModel { @@ -34,4 +34,4 @@ class SpeakerEmbeddingExtractorWeSpeakerModel {
34 34
35 } // namespace sherpa_onnx 35 } // namespace sherpa_onnx
36 36
37 -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ 37 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
@@ -23,6 +23,7 @@ set(py_test_files @@ -23,6 +23,7 @@ set(py_test_files
23 test_offline_recognizer.py 23 test_offline_recognizer.py
24 test_online_recognizer.py 24 test_online_recognizer.py
25 test_online_transducer_model_config.py 25 test_online_transducer_model_config.py
  26 + test_speaker_recognition.py
26 test_text2token.py 27 test_text2token.py
27 ) 28 )
28 29
  1 +# sherpa-onnx/python/tests/test_speaker_recognition.py
  2 +#
  3 +# Copyright (c) 2024 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_speaker_recognition_py
  8 +
  9 +import unittest
  10 +import wave
  11 +from collections import defaultdict
  12 +from pathlib import Path
  13 +from typing import Tuple
  14 +
  15 +import numpy as np
  16 +import sherpa_onnx
  17 +
  18 +d = "/tmp/sr-models"
  19 +
  20 +
  21 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  22 + """
  23 + Args:
  24 + wave_filename:
  25 + Path to a wave file. It should be single channel and each sample should
  26 + be 16-bit. Its sample rate does not need to be 16kHz.
  27 + Returns:
  28 + Return a tuple containing:
  29 + - A 1-D array of dtype np.float32 containing the samples, which are
  30 + normalized to the range [-1, 1].
  31 + - sample rate of the wave file
  32 + """
  33 +
  34 + with wave.open(wave_filename) as f:
  35 + assert f.getnchannels() == 1, f.getnchannels()
  36 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  37 + num_samples = f.getnframes()
  38 + samples = f.readframes(num_samples)
  39 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  40 + samples_float32 = samples_int16.astype(np.float32)
  41 +
  42 + samples_float32 = samples_float32 / 32768
  43 + return samples_float32, f.getframerate()
  44 +
  45 +
  46 +def load_speaker_embedding_model(model_filename):
  47 + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  48 + model=model_filename,
  49 + num_threads=1,
  50 + debug=True,
  51 + provider="cpu",
  52 + )
  53 + if not config.validate():
  54 + raise ValueError(f"Invalid config. {config}")
  55 + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
  56 + return extractor
  57 +
  58 +
  59 +def test_wespeaker_model(model_filename: str):
  60 + model_filename = str(model_filename)
  61 + if "en" in model_filename:
  62 + print(f"skip {model_filename}")
  63 + return
  64 + extractor = load_speaker_embedding_model(model_filename)
  65 + filenames = [
  66 + "leijun-sr-1",
  67 + "leijun-sr-2",
  68 + "fangjun-sr-1",
  69 + "fangjun-sr-2",
  70 + "fangjun-sr-3",
  71 + ]
  72 + tmp = defaultdict(list)
  73 + for filename in filenames:
  74 + print(filename)
  75 + name = filename.split("-", maxsplit=1)[0]
  76 + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
  77 + stream = extractor.create_stream()
  78 + stream.accept_waveform(sample_rate=sample_rate, waveform=data)
  79 + stream.input_finished()
  80 + assert extractor.is_ready(stream)
  81 + embedding = extractor.compute(stream)
  82 + embedding = np.array(embedding)
  83 + tmp[name].append(embedding)
  84 +
  85 + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
  86 + for name, embedding_list in tmp.items():
  87 + print(name, len(embedding_list))
  88 + embedding = sum(embedding_list) / len(embedding_list)
  89 + status = manager.add(name, embedding)
  90 + if not status:
  91 + raise RuntimeError(f"Failed to register speaker {name}")
  92 +
  93 + filenames = [
  94 + "leijun-test-sr-1",
  95 + "leijun-test-sr-2",
  96 + "leijun-test-sr-3",
  97 + "fangjun-test-sr-1",
  98 + "fangjun-test-sr-2",
  99 + ]
  100 + for filename in filenames:
  101 + name = filename.split("-", maxsplit=1)[0]
  102 + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
  103 + stream = extractor.create_stream()
  104 + stream.accept_waveform(sample_rate=sample_rate, waveform=data)
  105 + stream.input_finished()
  106 + assert extractor.is_ready(stream)
  107 + embedding = extractor.compute(stream)
  108 + embedding = np.array(embedding)
  109 + status = manager.verify(name, embedding, threshold=0.5)
  110 + if not status:
  111 + raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")
  112 +
  113 + ans = manager.search(embedding, threshold=0.5)
  114 + assert ans == name, (name, ans)
  115 +
  116 +
  117 +def test_3dspeaker_model(model_filename: str):
  118 + extractor = load_speaker_embedding_model(str(model_filename))
  119 + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
  120 +
  121 + filenames = [
  122 + "speaker1_a_cn_16k",
  123 + "speaker2_a_cn_16k",
  124 + "speaker1_a_en_16k",
  125 + "speaker2_a_en_16k",
  126 + ]
  127 + for filename in filenames:
  128 + name = filename.rsplit("_", maxsplit=1)[0]
  129 + data, sample_rate = read_wave(
  130 + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
  131 + )
  132 + stream = extractor.create_stream()
  133 + stream.accept_waveform(sample_rate=sample_rate, waveform=data)
  134 + stream.input_finished()
  135 + assert extractor.is_ready(stream)
  136 + embedding = extractor.compute(stream)
  137 + embedding = np.array(embedding)
  138 +
  139 + status = manager.add(name, embedding)
  140 + if not status:
  141 + raise RuntimeError(f"Failed to register speaker {name}")
  142 +
  143 + filenames = [
  144 + "speaker1_b_cn_16k",
  145 + "speaker1_b_en_16k",
  146 + ]
  147 + for filename in filenames:
  148 + print(filename)
  149 + name = filename.rsplit("_", maxsplit=1)[0]
  150 + name = name.replace("b_cn", "a_cn")
  151 + name = name.replace("b_en", "a_en")
  152 + print(name)
  153 +
  154 + data, sample_rate = read_wave(
  155 + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
  156 + )
  157 + stream = extractor.create_stream()
  158 + stream.accept_waveform(sample_rate=sample_rate, waveform=data)
  159 + stream.input_finished()
  160 + assert extractor.is_ready(stream)
  161 + embedding = extractor.compute(stream)
  162 + embedding = np.array(embedding)
  163 + status = manager.verify(name, embedding, threshold=0.5)
  164 + if not status:
  165 + raise RuntimeError(
  166 + f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
  167 + )
  168 +
  169 + ans = manager.search(embedding, threshold=0.5)
  170 + assert ans == name, (name, ans)
  171 +
  172 +
  173 +class TestSpeakerRecognition(unittest.TestCase):
  174 + def test_wespeaker_models(self):
  175 + model_dir = Path(d) / "wespeaker"
  176 + if not model_dir.is_dir():
  177 + print(f"{model_dir} does not exist - skip it")
  178 + return
  179 + for filename in model_dir.glob("*.onnx"):
  180 + print(filename)
  181 + test_wespeaker_model(filename)
  182 +
  183 + def test_3dpeaker_models(self):
  184 + model_dir = Path(d) / "3dspeaker"
  185 + if not model_dir.is_dir():
  186 + print(f"{model_dir} does not exist - skip it")
  187 + return
  188 + for filename in model_dir.glob("*.onnx"):
  189 + print(filename)
  190 + test_3dspeaker_model(filename)
  191 +
  192 +
  193 +if __name__ == "__main__":
  194 + unittest.main()