Committed by
GitHub
feat: find best embedding matches (#1102)
正在显示
4 个修改的文件
包含
134 行增加
和
0 行删除
| @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) { | @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) { | ||
| 1256 | delete[] name; | 1256 | delete[] name; |
| 1257 | } | 1257 | } |
| 1258 | 1258 | ||
| 1259 | +const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * | ||
| 1260 | +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( | ||
| 1261 | + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, | ||
| 1262 | + int32_t n) { | ||
| 1263 | + auto matches = p->impl->GetBestMatches(v, threshold, n); | ||
| 1264 | + | ||
| 1265 | + if (matches.empty()) { | ||
| 1266 | + return nullptr; | ||
| 1267 | + } | ||
| 1268 | + | ||
| 1269 | + auto resultMatches = | ||
| 1270 | + new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()]; | ||
| 1271 | + for (int i = 0; i < matches.size(); ++i) { | ||
| 1272 | + resultMatches[i].score = matches[i].score; | ||
| 1273 | + | ||
| 1274 | + char *name = new char[matches[i].name.size() + 1]; | ||
| 1275 | + std::copy(matches[i].name.begin(), matches[i].name.end(), name); | ||
| 1276 | + name[matches[i].name.size()] = '\0'; | ||
| 1277 | + | ||
| 1278 | + resultMatches[i].name = name; | ||
| 1279 | + } | ||
| 1280 | + | ||
| 1281 | + auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult(); | ||
| 1282 | + result->count = matches.size(); | ||
| 1283 | + result->matches = resultMatches; | ||
| 1284 | + | ||
| 1285 | + return result; | ||
| 1286 | +} | ||
| 1287 | + | ||
| 1288 | +void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( | ||
| 1289 | + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) { | ||
| 1290 | + for (int32_t i = 0; i < r->count; ++i) { | ||
| 1291 | + delete[] r->matches[i].name; | ||
| 1292 | + } | ||
| 1293 | + delete[] r->matches; | ||
| 1294 | + delete r; | ||
| 1295 | +}; | ||
| 1296 | + | ||
| 1259 | int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( | 1297 | int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( |
| 1260 | const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, | 1298 | const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, |
| 1261 | const float *v, float threshold) { | 1299 | const float *v, float threshold) { |
| @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch( | @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch( | ||
| 1109 | SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( | 1109 | SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( |
| 1110 | const char *name); | 1110 | const char *name); |
| 1111 | 1111 | ||
| 1112 | +SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch { | ||
| 1113 | + float score; | ||
| 1114 | + const char *name; | ||
| 1115 | +} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch; | ||
| 1116 | + | ||
| 1117 | +SHERPA_ONNX_API typedef struct | ||
| 1118 | + SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult { | ||
| 1119 | + const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches; | ||
| 1120 | + int32_t count; | ||
| 1121 | +} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult; | ||
| 1122 | + | ||
| 1123 | +// Get the best matching speakers whose embeddings match the given | ||
| 1124 | +// embedding. | ||
| 1125 | +// | ||
| 1126 | +// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance. | ||
| 1127 | +// @param v Pointer to an array containing the embedding vector. | ||
| 1128 | +// @param threshold Minimum similarity score required for a match (between 0 and | ||
| 1129 | +// 1). | ||
| 1130 | +// @param n Number of best matches to retrieve. | ||
| 1131 | +// @return Returns a pointer to | ||
| 1132 | +// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult | ||
| 1133 | +// containing the best matches found. Returns NULL if no matches are | ||
| 1134 | +// found. The caller is responsible for freeing the returned pointer | ||
| 1135 | +// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to | ||
| 1136 | +// avoid memory leaks. | ||
| 1137 | +SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * | ||
| 1138 | +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( | ||
| 1139 | + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, | ||
| 1140 | + int32_t n); | ||
| 1141 | + | ||
| 1142 | +SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( | ||
| 1143 | + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r); | ||
| 1144 | + | ||
| 1112 | // Check whether the input embedding matches the embedding of the input | 1145 | // Check whether the input embedding matches the embedding of the input |
| 1113 | // speaker. | 1146 | // speaker. |
| 1114 | // | 1147 | // |
| @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl { | @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl { | ||
| 131 | return row2name_.at(max_index); | 131 | return row2name_.at(max_index); |
| 132 | } | 132 | } |
| 133 | 133 | ||
| 134 | + std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold, | ||
| 135 | + int32_t n) { | ||
| 136 | + std::vector<SpeakerMatch> matches; | ||
| 137 | + | ||
| 138 | + if (embedding_matrix_.rows() == 0) { | ||
| 139 | + return matches; | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + Eigen::VectorXf v = | ||
| 143 | + Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_); | ||
| 144 | + v.normalize(); | ||
| 145 | + | ||
| 146 | + Eigen::VectorXf scores = embedding_matrix_ * v; | ||
| 147 | + | ||
| 148 | + std::vector<std::pair<float, int>> score_indices; | ||
| 149 | + for (int i = 0; i < scores.size(); ++i) { | ||
| 150 | + if (scores[i] >= threshold) { | ||
| 151 | + score_indices.emplace_back(scores[i], i); | ||
| 152 | + } | ||
| 153 | + } | ||
| 154 | + | ||
| 155 | + std::sort(score_indices.rbegin(), score_indices.rend(), | ||
| 156 | + [](const auto &a, const auto &b) { return a.first < b.first; }); | ||
| 157 | + | ||
| 158 | + matches.reserve(score_indices.size()); | ||
| 159 | + for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size())); | ||
| 160 | + ++i) { | ||
| 161 | + const auto &pair = score_indices[i]; | ||
| 162 | + matches.push_back({row2name_.at(pair.second), pair.first}); | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + return matches; | ||
| 166 | + } | ||
| 167 | + | ||
| 134 | bool Verify(const std::string &name, const float *p, float threshold) { | 168 | bool Verify(const std::string &name, const float *p, float threshold) { |
| 135 | if (!name2row_.count(name)) { | 169 | if (!name2row_.count(name)) { |
| 136 | return false; | 170 | return false; |
| @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p, | @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p, | ||
| 219 | return impl_->Search(p, threshold); | 253 | return impl_->Search(p, threshold); |
| 220 | } | 254 | } |
| 221 | 255 | ||
| 256 | +std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches( | ||
| 257 | + const float *p, float threshold, int32_t n) const { | ||
| 258 | + return impl_->GetBestMatches(p, threshold, n); | ||
| 259 | +} | ||
| 260 | + | ||
| 222 | bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, | 261 | bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, |
| 223 | float threshold) const { | 262 | float threshold) const { |
| 224 | return impl_->Verify(name, p, threshold); | 263 | return impl_->Verify(name, p, threshold); |
| @@ -9,6 +9,11 @@ | @@ -9,6 +9,11 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +struct SpeakerMatch { | ||
| 13 | + const std::string name; | ||
| 14 | + float score; | ||
| 15 | +}; | ||
| 16 | + | ||
| 12 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 13 | 18 | ||
| 14 | class SpeakerEmbeddingManager { | 19 | class SpeakerEmbeddingManager { |
| @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager { | @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager { | ||
| 62 | */ | 67 | */ |
| 63 | std::string Search(const float *p, float threshold) const; | 68 | std::string Search(const float *p, float threshold) const; |
| 64 | 69 | ||
| 70 | + /** | ||
| 71 | + * It is for speaker identification. | ||
| 72 | + * | ||
| 73 | + * It computes the cosine similarity between a given embedding and all | ||
| 74 | + * other embeddings and finds the embeddings that have the largest scores | ||
| 75 | + * and the scores are above or equal to the threshold. Returns a vector of | ||
| 76 | + * SpeakerMatch structures containing the speaker names and scores for the | ||
| 77 | + * embeddings if found; otherwise, returns an empty vector. | ||
| 78 | + * | ||
| 79 | + * @param p A pointer to the input embedding. | ||
| 80 | + * @param threshold A value between 0 and 1. | ||
| 81 | + * @param n The number of top matches to return. | ||
| 82 | + * @return A vector of SpeakerMatch structures. If matches are found, the | ||
| 83 | + * vector contains the names and scores of the speakers. Otherwise, | ||
| 84 | + * it returns an empty vector. | ||
| 85 | + */ | ||
| 86 | + std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold, | ||
| 87 | + int32_t n) const; | ||
| 88 | + | ||
| 65 | /* Check whether the input embedding matches the embedding of the input | 89 | /* Check whether the input embedding matches the embedding of the input |
| 66 | * speaker. | 90 | * speaker. |
| 67 | * | 91 | * |
-
请 注册 或 登录 后发表评论