speaker-embedding-manager.cc 5.7 KB
// sherpa-onnx/csrc/speaker-embedding-manager.cc
//
// Copyright (c)  2024  Xiaomi Corporation

#include "sherpa-onnx/csrc/speaker-embedding-manager.h"

#include <algorithm>
#include <unordered_map>

#include "Eigen/Dense"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

using FloatMatrix =
    Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

class SpeakerEmbeddingManager::Impl {
 public:
  explicit Impl(int32_t dim) : dim_(dim) {}

  bool Add(const std::string &name, const float *p) {
    if (name2row_.count(name)) {
      // a speaker with the same name already exists
      return false;
    }

    embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);

    std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));

    embedding_matrix_.bottomRows(1).normalize();  // inplace

    name2row_[name] = embedding_matrix_.rows() - 1;
    row2name_[embedding_matrix_.rows() - 1] = name;

    return true;
  }

  bool Add(const std::string &name,
           const std::vector<std::vector<float>> &embedding_list) {
    if (name2row_.count(name)) {
      // a speaker with the same name already exists
      return false;
    }

    if (embedding_list.empty()) {
      SHERPA_ONNX_LOGE("Empty list of embeddings");
      return false;
    }

    for (const auto &x : embedding_list) {
      if (x.size() != dim_) {
        SHERPA_ONNX_LOGE("Given dim: %d, expected dim: %d",
                         static_cast<int32_t>(x.size()), dim_);
        return false;
      }
    }

    // compute the average
    Eigen::RowVectorXf v = Eigen::Map<Eigen::RowVectorXf>(
        const_cast<float *>(embedding_list[0].data()), dim_);
    int32_t i = -1;
    for (const auto &x : embedding_list) {
      ++i;
      if (i == 0) {
        continue;
      }
      v += Eigen::Map<Eigen::RowVectorXf>(const_cast<float *>(x.data()), dim_);
    }

    // no need to compute the mean since we are going to normalize it anyway
    // v /= embedding_list.size();

    v.normalize();

    embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
    embedding_matrix_.bottomRows(1) = v;

    name2row_[name] = embedding_matrix_.rows() - 1;
    row2name_[embedding_matrix_.rows() - 1] = name;

    return true;
  }

  bool Remove(const std::string &name) {
    if (!name2row_.count(name)) {
      return false;
    }

    int32_t row_idx = name2row_.at(name);

    int32_t num_rows = embedding_matrix_.rows();

    if (row_idx < num_rows - 1) {
      embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
          embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
    }

    embedding_matrix_.conservativeResize(num_rows - 1, dim_);
    for (auto &p : name2row_) {
      if (p.second > row_idx) {
        p.second -= 1;
        row2name_[p.second] = p.first;
      }
    }

    name2row_.erase(name);
    row2name_.erase(num_rows - 1);

    return true;
  }

  std::string Search(const float *p, float threshold) {
    if (embedding_matrix_.rows() == 0) {
      return {};
    }

    Eigen::VectorXf v =
        Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
    v.normalize();

    Eigen::VectorXf scores = embedding_matrix_ * v;

    Eigen::VectorXf::Index max_index;
    float max_score = scores.maxCoeff(&max_index);
    if (max_score < threshold) {
      return {};
    }

    return row2name_.at(max_index);
  }

  bool Verify(const std::string &name, const float *p, float threshold) {
    if (!name2row_.count(name)) {
      return false;
    }

    int32_t row_idx = name2row_.at(name);

    Eigen::VectorXf v =
        Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
    v.normalize();

    float score = embedding_matrix_.row(row_idx) * v;

    if (score < threshold) {
      return false;
    }

    return true;
  }

  bool Contains(const std::string &name) const {
    return name2row_.count(name) > 0;
  }

  int32_t NumSpeakers() const { return embedding_matrix_.rows(); }

  int32_t Dim() const { return dim_; }

  std::vector<std::string> GetAllSpeakers() const {
    std::vector<std::string> all_speakers;
    for (const auto &p : name2row_) {
      all_speakers.push_back(p.first);
    }

    std::stable_sort(all_speakers.begin(), all_speakers.end());
    return all_speakers;
  }

 private:
  int32_t dim_;
  FloatMatrix embedding_matrix_;
  std::unordered_map<std::string, int32_t> name2row_;
  std::unordered_map<int32_t, std::string> row2name_;
};

SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
    : impl_(std::make_unique<Impl>(dim)) {}

SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;

bool SpeakerEmbeddingManager::Add(const std::string &name,
                                  const float *p) const {
  return impl_->Add(name, p);
}

bool SpeakerEmbeddingManager::Add(
    const std::string &name,
    const std::vector<std::vector<float>> &embedding_list) const {
  return impl_->Add(name, embedding_list);
}

bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
  return impl_->Remove(name);
}

std::string SpeakerEmbeddingManager::Search(const float *p,
                                            float threshold) const {
  return impl_->Search(p, threshold);
}

bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
                                     float threshold) const {
  return impl_->Verify(name, p, threshold);
}

int32_t SpeakerEmbeddingManager::NumSpeakers() const {
  return impl_->NumSpeakers();
}

int32_t SpeakerEmbeddingManager::Dim() const { return impl_->Dim(); }

bool SpeakerEmbeddingManager::Contains(const std::string &name) const {
  return impl_->Contains(name);
}

std::vector<std::string> SpeakerEmbeddingManager::GetAllSpeakers() const {
  return impl_->GetAllSpeakers();
}

}  // namespace sherpa_onnx