Fangjun Kuang
Committed by GitHub

Handle NaN embeddings in speaker diarization. (#1461)

See also https://github.com/thewh1teagle/sherpa-rs/issues/33
@@ -19,7 +19,7 @@ @@ -19,7 +19,7 @@
19 #include "sherpa-onnx/c-api/cxx-api.h" 19 #include "sherpa-onnx/c-api/cxx-api.h"
20 20
21 int32_t main() { 21 int32_t main() {
22 - using namespace sherpa_onnx::cxx; 22 + using namespace sherpa_onnx::cxx; // NOLINT
23 OfflineRecognizerConfig config; 23 OfflineRecognizerConfig config;
24 24
25 config.model_config.sense_voice.model = 25 config.model_config.sense_voice.model =
@@ -20,7 +20,7 @@ @@ -20,7 +20,7 @@
20 #include "sherpa-onnx/c-api/cxx-api.h" 20 #include "sherpa-onnx/c-api/cxx-api.h"
21 21
22 int32_t main() { 22 int32_t main() {
23 - using namespace sherpa_onnx::cxx; 23 + using namespace sherpa_onnx::cxx; // NOLINT
24 OnlineRecognizerConfig config; 24 OnlineRecognizerConfig config;
25 25
26 // please see 26 // please see
@@ -19,7 +19,7 @@ @@ -19,7 +19,7 @@
19 #include "sherpa-onnx/c-api/cxx-api.h" 19 #include "sherpa-onnx/c-api/cxx-api.h"
20 20
21 int32_t main() { 21 int32_t main() {
22 - using namespace sherpa_onnx::cxx; 22 + using namespace sherpa_onnx::cxx; // NOLINT
23 OfflineRecognizerConfig config; 23 OfflineRecognizerConfig config;
24 24
25 config.model_config.whisper.encoder = 25 config.model_config.whisper.encoder =
@@ -71,6 +71,9 @@ function is_source_code_file() { @@ -71,6 +71,9 @@ function is_source_code_file() {
71 } 71 }
72 72
73 function check_style() { 73 function check_style() {
  74 + if [[ $1 == mfc-example* ]]; then
  75 + return
  76 + fi
74 python3 $cpplint_src $1 || abort $1 77 python3 $cpplint_src $1 || abort $1
75 } 78 }
76 79
@@ -99,7 +102,7 @@ function do_check() { @@ -99,7 +102,7 @@ function do_check() {
99 ;; 102 ;;
100 2) 103 2)
101 echo "Check all files" 104 echo "Check all files"
102 - files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc") 105 + files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
103 ;; 106 ;;
104 *) 107 *)
105 echo "Check last commit" 108 echo "Check last commit"
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
6 6
7 #include <algorithm> 7 #include <algorithm>
  8 +#include <cmath>
8 #include <memory> 9 #include <memory>
9 #include <unordered_map> 10 #include <unordered_map>
10 #include <utility> 11 #include <utility>
@@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
135 } 136 }
136 137
137 auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); 138 auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
  139 +
  140 + // The embedding model may output NaN. valid_indexes contains indexes
  141 + // in chunk_speaker_samples_list_pair.second that don't lead to
  142 + // NaN embeddings.
  143 + std::vector<int32_t> valid_indexes;
  144 + valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());
  145 +
138 Matrix2D embeddings = 146 Matrix2D embeddings =
139 ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, 147 ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
140 - std::move(callback), callback_arg); 148 + &valid_indexes, std::move(callback), callback_arg);
  149 +
  150 + if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
  151 + std::vector<Int32Pair> chunk_speaker_pair;
  152 + std::vector<std::vector<Int32Pair>> sample_indexes;
  153 +
  154 + chunk_speaker_pair.reserve(valid_indexes.size());
  155 + sample_indexes.reserve(valid_indexes.size());
  156 + for (auto i : valid_indexes) {
  157 + chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
  158 + sample_indexes.push_back(
  159 + std::move(chunk_speaker_samples_list_pair.second[i]));
  160 + }
  161 +
  162 + chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
  163 + chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
  164 + }
141 165
142 std::vector<int32_t> cluster_labels = clustering_->Cluster( 166 std::vector<int32_t> cluster_labels = clustering_->Cluster(
143 &embeddings(0, 0), embeddings.rows(), embeddings.cols()); 167 &embeddings(0, 0), embeddings.rows(), embeddings.cols());
@@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
431 Matrix2D ComputeEmbeddings( 455 Matrix2D ComputeEmbeddings(
432 const float *audio, int32_t n, 456 const float *audio, int32_t n,
433 const std::vector<std::vector<Int32Pair>> &sample_indexes, 457 const std::vector<std::vector<Int32Pair>> &sample_indexes,
  458 + std::vector<int32_t> *valid_indexes,
434 OfflineSpeakerDiarizationProgressCallback callback, 459 OfflineSpeakerDiarizationProgressCallback callback,
435 void *callback_arg) const { 460 void *callback_arg) const {
436 const auto &meta_data = segmentation_model_.GetModelMetaData(); 461 const auto &meta_data = segmentation_model_.GetModelMetaData();
437 int32_t sample_rate = meta_data.sample_rate; 462 int32_t sample_rate = meta_data.sample_rate;
438 Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); 463 Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
439 464
  465 + auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };
  466 +
440 int32_t k = 0; 467 int32_t k = 0;
  468 + int32_t cur_row_index = 0;
441 for (const auto &v : sample_indexes) { 469 for (const auto &v : sample_indexes) {
442 auto stream = embedding_extractor_.CreateStream(); 470 auto stream = embedding_extractor_.CreateStream();
443 for (const auto &p : v) { 471 for (const auto &p : v) {
@@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl
459 487
460 std::vector<float> embedding = embedding_extractor_.Compute(stream.get()); 488 std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
461 489
462 - std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); 490 + if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
  491 + // a valid embedding
  492 + std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
  493 + cur_row_index += 1;
  494 + valid_indexes->push_back(k);
  495 + }
463 496
464 k += 1; 497 k += 1;
465 498
@@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
468 } 501 }
469 } 502 }
470 503
  504 + if (k != cur_row_index) {
  505 + auto seq = Eigen::seqN(0, cur_row_index);
  506 + ans = ans(seq, Eigen::all);
  507 + }
  508 +
471 return ans; 509 return ans;
472 } 510 }
473 511
@@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
122 auto variance = EX2 - EX.array().pow(2); 122 auto variance = EX2 - EX.array().pow(2);
123 auto stddev = variance.array().sqrt(); 123 auto stddev = variance.array().sqrt();
124 124
125 - m = (m.rowwise() - EX).array().rowwise() / stddev.array(); 125 + m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5);
126 } 126 }
127 127
128 private: 128 private: