Fangjun Kuang
Committed by GitHub

Add runtime support for wespeaker models (#516)

  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to use Python APIs for speaker identification.
  5 +
  6 +Usage:
  7 +
  8 +(1) Prepare a text file containing speaker related files.
  9 +
  10 +Each line in the text file contains two columns. The first column is the
  11 +speaker name, while the second column contains the wave file of the speaker.
  12 +
  13 +If the text file contains multiple wave files for the same speaker, then the
  14 +embeddings of these files are averaged.
  15 +
  16 +An example text file is given below:
  17 +
  18 + foo /path/to/a.wav
  19 + bar /path/to/b.wav
  20 + foo /path/to/c.wav
  21 + foobar /path/to/d.wav
  22 +
  23 +Each wave file should contain only a single channel; the sample format
  24 +should be int16_t; the sample rate can be arbitrary.
  25 +
  26 +(2) Download a model for computing speaker embeddings
  27 +
  28 +Please visit
  29 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  30 +to download a model. An example is given below:
  31 +
  32 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx
  33 +
  34 +Note that `zh` means Chinese, while `en` means English.
  35 +
  36 +(3) Run this script
  37 +
  38 +Assume the filename of the text file is speaker.txt.
  39 +
  40 +python3 ./python-api-examples/speaker-identification.py \
  41 + --speaker-file ./speaker.txt \
  42 + --model ./zh_cnceleb_resnet34.onnx
  43 +"""
  44 +import argparse
  45 +import queue
  46 +import threading
  47 +from collections import defaultdict
  48 +from pathlib import Path
  49 +from typing import Dict, List, Tuple
  50 +
  51 +import numpy as np
  52 +import sherpa_onnx
  53 +import torchaudio
  54 +
  55 +try:
  56 + import sounddevice as sd
  57 +except ImportError:
  58 + print("Please install sounddevice first. You can use")
  59 + print()
  60 + print(" pip install sounddevice")
  61 + print()
  62 + print("to install it")
  63 + sys.exit(-1)
  64 +
  65 +
  66 +def get_args():
  67 + parser = argparse.ArgumentParser(
  68 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  69 + )
  70 +
  71 + parser.add_argument(
  72 + "--speaker-file",
  73 + type=str,
  74 + required=True,
  75 + help="""Path to the speaker file. Read the help doc at the beginning of this
  76 + file for the format.""",
  77 + )
  78 +
  79 + parser.add_argument(
  80 + "--model",
  81 + type=str,
  82 + required=True,
  83 + help="Path to the model file.",
  84 + )
  85 +
  86 + parser.add_argument("--threshold", type=float, default=0.6)
  87 +
  88 + parser.add_argument(
  89 + "--num-threads",
  90 + type=int,
  91 + default=1,
  92 + help="Number of threads for neural network computation",
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--debug",
  97 + type=bool,
  98 + default=False,
  99 + help="True to show debug messages",
  100 + )
  101 +
  102 + parser.add_argument(
  103 + "--provider",
  104 + type=str,
  105 + default="cpu",
  106 + help="Valid values: cpu, cuda, coreml",
  107 + )
  108 +
  109 + return parser.parse_args()
  110 +
  111 +
  112 +def load_speaker_embedding_model(args):
  113 + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  114 + model=args.model,
  115 + num_threads=args.num_threads,
  116 + debug=args.debug,
  117 + provider=args.provider,
  118 + )
  119 + if not config.validate():
  120 + raise ValueError(f"Invalid config. {config}")
  121 + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
  122 + return extractor
  123 +
  124 +
  125 +def load_speaker_file(args) -> Dict[str, List[str]]:
  126 + if not Path(args.speaker_file).is_file():
  127 + raise ValueError(f"--speaker-file {args.speaker_file} does not exist")
  128 +
  129 + ans = defaultdict(list)
  130 + with open(args.speaker_file) as f:
  131 + for line in f:
  132 + line = line.strip()
  133 + if not line:
  134 + continue
  135 +
  136 + fields = line.split()
  137 + if len(fields) != 2:
  138 + raise ValueError(f"Invalid line: {line}. Fields: {fields}")
  139 +
  140 + speaker_name, filename = fields
  141 + ans[speaker_name].append(filename)
  142 + return ans
  143 +
  144 +
  145 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  146 + samples, sample_rate = torchaudio.load(filename)
  147 + return samples[0].contiguous().numpy(), sample_rate
  148 +
  149 +
  150 +def compute_speaker_embedding(
  151 + filenames: List[str],
  152 + extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
  153 +) -> np.ndarray:
  154 + assert len(filenames) > 0, f"filenames is empty"
  155 +
  156 + ans = None
  157 + for filename in filenames:
  158 + print(f"processing {filename}")
  159 + samples, sample_rate = load_audio(filename)
  160 + stream = extractor.create_stream()
  161 + stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
  162 + stream.input_finished()
  163 +
  164 + assert extractor.is_ready(stream)
  165 + embedding = extractor.compute(stream)
  166 + embedding = np.array(embedding)
  167 + if ans is None:
  168 + ans = embedding
  169 + else:
  170 + ans += embedding
  171 +
  172 + return ans / len(filenames)
  173 +
  174 +
  175 +g_buffer = queue.Queue()
  176 +g_stop = False
  177 +g_sample_rate = 16000
  178 +g_read_mic_thread = None
  179 +
  180 +
  181 +def read_mic():
  182 + print("Please speak!")
  183 + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
  184 + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s:
  185 + while not g_stop:
  186 + samples, _ = s.read(samples_per_read) # a blocking read
  187 + g_buffer.put(samples)
  188 +
  189 +
  190 +def main():
  191 + args = get_args()
  192 + print(args)
  193 + extractor = load_speaker_embedding_model(args)
  194 + speaker_file = load_speaker_file(args)
  195 +
  196 + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
  197 + for name, filename_list in speaker_file.items():
  198 + embedding = compute_speaker_embedding(
  199 + filenames=filename_list,
  200 + extractor=extractor,
  201 + )
  202 + status = manager.add(name, embedding)
  203 + if not status:
  204 + raise RuntimeError(f"Failed to register speaker {name}")
  205 +
  206 + devices = sd.query_devices()
  207 + if len(devices) == 0:
  208 + print("No microphone devices found")
  209 + sys.exit(0)
  210 +
  211 + print(devices)
  212 + default_input_device_idx = sd.default.device[0]
  213 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  214 +
  215 + global g_stop
  216 + global g_read_mic_thread
  217 + while True:
  218 + key = input("Press enter to start recording")
  219 + if key.lower() in ("q", "quit"):
  220 + g_stop = True
  221 + break
  222 +
  223 + g_stop = False
  224 + g_buffer.queue.clear()
  225 + g_read_mic_thread = threading.Thread(target=read_mic)
  226 + g_read_mic_thread.start()
  227 + input("Press enter to stop recording")
  228 + g_stop = True
  229 + g_read_mic_thread.join()
  230 + print("Compute embedding")
  231 + stream = extractor.create_stream()
  232 + while not g_buffer.empty():
  233 + samples = g_buffer.get()
  234 + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples)
  235 + stream.input_finished()
  236 +
  237 + embedding = extractor.compute(stream)
  238 + embedding = np.array(embedding)
  239 + name = manager.search(embedding, threshold=args.threshold)
  240 + if not name:
  241 + name = "unknown"
  242 + print(f"Predicted name: {name}")
  243 +
  244 +
  245 +if __name__ == "__main__":
  246 + try:
  247 + main()
  248 + except KeyboardInterrupt:
  249 + print("\nCaught Ctrl + C. Exiting")
  250 + g_stop = True
  251 + if g_read_mic_thread.is_alive():
  252 + g_read_mic_thread.join()
@@ -96,6 +96,14 @@ set(sources @@ -96,6 +96,14 @@ set(sources
96 wave-reader.cc 96 wave-reader.cc
97 ) 97 )
98 98
  99 +# speaker embedding extractor
  100 +list(APPEND sources
  101 + speaker-embedding-extractor-impl.cc
  102 + speaker-embedding-extractor-wespeaker-model.cc
  103 + speaker-embedding-extractor.cc
  104 + speaker-embedding-manager.cc
  105 +)
  106 +
99 list(APPEND sources 107 list(APPEND sources
100 lexicon.cc 108 lexicon.cc
101 offline-tts-impl.cc 109 offline-tts-impl.cc
@@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
387 utfcpp-test.cc 395 utfcpp-test.cc
388 ) 396 )
389 397
  398 + list(APPEND sherpa_onnx_test_srcs
  399 + speaker-embedding-manager-test.cc
  400 + )
  401 +
390 function(sherpa_onnx_add_test source) 402 function(sherpa_onnx_add_test source)
391 get_filename_component(name ${source} NAME_WE) 403 get_filename_component(name ${source} NAME_WE)
392 set(target_name ${name}) 404 set(target_name ${name})
@@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) { @@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
64 auto stop = std::chrono::high_resolution_clock::now(); 64 auto stop = std::chrono::high_resolution_clock::now();
65 auto duration = 65 auto duration =
66 std::chrono::duration_cast<std::chrono::microseconds>(stop - start); 66 std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
67 - SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,  
68 - duration.count()); 67 + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
  68 + static_cast<int32_t>(duration.count()));
69 } 69 }
70 } 70 }
71 71
@@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { @@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
91 return GetSessionOptionsImpl(config.num_threads, config.provider); 91 return GetSessionOptionsImpl(config.num_threads, config.provider);
92 } 92 }
93 93
  94 +Ort::SessionOptions GetSessionOptions(
  95 + const SpeakerEmbeddingExtractorConfig &config) {
  96 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  97 +}
  98 +
94 } // namespace sherpa_onnx 99 } // namespace sherpa_onnx
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/offline-tts-model-config.h" 11 #include "sherpa-onnx/csrc/offline-tts-model-config.h"
12 #include "sherpa-onnx/csrc/online-lm-config.h" 12 #include "sherpa-onnx/csrc/online-lm-config.h"
13 #include "sherpa-onnx/csrc/online-model-config.h" 13 #include "sherpa-onnx/csrc/online-model-config.h"
  14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
14 #include "sherpa-onnx/csrc/vad-model-config.h" 15 #include "sherpa-onnx/csrc/vad-model-config.h"
15 16
16 namespace sherpa_onnx { 17 namespace sherpa_onnx {
@@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); @@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
26 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); 27 Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
27 28
28 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); 29 Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
  30 +
  31 +Ort::SessionOptions GetSessionOptions(
  32 + const SpeakerEmbeddingExtractorConfig &config);
29 } // namespace sherpa_onnx 33 } // namespace sherpa_onnx
30 34
31 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 35 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
  5 +
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/csrc/onnx-utils.h"
  8 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +namespace {
  13 +
  14 +enum class ModelType {
  15 + kWeSpeaker,
  16 + kUnkown,
  17 +};
  18 +
  19 +} // namespace
  20 +
  21 +static ModelType GetModelType(char *model_data, size_t model_data_length,
  22 + bool debug) {
  23 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  24 + Ort::SessionOptions sess_opts;
  25 +
  26 + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
  27 + sess_opts);
  28 +
  29 + Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  30 + if (debug) {
  31 + std::ostringstream os;
  32 + PrintModelMetadata(os, meta_data);
  33 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  34 + }
  35 +
  36 + Ort::AllocatorWithDefaultOptions allocator;
  37 + auto model_type =
  38 + meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
  39 + if (!model_type) {
  40 + SHERPA_ONNX_LOGE(
  41 + "No model_type in the metadata!\n"
  42 + "Please make sure you have added metadata to the model.\n\n"
  43 + "For instance, you can use\n"
  44 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
  45 + "add_meta_data.py"
  46 + "to add metadata to models from WeSpeaker\n");
  47 + return ModelType::kUnkown;
  48 + }
  49 +
  50 + if (model_type.get() == std::string("wespeaker")) {
  51 + return ModelType::kWeSpeaker;
  52 + } else {
  53 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
  54 + return ModelType::kUnkown;
  55 + }
  56 +}
  57 +
  58 +std::unique_ptr<SpeakerEmbeddingExtractorImpl>
  59 +SpeakerEmbeddingExtractorImpl::Create(
  60 + const SpeakerEmbeddingExtractorConfig &config) {
  61 + ModelType model_type = ModelType::kUnkown;
  62 +
  63 + {
  64 + auto buffer = ReadFile(config.model);
  65 +
  66 + model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  67 + }
  68 +
  69 + switch (model_type) {
  70 + case ModelType::kWeSpeaker:
  71 + return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
  72 + case ModelType::kUnkown:
  73 + SHERPA_ONNX_LOGE(
  74 + "Unknown model type in for speaker embedding extractor!");
  75 + return nullptr;
  76 + }
  77 +
  78 + // unreachable code
  79 + return nullptr;
  80 +}
  81 +
  82 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class SpeakerEmbeddingExtractorImpl {
  17 + public:
  18 + virtual ~SpeakerEmbeddingExtractorImpl() = default;
  19 +
  20 + static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
  21 + const SpeakerEmbeddingExtractorConfig &config);
  22 +
  23 + virtual int32_t Dim() const = 0;
  24 +
  25 + virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
  26 +
  27 + virtual bool IsReady(OnlineStream *s) const = 0;
  28 +
  29 + virtual std::vector<float> Compute(OnlineStream *s) const = 0;
  30 +};
  31 +
  32 +} // namespace sherpa_onnx
  33 +
  34 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
  13 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class SpeakerEmbeddingExtractorWeSpeakerImpl
  18 + : public SpeakerEmbeddingExtractorImpl {
  19 + public:
  20 + explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
  21 + const SpeakerEmbeddingExtractorConfig &config)
  22 + : model_(config) {}
  23 +
  24 + int32_t Dim() const override { return model_.GetMetaData().output_dim; }
  25 +
  26 + std::unique_ptr<OnlineStream> CreateStream() const override {
  27 + FeatureExtractorConfig feat_config;
  28 + auto meta_data = model_.GetMetaData();
  29 + feat_config.sampling_rate = meta_data.sample_rate;
  30 + feat_config.normalize_samples = meta_data.normalize_features;
  31 +
  32 + return std::make_unique<OnlineStream>(feat_config);
  33 + }
  34 +
  35 + bool IsReady(OnlineStream *s) const override {
  36 + return s->GetNumProcessedFrames() < s->NumFramesReady();
  37 + }
  38 +
  39 + std::vector<float> Compute(OnlineStream *s) const override {
  40 + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
  41 + if (num_frames <= 0) {
  42 + SHERPA_ONNX_LOGE(
  43 + "Please make sure IsReady(s) returns true. num_frames: %d",
  44 + num_frames);
  45 + return {};
  46 + }
  47 +
  48 + std::vector<float> features =
  49 + s->GetFrames(s->GetNumProcessedFrames(), num_frames);
  50 +
  51 + s->GetNumProcessedFrames() += num_frames;
  52 +
  53 + int32_t feat_dim = features.size() / num_frames;
  54 +
  55 + auto memory_info =
  56 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  57 +
  58 + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
  59 + Ort::Value x =
  60 + Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
  61 + x_shape.data(), x_shape.size());
  62 + Ort::Value embedding = model_.Compute(std::move(x));
  63 + std::vector<int64_t> embedding_shape =
  64 + embedding.GetTensorTypeAndShapeInfo().GetShape();
  65 +
  66 + std::vector<float> ans(embedding_shape[1]);
  67 + std::copy(embedding.GetTensorData<float>(),
  68 + embedding.GetTensorData<float>() + ans.size(), ans.begin());
  69 +
  70 + return ans;
  71 + }
  72 +
  73 + private:
  74 + SpeakerEmbeddingExtractorWeSpeakerModel model_;
  75 +};
  76 +
  77 +} // namespace sherpa_onnx
  78 +
  79 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
  6 +
  7 +#include <cstdint>
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
  13 + int32_t output_dim = 0;
  14 + int32_t sample_rate = 0;
  15 + int32_t normalize_features = 0;
  16 + std::string language;
  17 +};
  18 +
  19 +} // namespace sherpa_onnx
  20 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
  13 +#include "sherpa-onnx/csrc/session.h"
  14 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
  19 + public:
  20 + explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
  21 + : config_(config),
  22 + env_(ORT_LOGGING_LEVEL_ERROR),
  23 + sess_opts_(GetSessionOptions(config)),
  24 + allocator_{} {
  25 + {
  26 + auto buf = ReadFile(config.model);
  27 + Init(buf.data(), buf.size());
  28 + }
  29 + }
  30 +
  31 + Ort::Value Compute(Ort::Value x) const {
  32 + std::array<Ort::Value, 1> inputs = {std::move(x)};
  33 +
  34 + auto outputs =
  35 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  36 + output_names_ptr_.data(), output_names_ptr_.size());
  37 + return std::move(outputs[0]);
  38 + }
  39 +
  40 + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
  41 + return meta_data_;
  42 + }
  43 +
  44 + private:
  45 + void Init(void *model_data, size_t model_data_length) {
  46 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  47 + sess_opts_);
  48 +
  49 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  50 +
  51 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  52 +
  53 + // get meta data
  54 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  55 + if (config_.debug) {
  56 + std::ostringstream os;
  57 + PrintModelMetadata(os, meta_data);
  58 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  59 + }
  60 +
  61 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  62 + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
  63 + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
  64 + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
  65 + "normalize_features");
  66 + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
  67 +
  68 + std::string framework;
  69 + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
  70 + if (framework != "wespeaker") {
  71 + SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
  72 + framework.c_str());
  73 + exit(-1);
  74 + }
  75 + }
  76 +
  77 + private:
  78 + SpeakerEmbeddingExtractorConfig config_;
  79 + Ort::Env env_;
  80 + Ort::SessionOptions sess_opts_;
  81 + Ort::AllocatorWithDefaultOptions allocator_;
  82 +
  83 + std::unique_ptr<Ort::Session> sess_;
  84 +
  85 + std::vector<std::string> input_names_;
  86 + std::vector<const char *> input_names_ptr_;
  87 +
  88 + std::vector<std::string> output_names_;
  89 + std::vector<const char *> output_names_ptr_;
  90 +
  91 + SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
  92 +};
  93 +
  94 +SpeakerEmbeddingExtractorWeSpeakerModel::
  95 + SpeakerEmbeddingExtractorWeSpeakerModel(
  96 + const SpeakerEmbeddingExtractorConfig &config)
  97 + : impl_(std::make_unique<Impl>(config)) {}
  98 +
  99 +SpeakerEmbeddingExtractorWeSpeakerModel::
  100 + ~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
  101 +
  102 +const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
  103 +SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
  104 + return impl_->GetMetaData();
  105 +}
  106 +
  107 +Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
  108 + Ort::Value x) const {
  109 + return impl_->Compute(std::move(x));
  110 +}
  111 +
  112 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
  11 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class SpeakerEmbeddingExtractorWeSpeakerModel {
  16 + public:
  17 + explicit SpeakerEmbeddingExtractorWeSpeakerModel(
  18 + const SpeakerEmbeddingExtractorConfig &config);
  19 +
  20 + ~SpeakerEmbeddingExtractorWeSpeakerModel();
  21 +
  22 + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
  23 +
  24 + /**
  25 + * @param x A float32 tensor of shape (N, T, C)
  26 + * @return A float32 tensor of shape (N, C)
  27 + */
  28 + Ort::Value Compute(Ort::Value x) const;
  29 +
  30 + private:
  31 + class Impl;
  32 + std::unique_ptr<Impl> impl_;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
  16 + po->Register("model", &model, "Path to the speaker embedding model.");
  17 + po->Register("num-threads", &num_threads,
  18 + "Number of threads to run the neural network");
  19 +
  20 + po->Register("debug", &debug,
  21 + "true to print model information while loading it.");
  22 +
  23 + po->Register("provider", &provider,
  24 + "Specify a provider to use: cpu, cuda, coreml");
  25 +}
  26 +
  27 +bool SpeakerEmbeddingExtractorConfig::Validate() const {
  28 + if (model.empty()) {
  29 + SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
  30 + return false;
  31 + }
  32 +
  33 + if (!FileExists(model)) {
  34 + SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
  35 + model.c_str());
  36 + return false;
  37 + }
  38 +
  39 + return true;
  40 +}
  41 +
  42 +std::string SpeakerEmbeddingExtractorConfig::ToString() const {
  43 + std::ostringstream os;
  44 +
  45 + os << "SpeakerEmbeddingExtractorConfig(";
  46 + os << "model=\"" << model << "\", ";
  47 + os << "num_threads=" << num_threads << ", ";
  48 + os << "debug=" << (debug ? "True" : "False") << ", ";
  49 + os << "provider=\"" << provider << "\")";
  50 +
  51 + return os.str();
  52 +}
  53 +
  54 +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
  55 + const SpeakerEmbeddingExtractorConfig &config)
  56 + : impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
  57 +
  58 +SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
  59 +
  60 +int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
  61 +
  62 +std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
  63 + return impl_->CreateStream();
  64 +}
  65 +
  66 +bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
  67 + return impl_->IsReady(s);
  68 +}
  69 +
  70 +std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
  71 + return impl_->Compute(s);
  72 +}
  73 +
  74 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-extractor.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/online-stream.h"
  13 +#include "sherpa-onnx/csrc/parse-options.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct SpeakerEmbeddingExtractorConfig {
  18 + std::string model;
  19 + int32_t num_threads = 1;
  20 + bool debug = false;
  21 + std::string provider = "cpu";
  22 +
  23 + SpeakerEmbeddingExtractorConfig() = default;
  24 + SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
  25 + bool debug, const std::string &provider)
  26 + : model(model),
  27 + num_threads(num_threads),
  28 + debug(debug),
  29 + provider(provider) {}
  30 +
  31 + void Register(ParseOptions *po);
  32 + bool Validate() const;
  33 + std::string ToString() const;
  34 +};
  35 +
  36 +class SpeakerEmbeddingExtractorImpl;
  37 +
  38 +class SpeakerEmbeddingExtractor {
  39 + public:
  40 + explicit SpeakerEmbeddingExtractor(
  41 + const SpeakerEmbeddingExtractorConfig &config);
  42 +
  43 + ~SpeakerEmbeddingExtractor();
  44 +
  45 + // Return the dimension of the embedding
  46 + int32_t Dim() const;
  47 +
  48 + // Create a stream to accept audio samples and compute features
  49 + std::unique_ptr<OnlineStream> CreateStream() const;
  50 +
  51 + // Return true if there are feature frames in OnlineStream that
  52 + // can be used to compute embeddings.
  53 + bool IsReady(OnlineStream *s) const;
  54 +
  55 + // Compute the speaker embedding from the available unprocessed features
  56 + // of the given stream
  57 + //
  58 + // You have to ensure IsReady(s) returns true before you call this method.
  59 + std::vector<float> Compute(OnlineStream *s) const;
  60 +
  61 + private:
  62 + std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
  63 +};
  64 +
  65 +} // namespace sherpa_onnx
  66 +
  67 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  1 +// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
  2 +//
  3 +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
  4 +
  5 +#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
  6 +
  7 +#include "gtest/gtest.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +TEST(SpeakerEmbeddingManager, AddAndRemove) {
  12 + int32_t dim = 2;
  13 + SpeakerEmbeddingManager manager(dim);
  14 + std::vector<float> v = {0.1, 0.1};
  15 + bool status = manager.Add("first", v.data());
  16 + ASSERT_TRUE(status);
  17 + ASSERT_EQ(manager.NumSpeakers(), 1);
  18 +
  19 + // duplicate
  20 + status = manager.Add("first", v.data());
  21 + ASSERT_FALSE(status);
  22 + ASSERT_EQ(manager.NumSpeakers(), 1);
  23 +
  24 + // non-duplicate
  25 + v = {0.1, 0.9};
  26 + status = manager.Add("second", v.data());
  27 + ASSERT_TRUE(status);
  28 + ASSERT_EQ(manager.NumSpeakers(), 2);
  29 +
  30 + // do not exist
  31 + status = manager.Remove("third");
  32 + ASSERT_FALSE(status);
  33 +
  34 + status = manager.Remove("first");
  35 + ASSERT_TRUE(status);
  36 + ASSERT_EQ(manager.NumSpeakers(), 1);
  37 +
  38 + v = {0.1, 0.1};
  39 + status = manager.Add("first", v.data());
  40 + ASSERT_TRUE(status);
  41 + ASSERT_EQ(manager.NumSpeakers(), 2);
  42 +
  43 + status = manager.Remove("first");
  44 + ASSERT_TRUE(status);
  45 + ASSERT_EQ(manager.NumSpeakers(), 1);
  46 +
  47 + status = manager.Remove("second");
  48 + ASSERT_TRUE(status);
  49 + ASSERT_EQ(manager.NumSpeakers(), 0);
  50 +}
  51 +
  52 +TEST(SpeakerEmbeddingManager, Search) {
  53 + int32_t dim = 2;
  54 + SpeakerEmbeddingManager manager(dim);
  55 + std::vector<float> v1 = {0.1, 0.1};
  56 + std::vector<float> v2 = {0.1, 0.9};
  57 + std::vector<float> v3 = {0.9, 0.1};
  58 + bool status = manager.Add("first", v1.data());
  59 + ASSERT_TRUE(status);
  60 +
  61 + status = manager.Add("second", v2.data());
  62 + ASSERT_TRUE(status);
  63 +
  64 + status = manager.Add("third", v3.data());
  65 + ASSERT_TRUE(status);
  66 +
  67 + ASSERT_EQ(manager.NumSpeakers(), 3);
  68 +
  69 + std::vector<float> v = {15, 16};
  70 + float threshold = 0.9;
  71 +
  72 + std::string name = manager.Search(v.data(), threshold);
  73 + EXPECT_EQ(name, "first");
  74 +
  75 + v = {2, 17};
  76 + name = manager.Search(v.data(), threshold);
  77 + EXPECT_EQ(name, "second");
  78 +
  79 + v = {17, 2};
  80 + name = manager.Search(v.data(), threshold);
  81 + EXPECT_EQ(name, "third");
  82 +
  83 + threshold = 0.9;
  84 + v = {15, 16};
  85 + status = manager.Remove("first");
  86 + ASSERT_TRUE(status);
  87 + name = manager.Search(v.data(), threshold);
  88 + EXPECT_EQ(name, "");
  89 +
  90 + v = {17, 2};
  91 + status = manager.Remove("third");
  92 + ASSERT_TRUE(status);
  93 + name = manager.Search(v.data(), threshold);
  94 + EXPECT_EQ(name, "");
  95 +
  96 + v = {2, 17};
  97 + status = manager.Remove("second");
  98 + ASSERT_TRUE(status);
  99 + name = manager.Search(v.data(), threshold);
  100 + EXPECT_EQ(name, "");
  101 +
  102 + ASSERT_EQ(manager.NumSpeakers(), 0);
  103 +}
  104 +
  105 +TEST(SpeakerEmbeddingManager, Verify) {
  106 + int32_t dim = 2;
  107 + SpeakerEmbeddingManager manager(dim);
  108 + std::vector<float> v1 = {0.1, 0.1};
  109 + std::vector<float> v2 = {0.1, 0.9};
  110 + std::vector<float> v3 = {0.9, 0.1};
  111 + bool status = manager.Add("first", v1.data());
  112 + ASSERT_TRUE(status);
  113 +
  114 + status = manager.Add("second", v2.data());
  115 + ASSERT_TRUE(status);
  116 +
  117 + status = manager.Add("third", v3.data());
  118 + ASSERT_TRUE(status);
  119 +
  120 + std::vector<float> v = {15, 16};
  121 + float threshold = 0.9;
  122 +
  123 + status = manager.Verify("first", v.data(), threshold);
  124 + ASSERT_TRUE(status);
  125 +
  126 + v = {2, 17};
  127 + status = manager.Verify("first", v.data(), threshold);
  128 + ASSERT_FALSE(status);
  129 +
  130 + status = manager.Verify("second", v.data(), threshold);
  131 + ASSERT_TRUE(status);
  132 +
  133 + v = {17, 2};
  134 + status = manager.Verify("first", v.data(), threshold);
  135 + ASSERT_FALSE(status);
  136 +
  137 + status = manager.Verify("second", v.data(), threshold);
  138 + ASSERT_FALSE(status);
  139 +
  140 + status = manager.Verify("third", v.data(), threshold);
  141 + ASSERT_TRUE(status);
  142 +
  143 + status = manager.Verify("fourth", v.data(), threshold);
  144 + ASSERT_FALSE(status);
  145 +}
  146 +
  147 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-manager.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
  6 +
  7 +#include <algorithm>
  8 +#include <unordered_map>
  9 +
  10 +#include "Eigen/Dense"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +using FloatMatrix =
  15 + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
  16 +
  17 +class SpeakerEmbeddingManager::Impl {
  18 + public:
  19 + explicit Impl(int32_t dim) : dim_(dim) {}
  20 +
  21 + bool Add(const std::string &name, const float *p) {
  22 + if (name2row_.count(name)) {
  23 + // a speaker with the same name already exists
  24 + return false;
  25 + }
  26 +
  27 + embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
  28 +
  29 + std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
  30 +
  31 + embedding_matrix_.bottomRows(1).normalize(); // inplace
  32 +
  33 + name2row_[name] = embedding_matrix_.rows() - 1;
  34 + row2name_[embedding_matrix_.rows() - 1] = name;
  35 +
  36 + return true;
  37 + }
  38 +
  39 + bool Remove(const std::string &name) {
  40 + if (!name2row_.count(name)) {
  41 + return false;
  42 + }
  43 +
  44 + int32_t row_idx = name2row_.at(name);
  45 +
  46 + int32_t num_rows = embedding_matrix_.rows();
  47 +
  48 + if (row_idx < num_rows - 1) {
  49 + embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
  50 + embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
  51 + }
  52 +
  53 + embedding_matrix_.conservativeResize(num_rows - 1, dim_);
  54 + for (auto &p : name2row_) {
  55 + if (p.second > row_idx) {
  56 + p.second -= 1;
  57 + row2name_[p.second] = p.first;
  58 + }
  59 + }
  60 +
  61 + name2row_.erase(name);
  62 + row2name_.erase(num_rows - 1);
  63 +
  64 + return true;
  65 + }
  66 +
  67 + std::string Search(const float *p, float threshold) {
  68 + if (embedding_matrix_.rows() == 0) {
  69 + return {};
  70 + }
  71 +
  72 + Eigen::VectorXf v =
  73 + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
  74 + v.normalize();
  75 +
  76 + Eigen::VectorXf scores = embedding_matrix_ * v;
  77 +
  78 + Eigen::VectorXf::Index max_index;
  79 + float max_score = scores.maxCoeff(&max_index);
  80 + if (max_score < threshold) {
  81 + return {};
  82 + }
  83 +
  84 + return row2name_.at(max_index);
  85 + }
  86 +
  87 + bool Verify(const std::string &name, const float *p, float threshold) {
  88 + if (!name2row_.count(name)) {
  89 + return false;
  90 + }
  91 +
  92 + int32_t row_idx = name2row_.at(name);
  93 +
  94 + Eigen::VectorXf v =
  95 + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
  96 + v.normalize();
  97 +
  98 + float score = embedding_matrix_.row(row_idx) * v;
  99 +
  100 + if (score < threshold) {
  101 + return false;
  102 + }
  103 +
  104 + return true;
  105 + }
  106 +
  107 + int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
  108 +
  109 + private:
  110 + int32_t dim_;
  111 + FloatMatrix embedding_matrix_;
  112 + std::unordered_map<std::string, int32_t> name2row_;
  113 + std::unordered_map<int32_t, std::string> row2name_;
  114 +};
  115 +
  116 +SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
  117 + : impl_(std::make_unique<Impl>(dim)) {}
  118 +
  119 +SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
  120 +
  121 +bool SpeakerEmbeddingManager::Add(const std::string &name,
  122 + const float *p) const {
  123 + return impl_->Add(name, p);
  124 +}
  125 +
  126 +bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
  127 + return impl_->Remove(name);
  128 +}
  129 +
  130 +std::string SpeakerEmbeddingManager::Search(const float *p,
  131 + float threshold) const {
  132 + return impl_->Search(p, threshold);
  133 +}
  134 +
  135 +bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
  136 + float threshold) const {
  137 + return impl_->Verify(name, p, threshold);
  138 +}
  139 +
  140 +int32_t SpeakerEmbeddingManager::NumSpeakers() const {
  141 + return impl_->NumSpeakers();
  142 +}
  143 +
  144 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/speaker-embedding-manager.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
  6 +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +class SpeakerEmbeddingManager {
  14 + public:
  15 + // @param dim Embedding dimension.
  16 + explicit SpeakerEmbeddingManager(int32_t dim);
  17 + ~SpeakerEmbeddingManager();
  18 +
  19 + /* Add the embedding and name of a speaker to the manager.
  20 + *
  21 + * @param name Name of the speaker
  22 + * @param p Pointer to the embedding. Its length is `dim`.
  23 + * @return Return true if added successfully. Return false if it failed.
  24 + * At present, the only reason for a failure is that there is already
  25 + * a speaker with the same `name`.
  26 + */
  27 + bool Add(const std::string &name, const float *p) const;
  28 +
  29 + /* Remove a speaker by its name.
  30 + *
  31 + * @param name Name of the speaker to remove.
  32 + * @return Return true if it is removed successfully. Return false
  33 + * if there is no such a speaker.
  34 + */
  35 + bool Remove(const std::string &name) const;
  36 +
  37 + /** It is for speaker identification.
  38 + *
  39 + * It computes the cosine similarity between and given embedding and all
  40 + * other embeddings and find the embedding that has the largest score
  41 + * and the score is above or equal to threshold. Return the speaker
  42 + * name for the embedding if found; otherwise, it returns an empty string.
  43 + *
  44 + * @param p The input embedding.
  45 + * @param threshold A value between 0 and 1.
  46 + * @param If found, return the name of the speaker. Otherwise, return an
  47 + * empty string.
  48 + */
  49 + std::string Search(const float *p, float threshold) const;
  50 +
  51 + /* Check whether the input embedding matches the embedding of the input
  52 + * speaker.
  53 + *
  54 + * It is for speaker verification.
  55 + *
  56 + * @param name The target speaker name.
  57 + * @param p The input embedding to check.
  58 + * @param threshold A value between 0 and 1.
  59 + * @return Return true if it matches. Otherwise, it returns false.
  60 + */
  61 + bool Verify(const std::string &name, const float *p, float threshold) const;
  62 +
  63 + int32_t NumSpeakers() const;
  64 +
  65 + private:
  66 + class Impl;
  67 + std::unique_ptr<Impl> impl_;
  68 +};
  69 +
  70 +} // namespace sherpa_onnx
  71 +
  72 +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
@@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl { @@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
40 40
41 for (int32_t i = 0; i != k; ++i, p += window_size) { 41 for (int32_t i = 0; i != k; ++i, p += window_size) {
42 buffer_.Push(p, window_size); 42 buffer_.Push(p, window_size);
43 - is_speech = is_speech || model_->IsSpeech(p, window_size); 43 + // NOTE(fangjun): Please don't use a very large n.
  44 + bool this_window_is_speech = model_->IsSpeech(p, window_size);
  45 + is_speech = is_speech || this_window_is_speech;
44 } 46 }
45 47
46 last_ = std::vector<float>( 48 last_ = std::vector<float>(
@@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl { @@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
102 104
103 bool IsSpeechDetected() const { return start_ != -1; } 105 bool IsSpeechDetected() const { return start_ != -1; }
104 106
  107 + const VadModelConfig &GetConfig() const { return config_; }
  108 +
105 private: 109 private:
106 std::queue<SpeechSegment> segments_; 110 std::queue<SpeechSegment> segments_;
107 111
@@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const { @@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
146 return impl_->IsSpeechDetected(); 150 return impl_->IsSpeechDetected();
147 } 151 }
148 152
  153 +const VadModelConfig &VoiceActivityDetector::GetConfig() const {
  154 + return impl_->GetConfig();
  155 +}
  156 +
149 } // namespace sherpa_onnx 157 } // namespace sherpa_onnx
@@ -43,6 +43,8 @@ class VoiceActivityDetector { @@ -43,6 +43,8 @@ class VoiceActivityDetector {
43 43
44 void Reset(); 44 void Reset();
45 45
  46 + const VadModelConfig &GetConfig() const;
  47 +
46 private: 48 private:
47 class Impl; 49 class Impl;
48 std::unique_ptr<Impl> impl_; 50 std::unique_ptr<Impl> impl_;
@@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx @@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx
30 online-zipformer2-ctc-model-config.cc 30 online-zipformer2-ctc-model-config.cc
31 sherpa-onnx.cc 31 sherpa-onnx.cc
32 silero-vad-model-config.cc 32 silero-vad-model-config.cc
  33 + speaker-embedding-extractor.cc
  34 + speaker-embedding-manager.cc
33 vad-model-config.cc 35 vad-model-config.cc
34 vad-model.cc 36 vad-model.cc
35 voice-activity-detector.cc 37 voice-activity-detector.cc
1 -// sherpa-onnx/python/csrc/online-recongizer.h 1 +// sherpa-onnx/python/csrc/online-recognizer.h
2 // 2 //
3 // Copyright (c) 2023 Xiaomi Corporation 3 // Copyright (c) 2023 Xiaomi Corporation
4 4
@@ -18,6 +18,8 @@ @@ -18,6 +18,8 @@
18 #include "sherpa-onnx/python/csrc/online-model-config.h" 18 #include "sherpa-onnx/python/csrc/online-model-config.h"
19 #include "sherpa-onnx/python/csrc/online-recognizer.h" 19 #include "sherpa-onnx/python/csrc/online-recognizer.h"
20 #include "sherpa-onnx/python/csrc/online-stream.h" 20 #include "sherpa-onnx/python/csrc/online-stream.h"
  21 +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
  22 +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
21 #include "sherpa-onnx/python/csrc/vad-model-config.h" 23 #include "sherpa-onnx/python/csrc/vad-model-config.h"
22 #include "sherpa-onnx/python/csrc/vad-model.h" 24 #include "sherpa-onnx/python/csrc/vad-model.h"
23 #include "sherpa-onnx/python/csrc/voice-activity-detector.h" 25 #include "sherpa-onnx/python/csrc/voice-activity-detector.h"
@@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
48 PybindVoiceActivityDetector(&m); 50 PybindVoiceActivityDetector(&m);
49 51
50 PybindOfflineTts(&m); 52 PybindOfflineTts(&m);
  53 + PybindSpeakerEmbeddingExtractor(&m);
  54 + PybindSpeakerEmbeddingManager(&m);
51 } 55 }
52 56
53 } // namespace sherpa_onnx 57 } // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
  14 + using PyClass = SpeakerEmbeddingExtractorConfig;
  15 + py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
  16 + .def(py::init<>())
  17 + .def(py::init<const std::string &, int32_t, bool, const std::string>(),
  18 + py::arg("model"), py::arg("num_threads") = 1,
  19 + py::arg("debug") = false, py::arg("provider") = "cpu")
  20 + .def_readwrite("model", &PyClass::model)
  21 + .def_readwrite("num_threads", &PyClass::num_threads)
  22 + .def_readwrite("debug", &PyClass::debug)
  23 + .def_readwrite("provider", &PyClass::provider)
  24 + .def("validate", &PyClass::Validate)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +void PybindSpeakerEmbeddingExtractor(py::module *m) {
  29 + PybindSpeakerEmbeddingExtractorConfig(m);
  30 +
  31 + using PyClass = SpeakerEmbeddingExtractor;
  32 + py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor")
  33 + .def(py::init<const SpeakerEmbeddingExtractorConfig &>(),
  34 + py::arg("config"), py::call_guard<py::gil_scoped_release>())
  35 + .def_property_readonly("dim", &PyClass::Dim)
  36 + .def("create_stream", &PyClass::CreateStream,
  37 + py::call_guard<py::gil_scoped_release>())
  38 + .def("compute", &PyClass::Compute,
  39 + py::call_guard<py::gil_scoped_release>())
  40 + .def("is_ready", &PyClass::IsReady,
  41 + py::call_guard<py::gil_scoped_release>());
  42 +}
  43 +
  44 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/speaker-embedding-extractor.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindSpeakerEmbeddingExtractor(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
  1 +// sherpa-onnx/python/csrc/speaker-embedding-manager.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindSpeakerEmbeddingManager(py::module *m) {
  15 + using PyClass = SpeakerEmbeddingManager;
  16 + py::class_<PyClass>(*m, "SpeakerEmbeddingManager")
  17 + .def(py::init<int32_t>(), py::arg("dim"),
  18 + py::call_guard<py::gil_scoped_release>())
  19 + .def_property_readonly("num_speakers", &PyClass::NumSpeakers)
  20 + .def(
  21 + "add",
  22 + [](const PyClass &self, const std::string &name,
  23 + const std::vector<float> &v) -> bool {
  24 + return self.Add(name, v.data());
  25 + },
  26 + py::arg("name"), py::arg("v"),
  27 + py::call_guard<py::gil_scoped_release>())
  28 + .def(
  29 + "remove",
  30 + [](const PyClass &self, const std::string &name) -> bool {
  31 + return self.Remove(name);
  32 + },
  33 + py::arg("name"), py::call_guard<py::gil_scoped_release>())
  34 + .def(
  35 + "search",
  36 + [](const PyClass &self, const std::vector<float> &v, float threshold)
  37 + -> std::string { return self.Search(v.data(), threshold); },
  38 + py::arg("v"), py::arg("threshold"),
  39 + py::call_guard<py::gil_scoped_release>())
  40 + .def(
  41 + "verify",
  42 + [](const PyClass &self, const std::string &name,
  43 + const std::vector<float> &v, float threshold) -> bool {
  44 + return self.Verify(name, v.data(), threshold);
  45 + },
  46 + py::arg("name"), py::arg("v"), py::arg("threshold"),
  47 + py::call_guard<py::gil_scoped_release>());
  48 +}
  49 +
  50 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/speaker-embedding-manager.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindSpeakerEmbeddingManager(py::module *m);
  13 +
  14 +} // namespace sherpa_onnx
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
@@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) { @@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) {
32 self.AcceptWaveform(samples.data(), samples.size()); 32 self.AcceptWaveform(samples.data(), samples.size());
33 }, 33 },
34 py::arg("samples"), py::call_guard<py::gil_scoped_release>()) 34 py::arg("samples"), py::call_guard<py::gil_scoped_release>())
  35 + .def_property_readonly("config", &PyClass::GetConfig)
35 .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>()) 36 .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
36 .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>()) 37 .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
37 .def("is_speech_detected", &PyClass::IsSpeechDetected, 38 .def("is_speech_detected", &PyClass::IsSpeechDetected,
@@ -8,6 +8,9 @@ from _sherpa_onnx import ( @@ -8,6 +8,9 @@ from _sherpa_onnx import (
8 OfflineTtsVitsModelConfig, 8 OfflineTtsVitsModelConfig,
9 OnlineStream, 9 OnlineStream,
10 SileroVadModelConfig, 10 SileroVadModelConfig,
  11 + SpeakerEmbeddingExtractor,
  12 + SpeakerEmbeddingExtractorConfig,
  13 + SpeakerEmbeddingManager,
11 SpeechSegment, 14 SpeechSegment,
12 VadModel, 15 VadModel,
13 VadModelConfig, 16 VadModelConfig,