Fangjun Kuang
Committed by GitHub

Handle NaN embeddings in speaker diarization. (#1461)

See also https://github.com/thewh1teagle/sherpa-rs/issues/33
... ... @@ -19,7 +19,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"
int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OfflineRecognizerConfig config;
config.model_config.sense_voice.model =
... ...
... ... @@ -20,7 +20,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"
int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OnlineRecognizerConfig config;
// please see
... ...
... ... @@ -19,7 +19,7 @@
#include "sherpa-onnx/c-api/cxx-api.h"
int32_t main() {
using namespace sherpa_onnx::cxx;
using namespace sherpa_onnx::cxx; // NOLINT
OfflineRecognizerConfig config;
config.model_config.whisper.encoder =
... ...
... ... @@ -71,6 +71,9 @@ function is_source_code_file() {
}
function check_style() {
if [[ $1 == mfc-example* ]]; then
return
fi
python3 $cpplint_src $1 || abort $1
}
... ... @@ -99,7 +102,7 @@ function do_check() {
;;
2)
echo "Check all files"
files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
;;
*)
echo "Check last commit"
... ...
... ... @@ -5,6 +5,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <unordered_map>
#include <utility>
... ... @@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
}
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
// The embedding model may output NaN. valid_indexes contains indexes
// in chunk_speaker_samples_list_pair.second that don't lead to
// NaN embeddings.
std::vector<int32_t> valid_indexes;
valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);
&valid_indexes, std::move(callback), callback_arg);
if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
std::vector<Int32Pair> chunk_speaker_pair;
std::vector<std::vector<Int32Pair>> sample_indexes;
chunk_speaker_pair.reserve(valid_indexes.size());
sample_indexes.reserve(valid_indexes.size());
for (auto i : valid_indexes) {
chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
sample_indexes.push_back(
std::move(chunk_speaker_samples_list_pair.second[i]));
}
chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
}
std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
... ... @@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Matrix2D ComputeEmbeddings(
const float *audio, int32_t n,
const std::vector<std::vector<Int32Pair>> &sample_indexes,
std::vector<int32_t> *valid_indexes,
OfflineSpeakerDiarizationProgressCallback callback,
void *callback_arg) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t sample_rate = meta_data.sample_rate;
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };
int32_t k = 0;
int32_t cur_row_index = 0;
for (const auto &v : sample_indexes) {
auto stream = embedding_extractor_.CreateStream();
for (const auto &p : v) {
... ... @@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
// a valid embedding
std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
cur_row_index += 1;
valid_indexes->push_back(k);
}
k += 1;
... ... @@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
}
}
if (k != cur_row_index) {
auto seq = Eigen::seqN(0, cur_row_index);
ans = ans(seq, Eigen::all);
}
return ans;
}
... ...
... ... @@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();
m = (m.rowwise() - EX).array().rowwise() / stddev.array();
m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5);
}
private:
... ...