Fangjun Kuang
Committed by GitHub

Support Agglomerative clustering. (#1384)

We use the open-source implementation from
https://github.com/cdalitz/hclust-cpp
@@ -40,6 +40,7 @@ option(SHERPA_ONNX_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF) @@ -40,6 +40,7 @@ option(SHERPA_ONNX_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF)
40 option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) 40 option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF)
41 option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON) 41 option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON)
42 option(SHERPA_ONNX_ENABLE_TTS "Whether to build TTS related code" ON) 42 option(SHERPA_ONNX_ENABLE_TTS "Whether to build TTS related code" ON)
  43 +option(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION "Whether to build speaker diarization related code" ON)
43 option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) 44 option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON)
44 option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) 45 option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON)
45 option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF) 46 option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF)
@@ -142,6 +143,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_VAD_ASR ${SHERPA_ONNX_ENABLE_WASM_VAD_AS @@ -142,6 +143,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_VAD_ASR ${SHERPA_ONNX_ENABLE_WASM_VAD_AS
142 message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}") 143 message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}")
143 message(STATUS "SHERPA_ONNX_ENABLE_BINARY ${SHERPA_ONNX_ENABLE_BINARY}") 144 message(STATUS "SHERPA_ONNX_ENABLE_BINARY ${SHERPA_ONNX_ENABLE_BINARY}")
144 message(STATUS "SHERPA_ONNX_ENABLE_TTS ${SHERPA_ONNX_ENABLE_TTS}") 145 message(STATUS "SHERPA_ONNX_ENABLE_TTS ${SHERPA_ONNX_ENABLE_TTS}")
  146 +message(STATUS "SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ${SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION}")
145 message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY}") 147 message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY}")
146 message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") 148 message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}")
147 message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}") 149 message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}")
@@ -341,6 +343,10 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -341,6 +343,10 @@ if(SHERPA_ONNX_ENABLE_TTS)
341 include(cppjieba) # For Chinese TTS. It is a header-only C++ library 343 include(cppjieba) # For Chinese TTS. It is a header-only C++ library
342 endif() 344 endif()
343 345
  346 +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  347 + include(hclust-cpp)
  348 +endif()
  349 +
344 # if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")) 350 # if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"))
345 if(SHERPA_ONNX_ENABLE_SANITIZER) 351 if(SHERPA_ONNX_ENABLE_SANITIZER)
346 message(WARNING "enable ubsan and asan") 352 message(WARNING "enable ubsan and asan")
  1 +function(download_hclust_cpp)
  2 + include(FetchContent)
  3 +
  4 + # The latest commit as of 2024.09.29
  5 + set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz")
  6 + set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download hclust-cpp
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz
  12 + ${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz
  13 + ${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz
  14 + /tmp/hclust-cpp-2024-09-29.tar.gz
  15 + /star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(hclust_cpp_URL "${f}")
  21 + file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL)
  22 + message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}")
  23 + break()
  24 + endif()
  25 + endforeach()
  26 +
  27 + FetchContent_Declare(hclust_cpp
  28 + URL
  29 + ${hclust_cpp_URL}
  30 + ${hclust_cpp_URL2}
  31 + URL_HASH ${hclust_cpp_HASH}
  32 + )
  33 +
  34 + FetchContent_GetProperties(hclust_cpp)
  35 + if(NOT hclust_cpp_POPULATED)
  36 + message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}")
  37 + FetchContent_Populate(hclust_cpp)
  38 + endif()
  39 +
  40 + message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}")
  41 + message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}")
  42 + include_directories(${hclust_cpp_SOURCE_DIR})
  43 +endfunction()
  44 +
  45 +download_hclust_cpp()
@@ -160,6 +160,13 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -160,6 +160,13 @@ if(SHERPA_ONNX_ENABLE_TTS)
160 ) 160 )
161 endif() 161 endif()
162 162
  163 +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  164 + list(APPEND sources
  165 + fast-clustering-config.cc
  166 + fast-clustering.cc
  167 + )
  168 +endif()
  169 +
163 if(SHERPA_ONNX_ENABLE_CHECK) 170 if(SHERPA_ONNX_ENABLE_CHECK)
164 list(APPEND sources log.cc) 171 list(APPEND sources log.cc)
165 endif() 172 endif()
@@ -523,6 +530,12 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -523,6 +530,12 @@ if(SHERPA_ONNX_ENABLE_TESTS)
523 ) 530 )
524 endif() 531 endif()
525 532
  533 + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  534 + list(APPEND sherpa_onnx_test_srcs
  535 + fast-clustering-test.cc
  536 + )
  537 + endif()
  538 +
526 list(APPEND sherpa_onnx_test_srcs 539 list(APPEND sherpa_onnx_test_srcs
527 speaker-embedding-manager-test.cc 540 speaker-embedding-manager-test.cc
528 ) 541 )
  1 +// sherpa-onnx/csrc/fast-clustering-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/fast-clustering-config.h"
  6 +
  7 +#include <sstream>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +std::string FastClusteringConfig::ToString() const {
  14 + std::ostringstream os;
  15 +
  16 + os << "FastClusteringConfig(";
  17 + os << "num_clusters=" << num_clusters << ", ";
  18 + os << "threshold=" << threshold << ")";
  19 +
  20 + return os.str();
  21 +}
  22 +
  23 +void FastClusteringConfig::Register(ParseOptions *po) {
  24 + std::string prefix = "ctc";
  25 + ParseOptions p(prefix, po);
  26 +
  27 + p.Register("num-clusters", &num_clusters,
  28 + "Number of cluster. If greater than 0, then --cluster-thresold is "
  29 + "ignored");
  30 +
  31 + p.Register("cluster-threshold", &threshold,
  32 + "If --num-clusters is not specified, then it specifies the "
  33 + "distance threshold for clustering.");
  34 +}
  35 +
  36 +bool FastClusteringConfig::Validate() const {
  37 + if (num_clusters < 1 && threshold < 0) {
  38 + SHERPA_ONNX_LOGE("Please provide either num_clusters or threshold");
  39 + return false;
  40 + }
  41 +
  42 + return true;
  43 +}
  44 +
  45 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/fast-clustering-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct FastClusteringConfig {
  15 + // If greater than 0, then threshold is ignored
  16 + int32_t num_clusters = -1;
  17 +
  18 + // distance threshold
  19 + float threshold = 0.5;
  20 +
  21 + std::string ToString() const;
  22 +
  23 + void Register(ParseOptions *po);
  24 + bool Validate() const;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_
  1 +// sherpa-onnx/csrc/fast-clustering-test.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/fast-clustering.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "gtest/gtest.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +TEST(FastClustering, TestTwoClusters) {
  14 + std::vector<float> features = {
  15 + // point 0
  16 + 0.1,
  17 + 0.1,
  18 + // point 2
  19 + 0.4,
  20 + -0.5,
  21 + // point 3
  22 + 0.6,
  23 + -0.7,
  24 + // point 1
  25 + 0.2,
  26 + 0.3,
  27 + };
  28 +
  29 + FastClusteringConfig config;
  30 + config.num_clusters = 2;
  31 +
  32 + FastClustering clustering(config);
  33 + auto labels = clustering.Cluster(features.data(), 4, 2);
  34 + int32_t k = 0;
  35 + for (auto i : labels) {
  36 + std::cout << "point " << k << ": label " << i << "\n";
  37 + ++k;
  38 + }
  39 +}
  40 +
  41 +TEST(FastClustering, TestClusteringWithThreshold) {
  42 + std::vector<float> features = {
  43 + // point 0
  44 + 0.1,
  45 + 0.1,
  46 + // point 2
  47 + 0.4,
  48 + -0.5,
  49 + // point 3
  50 + 0.6,
  51 + -0.7,
  52 + // point 1
  53 + 0.2,
  54 + 0.3,
  55 + };
  56 +
  57 + FastClusteringConfig config;
  58 + config.threshold = 0.5;
  59 +
  60 + FastClustering clustering(config);
  61 + auto labels = clustering.Cluster(features.data(), 4, 2);
  62 + int32_t k = 0;
  63 + for (auto i : labels) {
  64 + std::cout << "point " << k << ": label " << i << "\n";
  65 + ++k;
  66 + }
  67 +}
  68 +
  69 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/fast-clustering.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/fast-clustering.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "Eigen/Dense"
  10 +#include "fastcluster-all-in-one.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class FastClustering::Impl {
  15 + public:
  16 + explicit Impl(const FastClusteringConfig &config) : config_(config) {}
  17 +
  18 + std::vector<int32_t> Cluster(float *features, int32_t num_rows,
  19 + int32_t num_cols) {
  20 + if (num_rows <= 0) {
  21 + return {};
  22 + }
  23 +
  24 + if (num_rows == 1) {
  25 + return {0};
  26 + }
  27 +
  28 + Eigen::Map<
  29 + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
  30 + m(features, num_rows, num_cols);
  31 + m.rowwise().normalize();
  32 +
  33 + std::vector<double> distance((num_rows * (num_rows - 1)) / 2);
  34 +
  35 + int32_t k = 0;
  36 + for (int32_t i = 0; i != num_rows; ++i) {
  37 + auto v = m.row(i);
  38 + for (int32_t j = i + 1; j != num_rows; ++j) {
  39 + double cosine_similarity = v.dot(m.row(j));
  40 + double consine_dissimilarity = 1 - cosine_similarity;
  41 +
  42 + if (consine_dissimilarity < 0) {
  43 + consine_dissimilarity = 0;
  44 + }
  45 +
  46 + distance[k] = consine_dissimilarity;
  47 + ++k;
  48 + }
  49 + }
  50 +
  51 + std::vector<int32_t> merge(2 * (num_rows - 1));
  52 + std::vector<double> height(num_rows - 1);
  53 +
  54 + fastclustercpp::hclust_fast(num_rows, distance.data(),
  55 + fastclustercpp::HCLUST_METHOD_SINGLE,
  56 + merge.data(), height.data());
  57 +
  58 + std::vector<int32_t> labels(num_rows);
  59 + if (config_.num_clusters > 0) {
  60 + fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters,
  61 + labels.data());
  62 + } else {
  63 + fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(),
  64 + config_.threshold, labels.data());
  65 + }
  66 +
  67 + return labels;
  68 + }
  69 +
  70 + private:
  71 + FastClusteringConfig config_;
  72 +};
  73 +
  74 +FastClustering::FastClustering(const FastClusteringConfig &config)
  75 + : impl_(std::make_unique<Impl>(config)) {}
  76 +
  77 +FastClustering::~FastClustering() = default;
  78 +
  79 +std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
  80 + int32_t num_cols) {
  81 + return impl_->Cluster(features, num_rows, num_cols);
  82 +}
  83 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/fast-clustering.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_
  6 +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/fast-clustering-config.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class FastClustering {
  16 + public:
  17 + explicit FastClustering(const FastClusteringConfig &config);
  18 + ~FastClustering();
  19 +
  20 + /**
  21 + * @param features Pointer to a 2-D feature matrix in row major. Each row
  22 + * is a feature frame. It is changed in-place. We will
  23 + * convert each feature frame to a normalized vector.
  24 + * That is, the L2-norm of each vector will be equal to 1.
  25 + * It uses cosine dissimilarity,
  26 + * which is 1 - (cosine similarity)
  27 + * @param num_rows Number of feature frames
  28 + * @param num-cols The feature dimension.
  29 + *
  30 + * @return Return a vector of size num_rows. ans[i] contains the label
  31 + * for the i-th feature frame, i.e., the i-th row of the feature
  32 + * matrix.
  33 + */
  34 + std::vector<int32_t> Cluster(float *features, int32_t num_rows,
  35 + int32_t num_cols);
  36 +
  37 + private:
  38 + class Impl;
  39 + std::unique_ptr<Impl> impl_;
  40 +};
  41 +
  42 +} // namespace sherpa_onnx
  43 +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_
@@ -4,8 +4,6 @@ @@ -4,8 +4,6 @@
4 4
5 #include "sherpa-onnx/csrc/offline-stream.h" 5 #include "sherpa-onnx/csrc/offline-stream.h"
6 6
7 -#include <math.h>  
8 -  
9 #include <algorithm> 7 #include <algorithm>
10 #include <cassert> 8 #include <cassert>
11 #include <cmath> 9 #include <cmath>
@@ -8,16 +8,16 @@ @@ -8,16 +8,16 @@
8 namespace sherpa_onnx { 8 namespace sherpa_onnx {
9 9
10 struct OnlineCNNBiLSTMModelMetaData { 10 struct OnlineCNNBiLSTMModelMetaData {
11 - int32_t comma_id;  
12 - int32_t period_id;  
13 - int32_t quest_id; 11 + int32_t comma_id = -1;
  12 + int32_t period_id = -1;
  13 + int32_t quest_id = -1;
14 14
15 - int32_t upper_id;  
16 - int32_t cap_id;  
17 - int32_t mix_case_id; 15 + int32_t upper_id = -1;
  16 + int32_t cap_id = -1;
  17 + int32_t mix_case_id = -1;
18 18
19 - int32_t num_cases;  
20 - int32_t num_punctuations; 19 + int32_t num_cases = -1;
  20 + int32_t num_punctuations = -1;
21 }; 21 };
22 22
23 } // namespace sherpa_onnx 23 } // namespace sherpa_onnx
@@ -169,7 +169,7 @@ static std::vector<int64_t> CoquiPhonemesToIds( @@ -169,7 +169,7 @@ static std::vector<int64_t> CoquiPhonemesToIds(
169 return ans; 169 return ans;
170 } 170 }
171 171
172 -void InitEspeak(const std::string &data_dir) { 172 +static void InitEspeak(const std::string &data_dir) {
173 static std::once_flag init_flag; 173 static std::once_flag init_flag;
174 std::call_once(init_flag, [data_dir]() { 174 std::call_once(init_flag, [data_dir]() {
175 int32_t result = 175 int32_t result =
@@ -41,7 +41,7 @@ @@ -41,7 +41,7 @@
41 namespace sherpa_onnx { 41 namespace sherpa_onnx {
42 42
43 template <class I> 43 template <class I>
44 -I Gcd(I m, I n) { 44 +static I Gcd(I m, I n) {
45 // this function is copied from kaldi/src/base/kaldi-math.h 45 // this function is copied from kaldi/src/base/kaldi-math.h
46 if (m == 0 || n == 0) { 46 if (m == 0 || n == 0) {
47 if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. 47 if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
@@ -65,7 +65,7 @@ I Gcd(I m, I n) { @@ -65,7 +65,7 @@ I Gcd(I m, I n) {
65 /// Returns the least common multiple of two integers. Will 65 /// Returns the least common multiple of two integers. Will
66 /// crash unless the inputs are positive. 66 /// crash unless the inputs are positive.
67 template <class I> 67 template <class I>
68 -I Lcm(I m, I n) { 68 +static I Lcm(I m, I n) {
69 // This function is copied from kaldi/src/base/kaldi-math.h 69 // This function is copied from kaldi/src/base/kaldi-math.h
70 assert(m > 0 && n > 0); 70 assert(m > 0 && n > 0);
71 I gcd = Gcd(m, n); 71 I gcd = Gcd(m, n);