speaker-embedding-manager.cc
3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// sherpa-onnx/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include <algorithm>
#include <unordered_map>
#include "Eigen/Dense"
namespace sherpa_onnx {
using FloatMatrix =
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
class SpeakerEmbeddingManager::Impl {
public:
explicit Impl(int32_t dim) : dim_(dim) {}
bool Add(const std::string &name, const float *p) {
if (name2row_.count(name)) {
// a speaker with the same name already exists
return false;
}
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
embedding_matrix_.bottomRows(1).normalize(); // inplace
name2row_[name] = embedding_matrix_.rows() - 1;
row2name_[embedding_matrix_.rows() - 1] = name;
return true;
}
bool Remove(const std::string &name) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
int32_t num_rows = embedding_matrix_.rows();
if (row_idx < num_rows - 1) {
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
}
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
for (auto &p : name2row_) {
if (p.second > row_idx) {
p.second -= 1;
row2name_[p.second] = p.first;
}
}
name2row_.erase(name);
row2name_.erase(num_rows - 1);
return true;
}
std::string Search(const float *p, float threshold) {
if (embedding_matrix_.rows() == 0) {
return {};
}
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
Eigen::VectorXf scores = embedding_matrix_ * v;
Eigen::VectorXf::Index max_index;
float max_score = scores.maxCoeff(&max_index);
if (max_score < threshold) {
return {};
}
return row2name_.at(max_index);
}
bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
float score = embedding_matrix_.row(row_idx) * v;
if (score < threshold) {
return false;
}
return true;
}
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
private:
int32_t dim_;
FloatMatrix embedding_matrix_;
std::unordered_map<std::string, int32_t> name2row_;
std::unordered_map<int32_t, std::string> row2name_;
};
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
: impl_(std::make_unique<Impl>(dim)) {}
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
bool SpeakerEmbeddingManager::Add(const std::string &name,
const float *p) const {
return impl_->Add(name, p);
}
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
return impl_->Remove(name);
}
std::string SpeakerEmbeddingManager::Search(const float *p,
float threshold) const {
return impl_->Search(p, threshold);
}
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);
}
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
return impl_->NumSpeakers();
}
} // namespace sherpa_onnx