thewh1teagle
Committed by GitHub

feat: find best embedding matches (#1102)

@@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) { @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
1256 delete[] name; 1256 delete[] name;
1257 } 1257 }
1258 1258
  1259 +const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
  1260 +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
  1261 + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
  1262 + int32_t n) {
  1263 + auto matches = p->impl->GetBestMatches(v, threshold, n);
  1264 +
  1265 + if (matches.empty()) {
  1266 + return nullptr;
  1267 + }
  1268 +
  1269 + auto resultMatches =
  1270 + new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()];
  1271 + for (int i = 0; i < matches.size(); ++i) {
  1272 + resultMatches[i].score = matches[i].score;
  1273 +
  1274 + char *name = new char[matches[i].name.size() + 1];
  1275 + std::copy(matches[i].name.begin(), matches[i].name.end(), name);
  1276 + name[matches[i].name.size()] = '\0';
  1277 +
  1278 + resultMatches[i].name = name;
  1279 + }
  1280 +
  1281 + auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult();
  1282 + result->count = matches.size();
  1283 + result->matches = resultMatches;
  1284 +
  1285 + return result;
  1286 +}
  1287 +
  1288 +void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
  1289 + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) {
  1290 + for (int32_t i = 0; i < r->count; ++i) {
  1291 + delete[] r->matches[i].name;
  1292 + }
  1293 + delete[] r->matches;
  1294 + delete r;
  1295 +};
  1296 +
1259 int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( 1297 int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
1260 const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, 1298 const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
1261 const float *v, float threshold) { 1299 const float *v, float threshold) {
@@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch( @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
1109 SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( 1109 SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
1110 const char *name); 1110 const char *name);
1111 1111
  1112 +SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch {
  1113 + float score;
  1114 + const char *name;
  1115 +} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch;
  1116 +
  1117 +SHERPA_ONNX_API typedef struct
  1118 + SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult {
  1119 + const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches;
  1120 + int32_t count;
  1121 +} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult;
  1122 +
  1123 +// Get the best matching speakers whose embeddings match the given
  1124 +// embedding.
  1125 +//
  1126 +// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance.
  1127 +// @param v Pointer to an array containing the embedding vector.
  1128 +// @param threshold Minimum similarity score required for a match (between 0 and
  1129 +// 1).
  1130 +// @param n Number of best matches to retrieve.
  1131 +// @return Returns a pointer to
  1132 +// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult
  1133 +// containing the best matches found. Returns NULL if no matches are
  1134 +// found. The caller is responsible for freeing the returned pointer
  1135 +// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to
  1136 +// avoid memory leaks.
  1137 +SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
  1138 +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
  1139 + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
  1140 + int32_t n);
  1141 +
  1142 +SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
  1143 + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r);
  1144 +
1112 // Check whether the input embedding matches the embedding of the input 1145 // Check whether the input embedding matches the embedding of the input
1113 // speaker. 1146 // speaker.
1114 // 1147 //
@@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl { @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl {
131 return row2name_.at(max_index); 131 return row2name_.at(max_index);
132 } 132 }
133 133
  134 + std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
  135 + int32_t n) {
  136 + std::vector<SpeakerMatch> matches;
  137 +
  138 + if (embedding_matrix_.rows() == 0) {
  139 + return matches;
  140 + }
  141 +
  142 + Eigen::VectorXf v =
  143 + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
  144 + v.normalize();
  145 +
  146 + Eigen::VectorXf scores = embedding_matrix_ * v;
  147 +
  148 + std::vector<std::pair<float, int>> score_indices;
  149 + for (int i = 0; i < scores.size(); ++i) {
  150 + if (scores[i] >= threshold) {
  151 + score_indices.emplace_back(scores[i], i);
  152 + }
  153 + }
  154 +
  155 + std::sort(score_indices.rbegin(), score_indices.rend(),
  156 + [](const auto &a, const auto &b) { return a.first < b.first; });
  157 +
  158 + matches.reserve(score_indices.size());
  159 + for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size()));
  160 + ++i) {
  161 + const auto &pair = score_indices[i];
  162 + matches.push_back({row2name_.at(pair.second), pair.first});
  163 + }
  164 +
  165 + return matches;
  166 + }
  167 +
134 bool Verify(const std::string &name, const float *p, float threshold) { 168 bool Verify(const std::string &name, const float *p, float threshold) {
135 if (!name2row_.count(name)) { 169 if (!name2row_.count(name)) {
136 return false; 170 return false;
@@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p, @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p,
219 return impl_->Search(p, threshold); 253 return impl_->Search(p, threshold);
220 } 254 }
221 255
  256 +std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches(
  257 + const float *p, float threshold, int32_t n) const {
  258 + return impl_->GetBestMatches(p, threshold, n);
  259 +}
  260 +
222 bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, 261 bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
223 float threshold) const { 262 float threshold) const {
224 return impl_->Verify(name, p, threshold); 263 return impl_->Verify(name, p, threshold);
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +struct SpeakerMatch {
  13 + const std::string name;
  14 + float score;
  15 +};
  16 +
12 namespace sherpa_onnx { 17 namespace sherpa_onnx {
13 18
14 class SpeakerEmbeddingManager { 19 class SpeakerEmbeddingManager {
@@ -62,6 +67,25 @@ class SpeakerEmbeddingManager { @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager {
62 */ 67 */
63 std::string Search(const float *p, float threshold) const; 68 std::string Search(const float *p, float threshold) const;
64 69
  70 + /**
  71 + * It is for speaker identification.
  72 + *
  73 + * It computes the cosine similarity between a given embedding and all
  74 + * other embeddings and finds the embeddings that have the largest scores
  75 + * and the scores are above or equal to the threshold. Returns a vector of
  76 + * SpeakerMatch structures containing the speaker names and scores for the
  77 + * embeddings if found; otherwise, returns an empty vector.
  78 + *
  79 + * @param p A pointer to the input embedding.
  80 + * @param threshold A value between 0 and 1.
  81 + * @param n The number of top matches to return.
  82 + * @return A vector of SpeakerMatch structures. If matches are found, the
  83 + * vector contains the names and scores of the speakers. Otherwise,
  84 + * it returns an empty vector.
  85 + */
  86 + std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
  87 + int32_t n) const;
  88 +
65 /* Check whether the input embedding matches the embedding of the input 89 /* Check whether the input embedding matches the embedding of the input
66 * speaker. 90 * speaker.
67 * 91 *