Committed by
GitHub
Handle NaN embeddings in speaker diarization. (#1461)
See also https://github.com/thewh1teagle/sherpa-rs/issues/33
正在显示
6 个修改的文件
包含
48 行增加
和
7 行删除
| @@ -19,7 +19,7 @@ | @@ -19,7 +19,7 @@ | ||
| 19 | #include "sherpa-onnx/c-api/cxx-api.h" | 19 | #include "sherpa-onnx/c-api/cxx-api.h" |
| 20 | 20 | ||
| 21 | int32_t main() { | 21 | int32_t main() { |
| 22 | - using namespace sherpa_onnx::cxx; | 22 | + using namespace sherpa_onnx::cxx; // NOLINT |
| 23 | OfflineRecognizerConfig config; | 23 | OfflineRecognizerConfig config; |
| 24 | 24 | ||
| 25 | config.model_config.sense_voice.model = | 25 | config.model_config.sense_voice.model = |
| @@ -20,7 +20,7 @@ | @@ -20,7 +20,7 @@ | ||
| 20 | #include "sherpa-onnx/c-api/cxx-api.h" | 20 | #include "sherpa-onnx/c-api/cxx-api.h" |
| 21 | 21 | ||
| 22 | int32_t main() { | 22 | int32_t main() { |
| 23 | - using namespace sherpa_onnx::cxx; | 23 | + using namespace sherpa_onnx::cxx; // NOLINT |
| 24 | OnlineRecognizerConfig config; | 24 | OnlineRecognizerConfig config; |
| 25 | 25 | ||
| 26 | // please see | 26 | // please see |
| @@ -19,7 +19,7 @@ | @@ -19,7 +19,7 @@ | ||
| 19 | #include "sherpa-onnx/c-api/cxx-api.h" | 19 | #include "sherpa-onnx/c-api/cxx-api.h" |
| 20 | 20 | ||
| 21 | int32_t main() { | 21 | int32_t main() { |
| 22 | - using namespace sherpa_onnx::cxx; | 22 | + using namespace sherpa_onnx::cxx; // NOLINT |
| 23 | OfflineRecognizerConfig config; | 23 | OfflineRecognizerConfig config; |
| 24 | 24 | ||
| 25 | config.model_config.whisper.encoder = | 25 | config.model_config.whisper.encoder = |
| @@ -71,6 +71,9 @@ function is_source_code_file() { | @@ -71,6 +71,9 @@ function is_source_code_file() { | ||
| 71 | } | 71 | } |
| 72 | 72 | ||
| 73 | function check_style() { | 73 | function check_style() { |
| 74 | + if [[ $1 == mfc-example* ]]; then | ||
| 75 | + return | ||
| 76 | + fi | ||
| 74 | python3 $cpplint_src $1 || abort $1 | 77 | python3 $cpplint_src $1 || abort $1 |
| 75 | } | 78 | } |
| 76 | 79 | ||
| @@ -99,7 +102,7 @@ function do_check() { | @@ -99,7 +102,7 @@ function do_check() { | ||
| 99 | ;; | 102 | ;; |
| 100 | 2) | 103 | 2) |
| 101 | echo "Check all files" | 104 | echo "Check all files" |
| 102 | - files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") | 105 | + files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") |
| 103 | ;; | 106 | ;; |
| 104 | *) | 107 | *) |
| 105 | echo "Check last commit" | 108 | echo "Check last commit" |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ | 5 | #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ |
| 6 | 6 | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | +#include <cmath> | ||
| 8 | #include <memory> | 9 | #include <memory> |
| 9 | #include <unordered_map> | 10 | #include <unordered_map> |
| 10 | #include <utility> | 11 | #include <utility> |
| @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl | @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl | ||
| 135 | } | 136 | } |
| 136 | 137 | ||
| 137 | auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); | 138 | auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); |
| 139 | + | ||
| 140 | + // The embedding model may output NaN. valid_indexes contains indexes | ||
| 141 | + // in chunk_speaker_samples_list_pair.second that don't lead to | ||
| 142 | + // NaN embeddings. | ||
| 143 | + std::vector<int32_t> valid_indexes; | ||
| 144 | + valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size()); | ||
| 145 | + | ||
| 138 | Matrix2D embeddings = | 146 | Matrix2D embeddings = |
| 139 | ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, | 147 | ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, |
| 140 | - std::move(callback), callback_arg); | 148 | + &valid_indexes, std::move(callback), callback_arg); |
| 149 | + | ||
| 150 | + if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) { | ||
| 151 | + std::vector<Int32Pair> chunk_speaker_pair; | ||
| 152 | + std::vector<std::vector<Int32Pair>> sample_indexes; | ||
| 153 | + | ||
| 154 | + chunk_speaker_pair.reserve(valid_indexes.size()); | ||
| 155 | + sample_indexes.reserve(valid_indexes.size()); | ||
| 156 | + for (auto i : valid_indexes) { | ||
| 157 | + chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]); | ||
| 158 | + sample_indexes.push_back( | ||
| 159 | + std::move(chunk_speaker_samples_list_pair.second[i])); | ||
| 160 | + } | ||
| 161 | + | ||
| 162 | + chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair); | ||
| 163 | + chunk_speaker_samples_list_pair.second = std::move(sample_indexes); | ||
| 164 | + } | ||
| 141 | 165 | ||
| 142 | std::vector<int32_t> cluster_labels = clustering_->Cluster( | 166 | std::vector<int32_t> cluster_labels = clustering_->Cluster( |
| 143 | &embeddings(0, 0), embeddings.rows(), embeddings.cols()); | 167 | &embeddings(0, 0), embeddings.rows(), embeddings.cols()); |
| @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl | @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl | ||
| 431 | Matrix2D ComputeEmbeddings( | 455 | Matrix2D ComputeEmbeddings( |
| 432 | const float *audio, int32_t n, | 456 | const float *audio, int32_t n, |
| 433 | const std::vector<std::vector<Int32Pair>> &sample_indexes, | 457 | const std::vector<std::vector<Int32Pair>> &sample_indexes, |
| 458 | + std::vector<int32_t> *valid_indexes, | ||
| 434 | OfflineSpeakerDiarizationProgressCallback callback, | 459 | OfflineSpeakerDiarizationProgressCallback callback, |
| 435 | void *callback_arg) const { | 460 | void *callback_arg) const { |
| 436 | const auto &meta_data = segmentation_model_.GetModelMetaData(); | 461 | const auto &meta_data = segmentation_model_.GetModelMetaData(); |
| 437 | int32_t sample_rate = meta_data.sample_rate; | 462 | int32_t sample_rate = meta_data.sample_rate; |
| 438 | Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); | 463 | Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); |
| 439 | 464 | ||
| 465 | + auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); }; | ||
| 466 | + | ||
| 440 | int32_t k = 0; | 467 | int32_t k = 0; |
| 468 | + int32_t cur_row_index = 0; | ||
| 441 | for (const auto &v : sample_indexes) { | 469 | for (const auto &v : sample_indexes) { |
| 442 | auto stream = embedding_extractor_.CreateStream(); | 470 | auto stream = embedding_extractor_.CreateStream(); |
| 443 | for (const auto &p : v) { | 471 | for (const auto &p : v) { |
| @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl | @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl | ||
| 459 | 487 | ||
| 460 | std::vector<float> embedding = embedding_extractor_.Compute(stream.get()); | 488 | std::vector<float> embedding = embedding_extractor_.Compute(stream.get()); |
| 461 | 489 | ||
| 462 | - std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); | 490 | + if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) { |
| 491 | + // a valid embedding | ||
| 492 | + std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0)); | ||
| 493 | + cur_row_index += 1; | ||
| 494 | + valid_indexes->push_back(k); | ||
| 495 | + } | ||
| 463 | 496 | ||
| 464 | k += 1; | 497 | k += 1; |
| 465 | 498 | ||
| @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl | @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl | ||
| 468 | } | 501 | } |
| 469 | } | 502 | } |
| 470 | 503 | ||
| 504 | + if (k != cur_row_index) { | ||
| 505 | + auto seq = Eigen::seqN(0, cur_row_index); | ||
| 506 | + ans = ans(seq, Eigen::all); | ||
| 507 | + } | ||
| 508 | + | ||
| 471 | return ans; | 509 | return ans; |
| 472 | } | 510 | } |
| 473 | 511 |
| @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { | @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { | ||
| 122 | auto variance = EX2 - EX.array().pow(2); | 122 | auto variance = EX2 - EX.array().pow(2); |
| 123 | auto stddev = variance.array().sqrt(); | 123 | auto stddev = variance.array().sqrt(); |
| 124 | 124 | ||
| 125 | - m = (m.rowwise() - EX).array().rowwise() / stddev.array(); | 125 | + m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5); |
| 126 | } | 126 | } |
| 127 | 127 | ||
| 128 | private: | 128 | private: |
-
请 注册 或 登录 后发表评论