Committed by
GitHub
Add runtime support for wespeaker models (#516)
正在显示
27 个修改的文件
包含
1291 行增加
和
4 行删除
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This script shows how to use Python APIs for speaker identification. | ||
| 5 | + | ||
| 6 | +Usage: | ||
| 7 | + | ||
| 8 | +(1) Prepare a text file containing speaker related files. | ||
| 9 | + | ||
| 10 | +Each line in the text file contains two columns. The first column is the | ||
| 11 | +speaker name, while the second column contains the wave file of the speaker. | ||
| 12 | + | ||
| 13 | +If the text file contains multiple wave files for the same speaker, then the | ||
| 14 | +embeddings of these files are averaged. | ||
| 15 | + | ||
| 16 | +An example text file is given below: | ||
| 17 | + | ||
| 18 | + foo /path/to/a.wav | ||
| 19 | + bar /path/to/b.wav | ||
| 20 | + foo /path/to/c.wav | ||
| 21 | + foobar /path/to/d.wav | ||
| 22 | + | ||
| 23 | +Each wave file should contain only a single channel; the sample format | ||
| 24 | +should be int16_t; the sample rate can be arbitrary. | ||
| 25 | + | ||
| 26 | +(2) Download a model for computing speaker embeddings | ||
| 27 | + | ||
| 28 | +Please visit | ||
| 29 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models | ||
| 30 | +to download a model. An example is given below: | ||
| 31 | + | ||
| 32 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx | ||
| 33 | + | ||
| 34 | +Note that `zh` means Chinese, while `en` means English. | ||
| 35 | + | ||
| 36 | +(3) Run this script | ||
| 37 | + | ||
| 38 | +Assume the filename of the text file is speaker.txt. | ||
| 39 | + | ||
| 40 | +python3 ./python-api-examples/speaker-identification.py \ | ||
| 41 | + --speaker-file ./speaker.txt \ | ||
| 42 | + --model ./zh_cnceleb_resnet34.onnx | ||
| 43 | +""" | ||
| 44 | +import argparse | ||
| 45 | +import queue | ||
| 46 | +import threading | ||
| 47 | +from collections import defaultdict | ||
| 48 | +from pathlib import Path | ||
| 49 | +from typing import Dict, List, Tuple | ||
| 50 | + | ||
| 51 | +import numpy as np | ||
| 52 | +import sherpa_onnx | ||
| 53 | +import torchaudio | ||
| 54 | + | ||
| 55 | +try: | ||
| 56 | + import sounddevice as sd | ||
| 57 | +except ImportError: | ||
| 58 | + print("Please install sounddevice first. You can use") | ||
| 59 | + print() | ||
| 60 | + print(" pip install sounddevice") | ||
| 61 | + print() | ||
| 62 | + print("to install it") | ||
| 63 | + sys.exit(-1) | ||
| 64 | + | ||
| 65 | + | ||
| 66 | +def get_args(): | ||
| 67 | + parser = argparse.ArgumentParser( | ||
| 68 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 69 | + ) | ||
| 70 | + | ||
| 71 | + parser.add_argument( | ||
| 72 | + "--speaker-file", | ||
| 73 | + type=str, | ||
| 74 | + required=True, | ||
| 75 | + help="""Path to the speaker file. Read the help doc at the beginning of this | ||
| 76 | + file for the format.""", | ||
| 77 | + ) | ||
| 78 | + | ||
| 79 | + parser.add_argument( | ||
| 80 | + "--model", | ||
| 81 | + type=str, | ||
| 82 | + required=True, | ||
| 83 | + help="Path to the model file.", | ||
| 84 | + ) | ||
| 85 | + | ||
| 86 | + parser.add_argument("--threshold", type=float, default=0.6) | ||
| 87 | + | ||
| 88 | + parser.add_argument( | ||
| 89 | + "--num-threads", | ||
| 90 | + type=int, | ||
| 91 | + default=1, | ||
| 92 | + help="Number of threads for neural network computation", | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + parser.add_argument( | ||
| 96 | + "--debug", | ||
| 97 | + type=bool, | ||
| 98 | + default=False, | ||
| 99 | + help="True to show debug messages", | ||
| 100 | + ) | ||
| 101 | + | ||
| 102 | + parser.add_argument( | ||
| 103 | + "--provider", | ||
| 104 | + type=str, | ||
| 105 | + default="cpu", | ||
| 106 | + help="Valid values: cpu, cuda, coreml", | ||
| 107 | + ) | ||
| 108 | + | ||
| 109 | + return parser.parse_args() | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +def load_speaker_embedding_model(args): | ||
| 113 | + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( | ||
| 114 | + model=args.model, | ||
| 115 | + num_threads=args.num_threads, | ||
| 116 | + debug=args.debug, | ||
| 117 | + provider=args.provider, | ||
| 118 | + ) | ||
| 119 | + if not config.validate(): | ||
| 120 | + raise ValueError(f"Invalid config. {config}") | ||
| 121 | + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) | ||
| 122 | + return extractor | ||
| 123 | + | ||
| 124 | + | ||
| 125 | +def load_speaker_file(args) -> Dict[str, List[str]]: | ||
| 126 | + if not Path(args.speaker_file).is_file(): | ||
| 127 | + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") | ||
| 128 | + | ||
| 129 | + ans = defaultdict(list) | ||
| 130 | + with open(args.speaker_file) as f: | ||
| 131 | + for line in f: | ||
| 132 | + line = line.strip() | ||
| 133 | + if not line: | ||
| 134 | + continue | ||
| 135 | + | ||
| 136 | + fields = line.split() | ||
| 137 | + if len(fields) != 2: | ||
| 138 | + raise ValueError(f"Invalid line: {line}. Fields: {fields}") | ||
| 139 | + | ||
| 140 | + speaker_name, filename = fields | ||
| 141 | + ans[speaker_name].append(filename) | ||
| 142 | + return ans | ||
| 143 | + | ||
| 144 | + | ||
| 145 | +def load_audio(filename: str) -> Tuple[np.ndarray, int]: | ||
| 146 | + samples, sample_rate = torchaudio.load(filename) | ||
| 147 | + return samples[0].contiguous().numpy(), sample_rate | ||
| 148 | + | ||
| 149 | + | ||
| 150 | +def compute_speaker_embedding( | ||
| 151 | + filenames: List[str], | ||
| 152 | + extractor: sherpa_onnx.SpeakerEmbeddingExtractor, | ||
| 153 | +) -> np.ndarray: | ||
| 154 | + assert len(filenames) > 0, f"filenames is empty" | ||
| 155 | + | ||
| 156 | + ans = None | ||
| 157 | + for filename in filenames: | ||
| 158 | + print(f"processing {filename}") | ||
| 159 | + samples, sample_rate = load_audio(filename) | ||
| 160 | + stream = extractor.create_stream() | ||
| 161 | + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) | ||
| 162 | + stream.input_finished() | ||
| 163 | + | ||
| 164 | + assert extractor.is_ready(stream) | ||
| 165 | + embedding = extractor.compute(stream) | ||
| 166 | + embedding = np.array(embedding) | ||
| 167 | + if ans is None: | ||
| 168 | + ans = embedding | ||
| 169 | + else: | ||
| 170 | + ans += embedding | ||
| 171 | + | ||
| 172 | + return ans / len(filenames) | ||
| 173 | + | ||
| 174 | + | ||
| 175 | +g_buffer = queue.Queue() | ||
| 176 | +g_stop = False | ||
| 177 | +g_sample_rate = 16000 | ||
| 178 | +g_read_mic_thread = None | ||
| 179 | + | ||
| 180 | + | ||
| 181 | +def read_mic(): | ||
| 182 | + print("Please speak!") | ||
| 183 | + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms | ||
| 184 | + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: | ||
| 185 | + while not g_stop: | ||
| 186 | + samples, _ = s.read(samples_per_read) # a blocking read | ||
| 187 | + g_buffer.put(samples) | ||
| 188 | + | ||
| 189 | + | ||
| 190 | +def main(): | ||
| 191 | + args = get_args() | ||
| 192 | + print(args) | ||
| 193 | + extractor = load_speaker_embedding_model(args) | ||
| 194 | + speaker_file = load_speaker_file(args) | ||
| 195 | + | ||
| 196 | + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) | ||
| 197 | + for name, filename_list in speaker_file.items(): | ||
| 198 | + embedding = compute_speaker_embedding( | ||
| 199 | + filenames=filename_list, | ||
| 200 | + extractor=extractor, | ||
| 201 | + ) | ||
| 202 | + status = manager.add(name, embedding) | ||
| 203 | + if not status: | ||
| 204 | + raise RuntimeError(f"Failed to register speaker {name}") | ||
| 205 | + | ||
| 206 | + devices = sd.query_devices() | ||
| 207 | + if len(devices) == 0: | ||
| 208 | + print("No microphone devices found") | ||
| 209 | + sys.exit(0) | ||
| 210 | + | ||
| 211 | + print(devices) | ||
| 212 | + default_input_device_idx = sd.default.device[0] | ||
| 213 | + print(f'Use default device: {devices[default_input_device_idx]["name"]}') | ||
| 214 | + | ||
| 215 | + global g_stop | ||
| 216 | + global g_read_mic_thread | ||
| 217 | + while True: | ||
| 218 | + key = input("Press enter to start recording") | ||
| 219 | + if key.lower() in ("q", "quit"): | ||
| 220 | + g_stop = True | ||
| 221 | + break | ||
| 222 | + | ||
| 223 | + g_stop = False | ||
| 224 | + g_buffer.queue.clear() | ||
| 225 | + g_read_mic_thread = threading.Thread(target=read_mic) | ||
| 226 | + g_read_mic_thread.start() | ||
| 227 | + input("Press enter to stop recording") | ||
| 228 | + g_stop = True | ||
| 229 | + g_read_mic_thread.join() | ||
| 230 | + print("Compute embedding") | ||
| 231 | + stream = extractor.create_stream() | ||
| 232 | + while not g_buffer.empty(): | ||
| 233 | + samples = g_buffer.get() | ||
| 234 | + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples) | ||
| 235 | + stream.input_finished() | ||
| 236 | + | ||
| 237 | + embedding = extractor.compute(stream) | ||
| 238 | + embedding = np.array(embedding) | ||
| 239 | + name = manager.search(embedding, threshold=args.threshold) | ||
| 240 | + if not name: | ||
| 241 | + name = "unknown" | ||
| 242 | + print(f"Predicted name: {name}") | ||
| 243 | + | ||
| 244 | + | ||
| 245 | +if __name__ == "__main__": | ||
| 246 | + try: | ||
| 247 | + main() | ||
| 248 | + except KeyboardInterrupt: | ||
| 249 | + print("\nCaught Ctrl + C. Exiting") | ||
| 250 | + g_stop = True | ||
| 251 | + if g_read_mic_thread.is_alive(): | ||
| 252 | + g_read_mic_thread.join() |
| @@ -96,6 +96,14 @@ set(sources | @@ -96,6 +96,14 @@ set(sources | ||
| 96 | wave-reader.cc | 96 | wave-reader.cc |
| 97 | ) | 97 | ) |
| 98 | 98 | ||
| 99 | +# speaker embedding extractor | ||
| 100 | +list(APPEND sources | ||
| 101 | + speaker-embedding-extractor-impl.cc | ||
| 102 | + speaker-embedding-extractor-wespeaker-model.cc | ||
| 103 | + speaker-embedding-extractor.cc | ||
| 104 | + speaker-embedding-manager.cc | ||
| 105 | +) | ||
| 106 | + | ||
| 99 | list(APPEND sources | 107 | list(APPEND sources |
| 100 | lexicon.cc | 108 | lexicon.cc |
| 101 | offline-tts-impl.cc | 109 | offline-tts-impl.cc |
| @@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 387 | utfcpp-test.cc | 395 | utfcpp-test.cc |
| 388 | ) | 396 | ) |
| 389 | 397 | ||
| 398 | + list(APPEND sherpa_onnx_test_srcs | ||
| 399 | + speaker-embedding-manager-test.cc | ||
| 400 | + ) | ||
| 401 | + | ||
| 390 | function(sherpa_onnx_add_test source) | 402 | function(sherpa_onnx_add_test source) |
| 391 | get_filename_component(name ${source} NAME_WE) | 403 | get_filename_component(name ${source} NAME_WE) |
| 392 | set(target_name ${name}) | 404 | set(target_name ${name}) |
| @@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) { | @@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) { | ||
| 64 | auto stop = std::chrono::high_resolution_clock::now(); | 64 | auto stop = std::chrono::high_resolution_clock::now(); |
| 65 | auto duration = | 65 | auto duration = |
| 66 | std::chrono::duration_cast<std::chrono::microseconds>(stop - start); | 66 | std::chrono::duration_cast<std::chrono::microseconds>(stop - start); |
| 67 | - SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num, | ||
| 68 | - duration.count()); | 67 | + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num, |
| 68 | + static_cast<int32_t>(duration.count())); | ||
| 69 | } | 69 | } |
| 70 | } | 70 | } |
| 71 | 71 |
| @@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { | @@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { | ||
| 91 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 91 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 92 | } | 92 | } |
| 93 | 93 | ||
| 94 | +Ort::SessionOptions GetSessionOptions( | ||
| 95 | + const SpeakerEmbeddingExtractorConfig &config) { | ||
| 96 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 97 | +} | ||
| 98 | + | ||
| 94 | } // namespace sherpa_onnx | 99 | } // namespace sherpa_onnx |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-tts-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/online-lm-config.h" | 12 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 13 | #include "sherpa-onnx/csrc/online-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 14 | #include "sherpa-onnx/csrc/vad-model-config.h" | 15 | #include "sherpa-onnx/csrc/vad-model-config.h" |
| 15 | 16 | ||
| 16 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| @@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | @@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | ||
| 26 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); | 27 | Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); |
| 27 | 28 | ||
| 28 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | 29 | Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); |
| 30 | + | ||
| 31 | +Ort::SessionOptions GetSessionOptions( | ||
| 32 | + const SpeakerEmbeddingExtractorConfig &config); | ||
| 29 | } // namespace sherpa_onnx | 33 | } // namespace sherpa_onnx |
| 30 | 34 | ||
| 31 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ | 35 | #endif // SHERPA_ONNX_CSRC_SESSION_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 7 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +namespace { | ||
| 13 | + | ||
| 14 | +enum class ModelType { | ||
| 15 | + kWeSpeaker, | ||
| 16 | + kUnkown, | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +} // namespace | ||
| 20 | + | ||
| 21 | +static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 22 | + bool debug) { | ||
| 23 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 24 | + Ort::SessionOptions sess_opts; | ||
| 25 | + | ||
| 26 | + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, | ||
| 27 | + sess_opts); | ||
| 28 | + | ||
| 29 | + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 30 | + if (debug) { | ||
| 31 | + std::ostringstream os; | ||
| 32 | + PrintModelMetadata(os, meta_data); | ||
| 33 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 37 | + auto model_type = | ||
| 38 | + meta_data.LookupCustomMetadataMapAllocated("framework", allocator); | ||
| 39 | + if (!model_type) { | ||
| 40 | + SHERPA_ONNX_LOGE( | ||
| 41 | + "No model_type in the metadata!\n" | ||
| 42 | + "Please make sure you have added metadata to the model.\n\n" | ||
| 43 | + "For instance, you can use\n" | ||
| 44 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" | ||
| 45 | + "add_meta_data.py" | ||
| 46 | + "to add metadata to models from WeSpeaker\n"); | ||
| 47 | + return ModelType::kUnkown; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + if (model_type.get() == std::string("wespeaker")) { | ||
| 51 | + return ModelType::kWeSpeaker; | ||
| 52 | + } else { | ||
| 53 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | ||
| 54 | + return ModelType::kUnkown; | ||
| 55 | + } | ||
| 56 | +} | ||
| 57 | + | ||
| 58 | +std::unique_ptr<SpeakerEmbeddingExtractorImpl> | ||
| 59 | +SpeakerEmbeddingExtractorImpl::Create( | ||
| 60 | + const SpeakerEmbeddingExtractorConfig &config) { | ||
| 61 | + ModelType model_type = ModelType::kUnkown; | ||
| 62 | + | ||
| 63 | + { | ||
| 64 | + auto buffer = ReadFile(config.model); | ||
| 65 | + | ||
| 66 | + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + switch (model_type) { | ||
| 70 | + case ModelType::kWeSpeaker: | ||
| 71 | + return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config); | ||
| 72 | + case ModelType::kUnkown: | ||
| 73 | + SHERPA_ONNX_LOGE( | ||
| 74 | + "Unknown model type in for speaker embedding extractor!"); | ||
| 75 | + return nullptr; | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + // unreachable code | ||
| 79 | + return nullptr; | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class SpeakerEmbeddingExtractorImpl { | ||
| 17 | + public: | ||
| 18 | + virtual ~SpeakerEmbeddingExtractorImpl() = default; | ||
| 19 | + | ||
| 20 | + static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create( | ||
| 21 | + const SpeakerEmbeddingExtractorConfig &config); | ||
| 22 | + | ||
| 23 | + virtual int32_t Dim() const = 0; | ||
| 24 | + | ||
| 25 | + virtual std::unique_ptr<OnlineStream> CreateStream() const = 0; | ||
| 26 | + | ||
| 27 | + virtual bool IsReady(OnlineStream *s) const = 0; | ||
| 28 | + | ||
| 29 | + virtual std::vector<float> Compute(OnlineStream *s) const = 0; | ||
| 30 | +}; | ||
| 31 | + | ||
| 32 | +} // namespace sherpa_onnx | ||
| 33 | + | ||
| 34 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
| 13 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class SpeakerEmbeddingExtractorWeSpeakerImpl | ||
| 18 | + : public SpeakerEmbeddingExtractorImpl { | ||
| 19 | + public: | ||
| 20 | + explicit SpeakerEmbeddingExtractorWeSpeakerImpl( | ||
| 21 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 22 | + : model_(config) {} | ||
| 23 | + | ||
| 24 | + int32_t Dim() const override { return model_.GetMetaData().output_dim; } | ||
| 25 | + | ||
| 26 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 27 | + FeatureExtractorConfig feat_config; | ||
| 28 | + auto meta_data = model_.GetMetaData(); | ||
| 29 | + feat_config.sampling_rate = meta_data.sample_rate; | ||
| 30 | + feat_config.normalize_samples = meta_data.normalize_features; | ||
| 31 | + | ||
| 32 | + return std::make_unique<OnlineStream>(feat_config); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + bool IsReady(OnlineStream *s) const override { | ||
| 36 | + return s->GetNumProcessedFrames() < s->NumFramesReady(); | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + std::vector<float> Compute(OnlineStream *s) const override { | ||
| 40 | + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); | ||
| 41 | + if (num_frames <= 0) { | ||
| 42 | + SHERPA_ONNX_LOGE( | ||
| 43 | + "Please make sure IsReady(s) returns true. num_frames: %d", | ||
| 44 | + num_frames); | ||
| 45 | + return {}; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + std::vector<float> features = | ||
| 49 | + s->GetFrames(s->GetNumProcessedFrames(), num_frames); | ||
| 50 | + | ||
| 51 | + s->GetNumProcessedFrames() += num_frames; | ||
| 52 | + | ||
| 53 | + int32_t feat_dim = features.size() / num_frames; | ||
| 54 | + | ||
| 55 | + auto memory_info = | ||
| 56 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 57 | + | ||
| 58 | + std::array<int64_t, 3> x_shape{1, num_frames, feat_dim}; | ||
| 59 | + Ort::Value x = | ||
| 60 | + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 61 | + x_shape.data(), x_shape.size()); | ||
| 62 | + Ort::Value embedding = model_.Compute(std::move(x)); | ||
| 63 | + std::vector<int64_t> embedding_shape = | ||
| 64 | + embedding.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 65 | + | ||
| 66 | + std::vector<float> ans(embedding_shape[1]); | ||
| 67 | + std::copy(embedding.GetTensorData<float>(), | ||
| 68 | + embedding.GetTensorData<float>() + ans.size(), ans.begin()); | ||
| 69 | + | ||
| 70 | + return ans; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + private: | ||
| 74 | + SpeakerEmbeddingExtractorWeSpeakerModel model_; | ||
| 75 | +}; | ||
| 76 | + | ||
| 77 | +} // namespace sherpa_onnx | ||
| 78 | + | ||
| 79 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ | ||
| 6 | + | ||
| 7 | +#include <cstdint> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { | ||
| 13 | + int32_t output_dim = 0; | ||
| 14 | + int32_t sample_rate = 0; | ||
| 15 | + int32_t normalize_features = 0; | ||
| 16 | + std::string language; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +} // namespace sherpa_onnx | ||
| 20 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/session.h" | ||
| 14 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { | ||
| 19 | + public: | ||
| 20 | + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) | ||
| 21 | + : config_(config), | ||
| 22 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 23 | + sess_opts_(GetSessionOptions(config)), | ||
| 24 | + allocator_{} { | ||
| 25 | + { | ||
| 26 | + auto buf = ReadFile(config.model); | ||
| 27 | + Init(buf.data(), buf.size()); | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + Ort::Value Compute(Ort::Value x) const { | ||
| 32 | + std::array<Ort::Value, 1> inputs = {std::move(x)}; | ||
| 33 | + | ||
| 34 | + auto outputs = | ||
| 35 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 36 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 37 | + return std::move(outputs[0]); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const { | ||
| 41 | + return meta_data_; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + private: | ||
| 45 | + void Init(void *model_data, size_t model_data_length) { | ||
| 46 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 47 | + sess_opts_); | ||
| 48 | + | ||
| 49 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 50 | + | ||
| 51 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 52 | + | ||
| 53 | + // get meta data | ||
| 54 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 55 | + if (config_.debug) { | ||
| 56 | + std::ostringstream os; | ||
| 57 | + PrintModelMetadata(os, meta_data); | ||
| 58 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 62 | + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); | ||
| 63 | + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | ||
| 64 | + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features, | ||
| 65 | + "normalize_features"); | ||
| 66 | + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); | ||
| 67 | + | ||
| 68 | + std::string framework; | ||
| 69 | + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); | ||
| 70 | + if (framework != "wespeaker") { | ||
| 71 | + SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s", | ||
| 72 | + framework.c_str()); | ||
| 73 | + exit(-1); | ||
| 74 | + } | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + private: | ||
| 78 | + SpeakerEmbeddingExtractorConfig config_; | ||
| 79 | + Ort::Env env_; | ||
| 80 | + Ort::SessionOptions sess_opts_; | ||
| 81 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 82 | + | ||
| 83 | + std::unique_ptr<Ort::Session> sess_; | ||
| 84 | + | ||
| 85 | + std::vector<std::string> input_names_; | ||
| 86 | + std::vector<const char *> input_names_ptr_; | ||
| 87 | + | ||
| 88 | + std::vector<std::string> output_names_; | ||
| 89 | + std::vector<const char *> output_names_ptr_; | ||
| 90 | + | ||
| 91 | + SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_; | ||
| 92 | +}; | ||
| 93 | + | ||
| 94 | +SpeakerEmbeddingExtractorWeSpeakerModel:: | ||
| 95 | + SpeakerEmbeddingExtractorWeSpeakerModel( | ||
| 96 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 97 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 98 | + | ||
| 99 | +SpeakerEmbeddingExtractorWeSpeakerModel:: | ||
| 100 | + ~SpeakerEmbeddingExtractorWeSpeakerModel() = default; | ||
| 101 | + | ||
| 102 | +const SpeakerEmbeddingExtractorWeSpeakerModelMetaData & | ||
| 103 | +SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const { | ||
| 104 | + return impl_->GetMetaData(); | ||
| 105 | +} | ||
| 106 | + | ||
| 107 | +Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute( | ||
| 108 | + Ort::Value x) const { | ||
| 109 | + return impl_->Compute(std::move(x)); | ||
| 110 | +} | ||
| 111 | + | ||
| 112 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" | ||
| 11 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class SpeakerEmbeddingExtractorWeSpeakerModel { | ||
| 16 | + public: | ||
| 17 | + explicit SpeakerEmbeddingExtractorWeSpeakerModel( | ||
| 18 | + const SpeakerEmbeddingExtractorConfig &config); | ||
| 19 | + | ||
| 20 | + ~SpeakerEmbeddingExtractorWeSpeakerModel(); | ||
| 21 | + | ||
| 22 | + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const; | ||
| 23 | + | ||
| 24 | + /** | ||
| 25 | + * @param x A float32 tensor of shape (N, T, C) | ||
| 26 | + * @return A float32 tensor of shape (N, C) | ||
| 27 | + */ | ||
| 28 | + Ort::Value Compute(Ort::Value x) const; | ||
| 29 | + | ||
| 30 | + private: | ||
| 31 | + class Impl; | ||
| 32 | + std::unique_ptr<Impl> impl_; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx | ||
| 36 | + | ||
| 37 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { | ||
| 16 | + po->Register("model", &model, "Path to the speaker embedding model."); | ||
| 17 | + po->Register("num-threads", &num_threads, | ||
| 18 | + "Number of threads to run the neural network"); | ||
| 19 | + | ||
| 20 | + po->Register("debug", &debug, | ||
| 21 | + "true to print model information while loading it."); | ||
| 22 | + | ||
| 23 | + po->Register("provider", &provider, | ||
| 24 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 25 | +} | ||
| 26 | + | ||
| 27 | +bool SpeakerEmbeddingExtractorConfig::Validate() const { | ||
| 28 | + if (model.empty()) { | ||
| 29 | + SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model"); | ||
| 30 | + return false; | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + if (!FileExists(model)) { | ||
| 34 | + SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist", | ||
| 35 | + model.c_str()); | ||
| 36 | + return false; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + return true; | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +std::string SpeakerEmbeddingExtractorConfig::ToString() const { | ||
| 43 | + std::ostringstream os; | ||
| 44 | + | ||
| 45 | + os << "SpeakerEmbeddingExtractorConfig("; | ||
| 46 | + os << "model=\"" << model << "\", "; | ||
| 47 | + os << "num_threads=" << num_threads << ", "; | ||
| 48 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 49 | + os << "provider=\"" << provider << "\")"; | ||
| 50 | + | ||
| 51 | + return os.str(); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( | ||
| 55 | + const SpeakerEmbeddingExtractorConfig &config) | ||
| 56 | + : impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {} | ||
| 57 | + | ||
| 58 | +SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default; | ||
| 59 | + | ||
| 60 | +int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); } | ||
| 61 | + | ||
| 62 | +std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const { | ||
| 63 | + return impl_->CreateStream(); | ||
| 64 | +} | ||
| 65 | + | ||
| 66 | +bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const { | ||
| 67 | + return impl_->IsReady(s); | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const { | ||
| 71 | + return impl_->Compute(s); | ||
| 72 | +} | ||
| 73 | + | ||
| 74 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-extractor.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct SpeakerEmbeddingExtractorConfig { | ||
| 18 | + std::string model; | ||
| 19 | + int32_t num_threads = 1; | ||
| 20 | + bool debug = false; | ||
| 21 | + std::string provider = "cpu"; | ||
| 22 | + | ||
| 23 | + SpeakerEmbeddingExtractorConfig() = default; | ||
| 24 | + SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads, | ||
| 25 | + bool debug, const std::string &provider) | ||
| 26 | + : model(model), | ||
| 27 | + num_threads(num_threads), | ||
| 28 | + debug(debug), | ||
| 29 | + provider(provider) {} | ||
| 30 | + | ||
| 31 | + void Register(ParseOptions *po); | ||
| 32 | + bool Validate() const; | ||
| 33 | + std::string ToString() const; | ||
| 34 | +}; | ||
| 35 | + | ||
| 36 | +class SpeakerEmbeddingExtractorImpl; | ||
| 37 | + | ||
| 38 | +class SpeakerEmbeddingExtractor { | ||
| 39 | + public: | ||
| 40 | + explicit SpeakerEmbeddingExtractor( | ||
| 41 | + const SpeakerEmbeddingExtractorConfig &config); | ||
| 42 | + | ||
| 43 | + ~SpeakerEmbeddingExtractor(); | ||
| 44 | + | ||
| 45 | + // Return the dimension of the embedding | ||
| 46 | + int32_t Dim() const; | ||
| 47 | + | ||
| 48 | + // Create a stream to accept audio samples and compute features | ||
| 49 | + std::unique_ptr<OnlineStream> CreateStream() const; | ||
| 50 | + | ||
| 51 | + // Return true if there are feature frames in OnlineStream that | ||
| 52 | + // can be used to compute embeddings. | ||
| 53 | + bool IsReady(OnlineStream *s) const; | ||
| 54 | + | ||
| 55 | + // Compute the speaker embedding from the available unprocessed features | ||
| 56 | + // of the given stream | ||
| 57 | + // | ||
| 58 | + // You have to ensure IsReady(s) returns true before you call this method. | ||
| 59 | + std::vector<float> Compute(OnlineStream *s) const; | ||
| 60 | + | ||
| 61 | + private: | ||
| 62 | + std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx | ||
| 66 | + | ||
| 67 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-manager-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" | ||
| 6 | + | ||
| 7 | +#include "gtest/gtest.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +TEST(SpeakerEmbeddingManager, AddAndRemove) { | ||
| 12 | + int32_t dim = 2; | ||
| 13 | + SpeakerEmbeddingManager manager(dim); | ||
| 14 | + std::vector<float> v = {0.1, 0.1}; | ||
| 15 | + bool status = manager.Add("first", v.data()); | ||
| 16 | + ASSERT_TRUE(status); | ||
| 17 | + ASSERT_EQ(manager.NumSpeakers(), 1); | ||
| 18 | + | ||
| 19 | + // duplicate | ||
| 20 | + status = manager.Add("first", v.data()); | ||
| 21 | + ASSERT_FALSE(status); | ||
| 22 | + ASSERT_EQ(manager.NumSpeakers(), 1); | ||
| 23 | + | ||
| 24 | + // non-duplicate | ||
| 25 | + v = {0.1, 0.9}; | ||
| 26 | + status = manager.Add("second", v.data()); | ||
| 27 | + ASSERT_TRUE(status); | ||
| 28 | + ASSERT_EQ(manager.NumSpeakers(), 2); | ||
| 29 | + | ||
| 30 | + // do not exist | ||
| 31 | + status = manager.Remove("third"); | ||
| 32 | + ASSERT_FALSE(status); | ||
| 33 | + | ||
| 34 | + status = manager.Remove("first"); | ||
| 35 | + ASSERT_TRUE(status); | ||
| 36 | + ASSERT_EQ(manager.NumSpeakers(), 1); | ||
| 37 | + | ||
| 38 | + v = {0.1, 0.1}; | ||
| 39 | + status = manager.Add("first", v.data()); | ||
| 40 | + ASSERT_TRUE(status); | ||
| 41 | + ASSERT_EQ(manager.NumSpeakers(), 2); | ||
| 42 | + | ||
| 43 | + status = manager.Remove("first"); | ||
| 44 | + ASSERT_TRUE(status); | ||
| 45 | + ASSERT_EQ(manager.NumSpeakers(), 1); | ||
| 46 | + | ||
| 47 | + status = manager.Remove("second"); | ||
| 48 | + ASSERT_TRUE(status); | ||
| 49 | + ASSERT_EQ(manager.NumSpeakers(), 0); | ||
| 50 | +} | ||
| 51 | + | ||
| 52 | +TEST(SpeakerEmbeddingManager, Search) { | ||
| 53 | + int32_t dim = 2; | ||
| 54 | + SpeakerEmbeddingManager manager(dim); | ||
| 55 | + std::vector<float> v1 = {0.1, 0.1}; | ||
| 56 | + std::vector<float> v2 = {0.1, 0.9}; | ||
| 57 | + std::vector<float> v3 = {0.9, 0.1}; | ||
| 58 | + bool status = manager.Add("first", v1.data()); | ||
| 59 | + ASSERT_TRUE(status); | ||
| 60 | + | ||
| 61 | + status = manager.Add("second", v2.data()); | ||
| 62 | + ASSERT_TRUE(status); | ||
| 63 | + | ||
| 64 | + status = manager.Add("third", v3.data()); | ||
| 65 | + ASSERT_TRUE(status); | ||
| 66 | + | ||
| 67 | + ASSERT_EQ(manager.NumSpeakers(), 3); | ||
| 68 | + | ||
| 69 | + std::vector<float> v = {15, 16}; | ||
| 70 | + float threshold = 0.9; | ||
| 71 | + | ||
| 72 | + std::string name = manager.Search(v.data(), threshold); | ||
| 73 | + EXPECT_EQ(name, "first"); | ||
| 74 | + | ||
| 75 | + v = {2, 17}; | ||
| 76 | + name = manager.Search(v.data(), threshold); | ||
| 77 | + EXPECT_EQ(name, "second"); | ||
| 78 | + | ||
| 79 | + v = {17, 2}; | ||
| 80 | + name = manager.Search(v.data(), threshold); | ||
| 81 | + EXPECT_EQ(name, "third"); | ||
| 82 | + | ||
| 83 | + threshold = 0.9; | ||
| 84 | + v = {15, 16}; | ||
| 85 | + status = manager.Remove("first"); | ||
| 86 | + ASSERT_TRUE(status); | ||
| 87 | + name = manager.Search(v.data(), threshold); | ||
| 88 | + EXPECT_EQ(name, ""); | ||
| 89 | + | ||
| 90 | + v = {17, 2}; | ||
| 91 | + status = manager.Remove("third"); | ||
| 92 | + ASSERT_TRUE(status); | ||
| 93 | + name = manager.Search(v.data(), threshold); | ||
| 94 | + EXPECT_EQ(name, ""); | ||
| 95 | + | ||
| 96 | + v = {2, 17}; | ||
| 97 | + status = manager.Remove("second"); | ||
| 98 | + ASSERT_TRUE(status); | ||
| 99 | + name = manager.Search(v.data(), threshold); | ||
| 100 | + EXPECT_EQ(name, ""); | ||
| 101 | + | ||
| 102 | + ASSERT_EQ(manager.NumSpeakers(), 0); | ||
| 103 | +} | ||
| 104 | + | ||
| 105 | +TEST(SpeakerEmbeddingManager, Verify) { | ||
| 106 | + int32_t dim = 2; | ||
| 107 | + SpeakerEmbeddingManager manager(dim); | ||
| 108 | + std::vector<float> v1 = {0.1, 0.1}; | ||
| 109 | + std::vector<float> v2 = {0.1, 0.9}; | ||
| 110 | + std::vector<float> v3 = {0.9, 0.1}; | ||
| 111 | + bool status = manager.Add("first", v1.data()); | ||
| 112 | + ASSERT_TRUE(status); | ||
| 113 | + | ||
| 114 | + status = manager.Add("second", v2.data()); | ||
| 115 | + ASSERT_TRUE(status); | ||
| 116 | + | ||
| 117 | + status = manager.Add("third", v3.data()); | ||
| 118 | + ASSERT_TRUE(status); | ||
| 119 | + | ||
| 120 | + std::vector<float> v = {15, 16}; | ||
| 121 | + float threshold = 0.9; | ||
| 122 | + | ||
| 123 | + status = manager.Verify("first", v.data(), threshold); | ||
| 124 | + ASSERT_TRUE(status); | ||
| 125 | + | ||
| 126 | + v = {2, 17}; | ||
| 127 | + status = manager.Verify("first", v.data(), threshold); | ||
| 128 | + ASSERT_FALSE(status); | ||
| 129 | + | ||
| 130 | + status = manager.Verify("second", v.data(), threshold); | ||
| 131 | + ASSERT_TRUE(status); | ||
| 132 | + | ||
| 133 | + v = {17, 2}; | ||
| 134 | + status = manager.Verify("first", v.data(), threshold); | ||
| 135 | + ASSERT_FALSE(status); | ||
| 136 | + | ||
| 137 | + status = manager.Verify("second", v.data(), threshold); | ||
| 138 | + ASSERT_FALSE(status); | ||
| 139 | + | ||
| 140 | + status = manager.Verify("third", v.data(), threshold); | ||
| 141 | + ASSERT_TRUE(status); | ||
| 142 | + | ||
| 143 | + status = manager.Verify("fourth", v.data(), threshold); | ||
| 144 | + ASSERT_FALSE(status); | ||
| 145 | +} | ||
| 146 | + | ||
| 147 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/speaker-embedding-manager.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <unordered_map> | ||
| 9 | + | ||
| 10 | +#include "Eigen/Dense" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +using FloatMatrix = | ||
| 15 | + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; | ||
| 16 | + | ||
| 17 | +class SpeakerEmbeddingManager::Impl { | ||
| 18 | + public: | ||
| 19 | + explicit Impl(int32_t dim) : dim_(dim) {} | ||
| 20 | + | ||
| 21 | + bool Add(const std::string &name, const float *p) { | ||
| 22 | + if (name2row_.count(name)) { | ||
| 23 | + // a speaker with the same name already exists | ||
| 24 | + return false; | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_); | ||
| 28 | + | ||
| 29 | + std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0)); | ||
| 30 | + | ||
| 31 | + embedding_matrix_.bottomRows(1).normalize(); // inplace | ||
| 32 | + | ||
| 33 | + name2row_[name] = embedding_matrix_.rows() - 1; | ||
| 34 | + row2name_[embedding_matrix_.rows() - 1] = name; | ||
| 35 | + | ||
| 36 | + return true; | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + bool Remove(const std::string &name) { | ||
| 40 | + if (!name2row_.count(name)) { | ||
| 41 | + return false; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + int32_t row_idx = name2row_.at(name); | ||
| 45 | + | ||
| 46 | + int32_t num_rows = embedding_matrix_.rows(); | ||
| 47 | + | ||
| 48 | + if (row_idx < num_rows - 1) { | ||
| 49 | + embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) = | ||
| 50 | + embedding_matrix_.bottomRows(num_rows - 1 - row_idx); | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + embedding_matrix_.conservativeResize(num_rows - 1, dim_); | ||
| 54 | + for (auto &p : name2row_) { | ||
| 55 | + if (p.second > row_idx) { | ||
| 56 | + p.second -= 1; | ||
| 57 | + row2name_[p.second] = p.first; | ||
| 58 | + } | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + name2row_.erase(name); | ||
| 62 | + row2name_.erase(num_rows - 1); | ||
| 63 | + | ||
| 64 | + return true; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + std::string Search(const float *p, float threshold) { | ||
| 68 | + if (embedding_matrix_.rows() == 0) { | ||
| 69 | + return {}; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + Eigen::VectorXf v = | ||
| 73 | + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_); | ||
| 74 | + v.normalize(); | ||
| 75 | + | ||
| 76 | + Eigen::VectorXf scores = embedding_matrix_ * v; | ||
| 77 | + | ||
| 78 | + Eigen::VectorXf::Index max_index; | ||
| 79 | + float max_score = scores.maxCoeff(&max_index); | ||
| 80 | + if (max_score < threshold) { | ||
| 81 | + return {}; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + return row2name_.at(max_index); | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + bool Verify(const std::string &name, const float *p, float threshold) { | ||
| 88 | + if (!name2row_.count(name)) { | ||
| 89 | + return false; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + int32_t row_idx = name2row_.at(name); | ||
| 93 | + | ||
| 94 | + Eigen::VectorXf v = | ||
| 95 | + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_); | ||
| 96 | + v.normalize(); | ||
| 97 | + | ||
| 98 | + float score = embedding_matrix_.row(row_idx) * v; | ||
| 99 | + | ||
| 100 | + if (score < threshold) { | ||
| 101 | + return false; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + return true; | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + int32_t NumSpeakers() const { return embedding_matrix_.rows(); } | ||
| 108 | + | ||
| 109 | + private: | ||
| 110 | + int32_t dim_; | ||
| 111 | + FloatMatrix embedding_matrix_; | ||
| 112 | + std::unordered_map<std::string, int32_t> name2row_; | ||
| 113 | + std::unordered_map<int32_t, std::string> row2name_; | ||
| 114 | +}; | ||
| 115 | + | ||
| 116 | +SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim) | ||
| 117 | + : impl_(std::make_unique<Impl>(dim)) {} | ||
| 118 | + | ||
| 119 | +SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default; | ||
| 120 | + | ||
| 121 | +bool SpeakerEmbeddingManager::Add(const std::string &name, | ||
| 122 | + const float *p) const { | ||
| 123 | + return impl_->Add(name, p); | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +bool SpeakerEmbeddingManager::Remove(const std::string &name) const { | ||
| 127 | + return impl_->Remove(name); | ||
| 128 | +} | ||
| 129 | + | ||
| 130 | +std::string SpeakerEmbeddingManager::Search(const float *p, | ||
| 131 | + float threshold) const { | ||
| 132 | + return impl_->Search(p, threshold); | ||
| 133 | +} | ||
| 134 | + | ||
| 135 | +bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, | ||
| 136 | + float threshold) const { | ||
| 137 | + return impl_->Verify(name, p, threshold); | ||
| 138 | +} | ||
| 139 | + | ||
| 140 | +int32_t SpeakerEmbeddingManager::NumSpeakers() const { | ||
| 141 | + return impl_->NumSpeakers(); | ||
| 142 | +} | ||
| 143 | + | ||
| 144 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/speaker-embedding-manager.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/speaker-embedding-manager.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +class SpeakerEmbeddingManager { | ||
| 14 | + public: | ||
| 15 | + // @param dim Embedding dimension. | ||
| 16 | + explicit SpeakerEmbeddingManager(int32_t dim); | ||
| 17 | + ~SpeakerEmbeddingManager(); | ||
| 18 | + | ||
| 19 | + /* Add the embedding and name of a speaker to the manager. | ||
| 20 | + * | ||
| 21 | + * @param name Name of the speaker | ||
| 22 | + * @param p Pointer to the embedding. Its length is `dim`. | ||
| 23 | + * @return Return true if added successfully. Return false if it failed. | ||
| 24 | + * At present, the only reason for a failure is that there is already | ||
| 25 | + * a speaker with the same `name`. | ||
| 26 | + */ | ||
| 27 | + bool Add(const std::string &name, const float *p) const; | ||
| 28 | + | ||
| 29 | + /* Remove a speaker by its name. | ||
| 30 | + * | ||
| 31 | + * @param name Name of the speaker to remove. | ||
| 32 | + * @return Return true if it is removed successfully. Return false | ||
| 33 | + * if there is no such a speaker. | ||
| 34 | + */ | ||
| 35 | + bool Remove(const std::string &name) const; | ||
| 36 | + | ||
| 37 | + /** It is for speaker identification. | ||
| 38 | + * | ||
| 39 | + * It computes the cosine similarity between and given embedding and all | ||
| 40 | + * other embeddings and find the embedding that has the largest score | ||
| 41 | + * and the score is above or equal to threshold. Return the speaker | ||
| 42 | + * name for the embedding if found; otherwise, it returns an empty string. | ||
| 43 | + * | ||
| 44 | + * @param p The input embedding. | ||
| 45 | + * @param threshold A value between 0 and 1. | ||
| 46 | + * @param If found, return the name of the speaker. Otherwise, return an | ||
| 47 | + * empty string. | ||
| 48 | + */ | ||
| 49 | + std::string Search(const float *p, float threshold) const; | ||
| 50 | + | ||
| 51 | + /* Check whether the input embedding matches the embedding of the input | ||
| 52 | + * speaker. | ||
| 53 | + * | ||
| 54 | + * It is for speaker verification. | ||
| 55 | + * | ||
| 56 | + * @param name The target speaker name. | ||
| 57 | + * @param p The input embedding to check. | ||
| 58 | + * @param threshold A value between 0 and 1. | ||
| 59 | + * @return Return true if it matches. Otherwise, it returns false. | ||
| 60 | + */ | ||
| 61 | + bool Verify(const std::string &name, const float *p, float threshold) const; | ||
| 62 | + | ||
| 63 | + int32_t NumSpeakers() const; | ||
| 64 | + | ||
| 65 | + private: | ||
| 66 | + class Impl; | ||
| 67 | + std::unique_ptr<Impl> impl_; | ||
| 68 | +}; | ||
| 69 | + | ||
| 70 | +} // namespace sherpa_onnx | ||
| 71 | + | ||
| 72 | +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ |
| @@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl { | @@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl { | ||
| 40 | 40 | ||
| 41 | for (int32_t i = 0; i != k; ++i, p += window_size) { | 41 | for (int32_t i = 0; i != k; ++i, p += window_size) { |
| 42 | buffer_.Push(p, window_size); | 42 | buffer_.Push(p, window_size); |
| 43 | - is_speech = is_speech || model_->IsSpeech(p, window_size); | 43 | + // NOTE(fangjun): Please don't use a very large n. |
| 44 | + bool this_window_is_speech = model_->IsSpeech(p, window_size); | ||
| 45 | + is_speech = is_speech || this_window_is_speech; | ||
| 44 | } | 46 | } |
| 45 | 47 | ||
| 46 | last_ = std::vector<float>( | 48 | last_ = std::vector<float>( |
| @@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl { | @@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl { | ||
| 102 | 104 | ||
| 103 | bool IsSpeechDetected() const { return start_ != -1; } | 105 | bool IsSpeechDetected() const { return start_ != -1; } |
| 104 | 106 | ||
| 107 | + const VadModelConfig &GetConfig() const { return config_; } | ||
| 108 | + | ||
| 105 | private: | 109 | private: |
| 106 | std::queue<SpeechSegment> segments_; | 110 | std::queue<SpeechSegment> segments_; |
| 107 | 111 | ||
| @@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const { | @@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const { | ||
| 146 | return impl_->IsSpeechDetected(); | 150 | return impl_->IsSpeechDetected(); |
| 147 | } | 151 | } |
| 148 | 152 | ||
| 153 | +const VadModelConfig &VoiceActivityDetector::GetConfig() const { | ||
| 154 | + return impl_->GetConfig(); | ||
| 155 | +} | ||
| 156 | + | ||
| 149 | } // namespace sherpa_onnx | 157 | } // namespace sherpa_onnx |
| @@ -43,6 +43,8 @@ class VoiceActivityDetector { | @@ -43,6 +43,8 @@ class VoiceActivityDetector { | ||
| 43 | 43 | ||
| 44 | void Reset(); | 44 | void Reset(); |
| 45 | 45 | ||
| 46 | + const VadModelConfig &GetConfig() const; | ||
| 47 | + | ||
| 46 | private: | 48 | private: |
| 47 | class Impl; | 49 | class Impl; |
| 48 | std::unique_ptr<Impl> impl_; | 50 | std::unique_ptr<Impl> impl_; |
| @@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx | @@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx | ||
| 30 | online-zipformer2-ctc-model-config.cc | 30 | online-zipformer2-ctc-model-config.cc |
| 31 | sherpa-onnx.cc | 31 | sherpa-onnx.cc |
| 32 | silero-vad-model-config.cc | 32 | silero-vad-model-config.cc |
| 33 | + speaker-embedding-extractor.cc | ||
| 34 | + speaker-embedding-manager.cc | ||
| 33 | vad-model-config.cc | 35 | vad-model-config.cc |
| 34 | vad-model.cc | 36 | vad-model.cc |
| 35 | voice-activity-detector.cc | 37 | voice-activity-detector.cc |
| @@ -18,6 +18,8 @@ | @@ -18,6 +18,8 @@ | ||
| 18 | #include "sherpa-onnx/python/csrc/online-model-config.h" | 18 | #include "sherpa-onnx/python/csrc/online-model-config.h" |
| 19 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 19 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| 20 | #include "sherpa-onnx/python/csrc/online-stream.h" | 20 | #include "sherpa-onnx/python/csrc/online-stream.h" |
| 21 | +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" | ||
| 22 | +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" | ||
| 21 | #include "sherpa-onnx/python/csrc/vad-model-config.h" | 23 | #include "sherpa-onnx/python/csrc/vad-model-config.h" |
| 22 | #include "sherpa-onnx/python/csrc/vad-model.h" | 24 | #include "sherpa-onnx/python/csrc/vad-model.h" |
| 23 | #include "sherpa-onnx/python/csrc/voice-activity-detector.h" | 25 | #include "sherpa-onnx/python/csrc/voice-activity-detector.h" |
| @@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 48 | PybindVoiceActivityDetector(&m); | 50 | PybindVoiceActivityDetector(&m); |
| 49 | 51 | ||
| 50 | PybindOfflineTts(&m); | 52 | PybindOfflineTts(&m); |
| 53 | + PybindSpeakerEmbeddingExtractor(&m); | ||
| 54 | + PybindSpeakerEmbeddingManager(&m); | ||
| 51 | } | 55 | } |
| 52 | 56 | ||
| 53 | } // namespace sherpa_onnx | 57 | } // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { | ||
| 14 | + using PyClass = SpeakerEmbeddingExtractorConfig; | ||
| 15 | + py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig") | ||
| 16 | + .def(py::init<>()) | ||
| 17 | + .def(py::init<const std::string &, int32_t, bool, const std::string>(), | ||
| 18 | + py::arg("model"), py::arg("num_threads") = 1, | ||
| 19 | + py::arg("debug") = false, py::arg("provider") = "cpu") | ||
| 20 | + .def_readwrite("model", &PyClass::model) | ||
| 21 | + .def_readwrite("num_threads", &PyClass::num_threads) | ||
| 22 | + .def_readwrite("debug", &PyClass::debug) | ||
| 23 | + .def_readwrite("provider", &PyClass::provider) | ||
| 24 | + .def("validate", &PyClass::Validate) | ||
| 25 | + .def("__str__", &PyClass::ToString); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +void PybindSpeakerEmbeddingExtractor(py::module *m) { | ||
| 29 | + PybindSpeakerEmbeddingExtractorConfig(m); | ||
| 30 | + | ||
| 31 | + using PyClass = SpeakerEmbeddingExtractor; | ||
| 32 | + py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor") | ||
| 33 | + .def(py::init<const SpeakerEmbeddingExtractorConfig &>(), | ||
| 34 | + py::arg("config"), py::call_guard<py::gil_scoped_release>()) | ||
| 35 | + .def_property_readonly("dim", &PyClass::Dim) | ||
| 36 | + .def("create_stream", &PyClass::CreateStream, | ||
| 37 | + py::call_guard<py::gil_scoped_release>()) | ||
| 38 | + .def("compute", &PyClass::Compute, | ||
| 39 | + py::call_guard<py::gil_scoped_release>()) | ||
| 40 | + .def("is_ready", &PyClass::IsReady, | ||
| 41 | + py::call_guard<py::gil_scoped_release>()); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/speaker-embedding-extractor.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindSpeakerEmbeddingExtractor(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ |
| 1 | +// sherpa-onnx/python/csrc/speaker-embedding-manager.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindSpeakerEmbeddingManager(py::module *m) { | ||
| 15 | + using PyClass = SpeakerEmbeddingManager; | ||
| 16 | + py::class_<PyClass>(*m, "SpeakerEmbeddingManager") | ||
| 17 | + .def(py::init<int32_t>(), py::arg("dim"), | ||
| 18 | + py::call_guard<py::gil_scoped_release>()) | ||
| 19 | + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) | ||
| 20 | + .def( | ||
| 21 | + "add", | ||
| 22 | + [](const PyClass &self, const std::string &name, | ||
| 23 | + const std::vector<float> &v) -> bool { | ||
| 24 | + return self.Add(name, v.data()); | ||
| 25 | + }, | ||
| 26 | + py::arg("name"), py::arg("v"), | ||
| 27 | + py::call_guard<py::gil_scoped_release>()) | ||
| 28 | + .def( | ||
| 29 | + "remove", | ||
| 30 | + [](const PyClass &self, const std::string &name) -> bool { | ||
| 31 | + return self.Remove(name); | ||
| 32 | + }, | ||
| 33 | + py::arg("name"), py::call_guard<py::gil_scoped_release>()) | ||
| 34 | + .def( | ||
| 35 | + "search", | ||
| 36 | + [](const PyClass &self, const std::vector<float> &v, float threshold) | ||
| 37 | + -> std::string { return self.Search(v.data(), threshold); }, | ||
| 38 | + py::arg("v"), py::arg("threshold"), | ||
| 39 | + py::call_guard<py::gil_scoped_release>()) | ||
| 40 | + .def( | ||
| 41 | + "verify", | ||
| 42 | + [](const PyClass &self, const std::string &name, | ||
| 43 | + const std::vector<float> &v, float threshold) -> bool { | ||
| 44 | + return self.Verify(name, v.data(), threshold); | ||
| 45 | + }, | ||
| 46 | + py::arg("name"), py::arg("v"), py::arg("threshold"), | ||
| 47 | + py::call_guard<py::gil_scoped_release>()); | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/speaker-embedding-manager.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindSpeakerEmbeddingManager(py::module *m); | ||
| 13 | + | ||
| 14 | +} // namespace sherpa_onnx | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ |
| @@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) { | @@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) { | ||
| 32 | self.AcceptWaveform(samples.data(), samples.size()); | 32 | self.AcceptWaveform(samples.data(), samples.size()); |
| 33 | }, | 33 | }, |
| 34 | py::arg("samples"), py::call_guard<py::gil_scoped_release>()) | 34 | py::arg("samples"), py::call_guard<py::gil_scoped_release>()) |
| 35 | + .def_property_readonly("config", &PyClass::GetConfig) | ||
| 35 | .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>()) | 36 | .def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>()) |
| 36 | .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>()) | 37 | .def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>()) |
| 37 | .def("is_speech_detected", &PyClass::IsSpeechDetected, | 38 | .def("is_speech_detected", &PyClass::IsSpeechDetected, |
| @@ -8,6 +8,9 @@ from _sherpa_onnx import ( | @@ -8,6 +8,9 @@ from _sherpa_onnx import ( | ||
| 8 | OfflineTtsVitsModelConfig, | 8 | OfflineTtsVitsModelConfig, |
| 9 | OnlineStream, | 9 | OnlineStream, |
| 10 | SileroVadModelConfig, | 10 | SileroVadModelConfig, |
| 11 | + SpeakerEmbeddingExtractor, | ||
| 12 | + SpeakerEmbeddingExtractorConfig, | ||
| 13 | + SpeakerEmbeddingManager, | ||
| 11 | SpeechSegment, | 14 | SpeechSegment, |
| 12 | VadModel, | 15 | VadModel, |
| 13 | VadModelConfig, | 16 | VadModelConfig, |
-
请 注册 或 登录 后发表评论