Fangjun Kuang
Committed by GitHub

Add runtime support for wespeaker models (#516)

#!/usr/bin/env python3
"""
This script shows how to use Python APIs for speaker identification.
Usage:
(1) Prepare a text file containing speaker related files.
Each line in the text file contains two columns. The first column is the
speaker name, while the second column contains the wave file of the speaker.
If the text file contains multiple wave files for the same speaker, then the
embeddings of these files are averaged.
An example text file is given below:
foo /path/to/a.wav
bar /path/to/b.wav
foo /path/to/c.wav
foobar /path/to/d.wav
Each wave file should contain only a single channel; the sample format
should be int16_t; the sample rate can be arbitrary.
(2) Download a model for computing speaker embeddings
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
to download a model. An example is given below:
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx
Note that `zh` means Chinese, while `en` means English.
(3) Run this script
Assume the filename of the text file is speaker.txt.
python3 ./python-api-examples/speaker-identification.py \
--speaker-file ./speaker.txt \
--model ./zh_cnceleb_resnet34.onnx
"""
import argparse
import queue
import threading
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import sherpa_onnx
import torchaudio
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--speaker-file",
type=str,
required=True,
help="""Path to the speaker file. Read the help doc at the beginning of this
file for the format.""",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the model file.",
)
parser.add_argument("--threshold", type=float, default=0.6)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
return parser.parse_args()
def load_speaker_embedding_model(args):
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=args.model,
num_threads=args.num_threads,
debug=args.debug,
provider=args.provider,
)
if not config.validate():
raise ValueError(f"Invalid config. {config}")
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
return extractor
def load_speaker_file(args) -> Dict[str, List[str]]:
if not Path(args.speaker_file).is_file():
raise ValueError(f"--speaker-file {args.speaker_file} does not exist")
ans = defaultdict(list)
with open(args.speaker_file) as f:
for line in f:
line = line.strip()
if not line:
continue
fields = line.split()
if len(fields) != 2:
raise ValueError(f"Invalid line: {line}. Fields: {fields}")
speaker_name, filename = fields
ans[speaker_name].append(filename)
return ans
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
samples, sample_rate = torchaudio.load(filename)
return samples[0].contiguous().numpy(), sample_rate
def compute_speaker_embedding(
filenames: List[str],
extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
) -> np.ndarray:
assert len(filenames) > 0, f"filenames is empty"
ans = None
for filename in filenames:
print(f"processing {filename}")
samples, sample_rate = load_audio(filename)
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
if ans is None:
ans = embedding
else:
ans += embedding
return ans / len(filenames)
g_buffer = queue.Queue()
g_stop = False
g_sample_rate = 16000
g_read_mic_thread = None
def read_mic():
print("Please speak!")
samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s:
while not g_stop:
samples, _ = s.read(samples_per_read) # a blocking read
g_buffer.put(samples)
def main():
args = get_args()
print(args)
extractor = load_speaker_embedding_model(args)
speaker_file = load_speaker_file(args)
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
for name, filename_list in speaker_file.items():
embedding = compute_speaker_embedding(
filenames=filename_list,
extractor=extractor,
)
status = manager.add(name, embedding)
if not status:
raise RuntimeError(f"Failed to register speaker {name}")
devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
global g_stop
global g_read_mic_thread
while True:
key = input("Press enter to start recording")
if key.lower() in ("q", "quit"):
g_stop = True
break
g_stop = False
g_buffer.queue.clear()
g_read_mic_thread = threading.Thread(target=read_mic)
g_read_mic_thread.start()
input("Press enter to stop recording")
g_stop = True
g_read_mic_thread.join()
print("Compute embedding")
stream = extractor.create_stream()
while not g_buffer.empty():
samples = g_buffer.get()
stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples)
stream.input_finished()
embedding = extractor.compute(stream)
embedding = np.array(embedding)
name = manager.search(embedding, threshold=args.threshold)
if not name:
name = "unknown"
print(f"Predicted name: {name}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
g_stop = True
if g_read_mic_thread.is_alive():
g_read_mic_thread.join()
... ...
... ... @@ -96,6 +96,14 @@ set(sources
wave-reader.cc
)
# speaker embedding extractor
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-wespeaker-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
list(APPEND sources
lexicon.cc
offline-tts-impl.cc
... ... @@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
utfcpp-test.cc
)
list(APPEND sherpa_onnx_test_srcs
speaker-embedding-manager-test.cc
)
function(sherpa_onnx_add_test source)
get_filename_component(name ${source} NAME_WE)
set(target_name ${name})
... ...
... ... @@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
duration.count());
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
static_cast<int32_t>(duration.count()));
}
}
... ...
... ... @@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
... ... @@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kWeSpeaker,
kUnkown,
};
} // namespace
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
"For instance, you can use\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
"add_meta_data.py"
"to add metadata to models from WeSpeaker\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("wespeaker")) {
return ModelType::kWeSpeaker;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
{
auto buffer = ReadFile(config.model);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kWeSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
return nullptr;
}
// unreachable code
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorImpl {
public:
virtual ~SpeakerEmbeddingExtractorImpl() = default;
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
const SpeakerEmbeddingExtractorConfig &config);
virtual int32_t Dim() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual bool IsReady(OnlineStream *s) const = 0;
virtual std::vector<float> Compute(OnlineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerImpl
: public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
auto meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.normalize_samples = meta_data.normalize_features;
return std::make_unique<OnlineStream>(feat_config);
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}
std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}
std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
s->GetNumProcessedFrames() += num_frames;
int32_t feat_dim = features.size() / num_frames;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
Ort::Value embedding = model_.Compute(std::move(x));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();
std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());
return ans;
}
private:
SpeakerEmbeddingExtractorWeSpeakerModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
#include <cstdint>
#include <string>
namespace sherpa_onnx {
struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
int32_t output_dim = 0;
int32_t sample_rate = 0;
int32_t normalize_features = 0;
std::string language;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
public:
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.model);
Init(buf.data(), buf.size());
}
}
Ort::Value Compute(Ort::Value x) const {
std::array<Ort::Value, 1> inputs = {std::move(x)};
auto outputs =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(outputs[0]);
}
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
return meta_data_;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
"normalize_features");
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
std::string framework;
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
if (framework != "wespeaker") {
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
framework.c_str());
exit(-1);
}
}
private:
SpeakerEmbeddingExtractorConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
};
SpeakerEmbeddingExtractorWeSpeakerModel::
SpeakerEmbeddingExtractorWeSpeakerModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
SpeakerEmbeddingExtractorWeSpeakerModel::
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
return impl_->GetMetaData();
}
Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
Ort::Value x) const {
return impl_->Compute(std::move(x));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerModel {
public:
explicit SpeakerEmbeddingExtractorWeSpeakerModel(
const SpeakerEmbeddingExtractorConfig &config);
~SpeakerEmbeddingExtractorWeSpeakerModel();
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
/**
* @param x A float32 tensor of shape (N, T, C)
* @return A float32 tensor of shape (N, C)
*/
Ort::Value Compute(Ort::Value x) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
namespace sherpa_onnx {
void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
po->Register("model", &model, "Path to the speaker embedding model.");
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool SpeakerEmbeddingExtractorConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
model.c_str());
return false;
}
return true;
}
std::string SpeakerEmbeddingExtractorConfig::ToString() const {
std::ostringstream os;
os << "SpeakerEmbeddingExtractorConfig(";
os << "model=\"" << model << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
return impl_->CreateStream();
}
bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
return impl_->Compute(s);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/speaker-embedding-extractor.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct SpeakerEmbeddingExtractorConfig {
std::string model;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
SpeakerEmbeddingExtractorConfig() = default;
SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
bool debug, const std::string &provider)
: model(model),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class SpeakerEmbeddingExtractorImpl;
class SpeakerEmbeddingExtractor {
public:
explicit SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config);
~SpeakerEmbeddingExtractor();
// Return the dimension of the embedding
int32_t Dim() const;
// Create a stream to accept audio samples and compute features
std::unique_ptr<OnlineStream> CreateStream() const;
// Return true if there are feature frames in OnlineStream that
// can be used to compute embeddings.
bool IsReady(OnlineStream *s) const;
// Compute the speaker embedding from the available unprocessed features
// of the given stream
//
// You have to ensure IsReady(s) returns true before you call this method.
std::vector<float> Compute(OnlineStream *s) const;
private:
std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
... ...
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "gtest/gtest.h"
namespace sherpa_onnx {
TEST(SpeakerEmbeddingManager, AddAndRemove) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v = {0.1, 0.1};
bool status = manager.Add("first", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
// duplicate
status = manager.Add("first", v.data());
ASSERT_FALSE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
// non-duplicate
v = {0.1, 0.9};
status = manager.Add("second", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 2);
// do not exist
status = manager.Remove("third");
ASSERT_FALSE(status);
status = manager.Remove("first");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
v = {0.1, 0.1};
status = manager.Add("first", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 2);
status = manager.Remove("first");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
status = manager.Remove("second");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 0);
}
TEST(SpeakerEmbeddingManager, Search) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v1 = {0.1, 0.1};
std::vector<float> v2 = {0.1, 0.9};
std::vector<float> v3 = {0.9, 0.1};
bool status = manager.Add("first", v1.data());
ASSERT_TRUE(status);
status = manager.Add("second", v2.data());
ASSERT_TRUE(status);
status = manager.Add("third", v3.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 3);
std::vector<float> v = {15, 16};
float threshold = 0.9;
std::string name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "first");
v = {2, 17};
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "second");
v = {17, 2};
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "third");
threshold = 0.9;
v = {15, 16};
status = manager.Remove("first");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
v = {17, 2};
status = manager.Remove("third");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
v = {2, 17};
status = manager.Remove("second");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
ASSERT_EQ(manager.NumSpeakers(), 0);
}
TEST(SpeakerEmbeddingManager, Verify) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v1 = {0.1, 0.1};
std::vector<float> v2 = {0.1, 0.9};
std::vector<float> v3 = {0.9, 0.1};
bool status = manager.Add("first", v1.data());
ASSERT_TRUE(status);
status = manager.Add("second", v2.data());
ASSERT_TRUE(status);
status = manager.Add("third", v3.data());
ASSERT_TRUE(status);
std::vector<float> v = {15, 16};
float threshold = 0.9;
status = manager.Verify("first", v.data(), threshold);
ASSERT_TRUE(status);
v = {2, 17};
status = manager.Verify("first", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("second", v.data(), threshold);
ASSERT_TRUE(status);
v = {17, 2};
status = manager.Verify("first", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("second", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("third", v.data(), threshold);
ASSERT_TRUE(status);
status = manager.Verify("fourth", v.data(), threshold);
ASSERT_FALSE(status);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include <algorithm>
#include <unordered_map>
#include "Eigen/Dense"
namespace sherpa_onnx {
using FloatMatrix =
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
class SpeakerEmbeddingManager::Impl {
public:
explicit Impl(int32_t dim) : dim_(dim) {}
bool Add(const std::string &name, const float *p) {
if (name2row_.count(name)) {
// a speaker with the same name already exists
return false;
}
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
embedding_matrix_.bottomRows(1).normalize(); // inplace
name2row_[name] = embedding_matrix_.rows() - 1;
row2name_[embedding_matrix_.rows() - 1] = name;
return true;
}
bool Remove(const std::string &name) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
int32_t num_rows = embedding_matrix_.rows();
if (row_idx < num_rows - 1) {
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
}
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
for (auto &p : name2row_) {
if (p.second > row_idx) {
p.second -= 1;
row2name_[p.second] = p.first;
}
}
name2row_.erase(name);
row2name_.erase(num_rows - 1);
return true;
}
std::string Search(const float *p, float threshold) {
if (embedding_matrix_.rows() == 0) {
return {};
}
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
Eigen::VectorXf scores = embedding_matrix_ * v;
Eigen::VectorXf::Index max_index;
float max_score = scores.maxCoeff(&max_index);
if (max_score < threshold) {
return {};
}
return row2name_.at(max_index);
}
bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
float score = embedding_matrix_.row(row_idx) * v;
if (score < threshold) {
return false;
}
return true;
}
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
private:
int32_t dim_;
FloatMatrix embedding_matrix_;
std::unordered_map<std::string, int32_t> name2row_;
std::unordered_map<int32_t, std::string> row2name_;
};
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
: impl_(std::make_unique<Impl>(dim)) {}
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
bool SpeakerEmbeddingManager::Add(const std::string &name,
const float *p) const {
return impl_->Add(name, p);
}
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
return impl_->Remove(name);
}
std::string SpeakerEmbeddingManager::Search(const float *p,
float threshold) const {
return impl_->Search(p, threshold);
}
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);
}
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
return impl_->NumSpeakers();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/speaker-embedding-manager.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#include <memory>
#include <string>
namespace sherpa_onnx {
class SpeakerEmbeddingManager {
public:
// @param dim Embedding dimension.
explicit SpeakerEmbeddingManager(int32_t dim);
~SpeakerEmbeddingManager();
/* Add the embedding and name of a speaker to the manager.
*
* @param name Name of the speaker
* @param p Pointer to the embedding. Its length is `dim`.
* @return Return true if added successfully. Return false if it failed.
* At present, the only reason for a failure is that there is already
* a speaker with the same `name`.
*/
bool Add(const std::string &name, const float *p) const;
/* Remove a speaker by its name.
*
* @param name Name of the speaker to remove.
* @return Return true if it is removed successfully. Return false
* if there is no such a speaker.
*/
bool Remove(const std::string &name) const;
/** It is for speaker identification.
*
* It computes the cosine similarity between and given embedding and all
* other embeddings and find the embedding that has the largest score
* and the score is above or equal to threshold. Return the speaker
* name for the embedding if found; otherwise, it returns an empty string.
*
* @param p The input embedding.
* @param threshold A value between 0 and 1.
* @param If found, return the name of the speaker. Otherwise, return an
* empty string.
*/
std::string Search(const float *p, float threshold) const;
/* Check whether the input embedding matches the embedding of the input
* speaker.
*
* It is for speaker verification.
*
* @param name The target speaker name.
* @param p The input embedding to check.
* @param threshold A value between 0 and 1.
* @return Return true if it matches. Otherwise, it returns false.
*/
bool Verify(const std::string &name, const float *p, float threshold) const;
int32_t NumSpeakers() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
... ...
... ... @@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
for (int32_t i = 0; i != k; ++i, p += window_size) {
buffer_.Push(p, window_size);
is_speech = is_speech || model_->IsSpeech(p, window_size);
// NOTE(fangjun): Please don't use a very large n.
bool this_window_is_speech = model_->IsSpeech(p, window_size);
is_speech = is_speech || this_window_is_speech;
}
last_ = std::vector<float>(
... ... @@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
bool IsSpeechDetected() const { return start_ != -1; }
const VadModelConfig &GetConfig() const { return config_; }
private:
std::queue<SpeechSegment> segments_;
... ... @@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
return impl_->IsSpeechDetected();
}
const VadModelConfig &VoiceActivityDetector::GetConfig() const {
return impl_->GetConfig();
}
} // namespace sherpa_onnx
... ...
... ... @@ -43,6 +43,8 @@ class VoiceActivityDetector {
void Reset();
const VadModelConfig &GetConfig() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx
online-zipformer2-ctc-model-config.cc
sherpa-onnx.cc
silero-vad-model-config.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
vad-model-config.cc
vad-model.cc
voice-activity-detector.cc
... ...
// sherpa-onnx/python/csrc/online-recongizer.h
// sherpa-onnx/python/csrc/online-recognizer.h
//
// Copyright (c) 2023 Xiaomi Corporation
... ...
... ... @@ -18,6 +18,8 @@
#include "sherpa-onnx/python/csrc/online-model-config.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/python/csrc/vad-model-config.h"
#include "sherpa-onnx/python/csrc/vad-model.h"
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
... ... @@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindVoiceActivityDetector(&m);
PybindOfflineTts(&m);
PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include <string>
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
using PyClass = SpeakerEmbeddingExtractorConfig;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
.def(py::init<>())
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
py::arg("model"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("model", &PyClass::model)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
void PybindSpeakerEmbeddingExtractor(py::module *m) {
PybindSpeakerEmbeddingExtractorConfig(m);
using PyClass = SpeakerEmbeddingExtractor;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor")
.def(py::init<const SpeakerEmbeddingExtractorConfig &>(),
py::arg("config"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("dim", &PyClass::Dim)
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute,
py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/speaker-embedding-extractor.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingExtractor(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
... ...
// sherpa-onnx/python/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingManager(py::module *m) {
using PyClass = SpeakerEmbeddingManager;
py::class_<PyClass>(*m, "SpeakerEmbeddingManager")
.def(py::init<int32_t>(), py::arg("dim"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def(
"add",
[](const PyClass &self, const std::string &name,
const std::vector<float> &v) -> bool {
return self.Add(name, v.data());
},
py::arg("name"), py::arg("v"),
py::call_guard<py::gil_scoped_release>())
.def(
"remove",
[](const PyClass &self, const std::string &name) -> bool {
return self.Remove(name);
},
py::arg("name"), py::call_guard<py::gil_scoped_release>())
.def(
"search",
[](const PyClass &self, const std::vector<float> &v, float threshold)
-> std::string { return self.Search(v.data(), threshold); },
py::arg("v"), py::arg("threshold"),
py::call_guard<py::gil_scoped_release>())
.def(
"verify",
[](const PyClass &self, const std::string &name,
const std::vector<float> &v, float threshold) -> bool {
return self.Verify(name, v.data(), threshold);
},
py::arg("name"), py::arg("v"), py::arg("threshold"),
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/speaker-embedding-manager.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingManager(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
... ...
... ... @@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) {
self.AcceptWaveform(samples.data(), samples.size());
},
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("config", &PyClass::GetConfig)
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
.def("is_speech_detected", &PyClass::IsSpeechDetected,
... ...
... ... @@ -8,6 +8,9 @@ from _sherpa_onnx import (
OfflineTtsVitsModelConfig,
OnlineStream,
SileroVadModelConfig,
SpeakerEmbeddingExtractor,
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingManager,
SpeechSegment,
VadModel,
VadModelConfig,
... ...