Committed by
GitHub
Support Agglomerative clustering. (#1384)
We use the open-source implementation from https://github.com/cdalitz/hclust-cpp
正在显示
12 个修改的文件
包含
343 行增加
和
13 行删除
| @@ -40,6 +40,7 @@ option(SHERPA_ONNX_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF) | @@ -40,6 +40,7 @@ option(SHERPA_ONNX_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF) | ||
| 40 | option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) | 40 | option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) |
| 41 | option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON) | 41 | option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON) |
| 42 | option(SHERPA_ONNX_ENABLE_TTS "Whether to build TTS related code" ON) | 42 | option(SHERPA_ONNX_ENABLE_TTS "Whether to build TTS related code" ON) |
| 43 | +option(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION "Whether to build speaker diarization related code" ON) | ||
| 43 | option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) | 44 | option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) |
| 44 | option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) | 45 | option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) |
| 45 | option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF) | 46 | option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF) |
| @@ -142,6 +143,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_VAD_ASR ${SHERPA_ONNX_ENABLE_WASM_VAD_AS | @@ -142,6 +143,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_VAD_ASR ${SHERPA_ONNX_ENABLE_WASM_VAD_AS | ||
| 142 | message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}") | 143 | message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}") |
| 143 | message(STATUS "SHERPA_ONNX_ENABLE_BINARY ${SHERPA_ONNX_ENABLE_BINARY}") | 144 | message(STATUS "SHERPA_ONNX_ENABLE_BINARY ${SHERPA_ONNX_ENABLE_BINARY}") |
| 144 | message(STATUS "SHERPA_ONNX_ENABLE_TTS ${SHERPA_ONNX_ENABLE_TTS}") | 145 | message(STATUS "SHERPA_ONNX_ENABLE_TTS ${SHERPA_ONNX_ENABLE_TTS}") |
| 146 | +message(STATUS "SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION ${SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION}") | ||
| 145 | message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY}") | 147 | message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY}") |
| 146 | message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") | 148 | message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") |
| 147 | message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}") | 149 | message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}") |
| @@ -341,6 +343,10 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -341,6 +343,10 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 341 | include(cppjieba) # For Chinese TTS. It is a header-only C++ library | 343 | include(cppjieba) # For Chinese TTS. It is a header-only C++ library |
| 342 | endif() | 344 | endif() |
| 343 | 345 | ||
| 346 | +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 347 | + include(hclust-cpp) | ||
| 348 | +endif() | ||
| 349 | + | ||
| 344 | # if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")) | 350 | # if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")) |
| 345 | if(SHERPA_ONNX_ENABLE_SANITIZER) | 351 | if(SHERPA_ONNX_ENABLE_SANITIZER) |
| 346 | message(WARNING "enable ubsan and asan") | 352 | message(WARNING "enable ubsan and asan") |
cmake/hclust-cpp.cmake
0 → 100644
| 1 | +function(download_hclust_cpp) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + # The latest commit as of 2024.09.29 | ||
| 5 | + set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz") | ||
| 6 | + set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download hclust-cpp | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz | ||
| 12 | + ${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz | ||
| 13 | + ${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz | ||
| 14 | + /tmp/hclust-cpp-2024-09-29.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(hclust_cpp_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL) | ||
| 22 | + message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}") | ||
| 23 | + break() | ||
| 24 | + endif() | ||
| 25 | + endforeach() | ||
| 26 | + | ||
| 27 | + FetchContent_Declare(hclust_cpp | ||
| 28 | + URL | ||
| 29 | + ${hclust_cpp_URL} | ||
| 30 | + ${hclust_cpp_URL2} | ||
| 31 | + URL_HASH ${hclust_cpp_HASH} | ||
| 32 | + ) | ||
| 33 | + | ||
| 34 | + FetchContent_GetProperties(hclust_cpp) | ||
| 35 | + if(NOT hclust_cpp_POPULATED) | ||
| 36 | + message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}") | ||
| 37 | + FetchContent_Populate(hclust_cpp) | ||
| 38 | + endif() | ||
| 39 | + | ||
| 40 | + message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}") | ||
| 41 | + message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}") | ||
| 42 | + include_directories(${hclust_cpp_SOURCE_DIR}) | ||
| 43 | +endfunction() | ||
| 44 | + | ||
| 45 | +download_hclust_cpp() |
| @@ -160,6 +160,13 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -160,6 +160,13 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 160 | ) | 160 | ) |
| 161 | endif() | 161 | endif() |
| 162 | 162 | ||
| 163 | +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 164 | + list(APPEND sources | ||
| 165 | + fast-clustering-config.cc | ||
| 166 | + fast-clustering.cc | ||
| 167 | + ) | ||
| 168 | +endif() | ||
| 169 | + | ||
| 163 | if(SHERPA_ONNX_ENABLE_CHECK) | 170 | if(SHERPA_ONNX_ENABLE_CHECK) |
| 164 | list(APPEND sources log.cc) | 171 | list(APPEND sources log.cc) |
| 165 | endif() | 172 | endif() |
| @@ -523,6 +530,12 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -523,6 +530,12 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 523 | ) | 530 | ) |
| 524 | endif() | 531 | endif() |
| 525 | 532 | ||
| 533 | + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 534 | + list(APPEND sherpa_onnx_test_srcs | ||
| 535 | + fast-clustering-test.cc | ||
| 536 | + ) | ||
| 537 | + endif() | ||
| 538 | + | ||
| 526 | list(APPEND sherpa_onnx_test_srcs | 539 | list(APPEND sherpa_onnx_test_srcs |
| 527 | speaker-embedding-manager-test.cc | 540 | speaker-embedding-manager-test.cc |
| 528 | ) | 541 | ) |
sherpa-onnx/csrc/fast-clustering-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/fast-clustering-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/fast-clustering-config.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | +std::string FastClusteringConfig::ToString() const { | ||
| 14 | + std::ostringstream os; | ||
| 15 | + | ||
| 16 | + os << "FastClusteringConfig("; | ||
| 17 | + os << "num_clusters=" << num_clusters << ", "; | ||
| 18 | + os << "threshold=" << threshold << ")"; | ||
| 19 | + | ||
| 20 | + return os.str(); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +void FastClusteringConfig::Register(ParseOptions *po) { | ||
| 24 | + std::string prefix = "ctc"; | ||
| 25 | + ParseOptions p(prefix, po); | ||
| 26 | + | ||
| 27 | + p.Register("num-clusters", &num_clusters, | ||
| 28 | + "Number of cluster. If greater than 0, then --cluster-thresold is " | ||
| 29 | + "ignored"); | ||
| 30 | + | ||
| 31 | + p.Register("cluster-threshold", &threshold, | ||
| 32 | + "If --num-clusters is not specified, then it specifies the " | ||
| 33 | + "distance threshold for clustering."); | ||
| 34 | +} | ||
| 35 | + | ||
| 36 | +bool FastClusteringConfig::Validate() const { | ||
| 37 | + if (num_clusters < 1 && threshold < 0) { | ||
| 38 | + SHERPA_ONNX_LOGE("Please provide either num_clusters or threshold"); | ||
| 39 | + return false; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + return true; | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/fast-clustering-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/fast-clustering-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct FastClusteringConfig { | ||
| 15 | + // If greater than 0, then threshold is ignored | ||
| 16 | + int32_t num_clusters = -1; | ||
| 17 | + | ||
| 18 | + // distance threshold | ||
| 19 | + float threshold = 0.5; | ||
| 20 | + | ||
| 21 | + std::string ToString() const; | ||
| 22 | + | ||
| 23 | + void Register(ParseOptions *po); | ||
| 24 | + bool Validate() const; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ |
sherpa-onnx/csrc/fast-clustering-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/fast-clustering-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/fast-clustering.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +TEST(FastClustering, TestTwoClusters) { | ||
| 14 | + std::vector<float> features = { | ||
| 15 | + // point 0 | ||
| 16 | + 0.1, | ||
| 17 | + 0.1, | ||
| 18 | + // point 2 | ||
| 19 | + 0.4, | ||
| 20 | + -0.5, | ||
| 21 | + // point 3 | ||
| 22 | + 0.6, | ||
| 23 | + -0.7, | ||
| 24 | + // point 1 | ||
| 25 | + 0.2, | ||
| 26 | + 0.3, | ||
| 27 | + }; | ||
| 28 | + | ||
| 29 | + FastClusteringConfig config; | ||
| 30 | + config.num_clusters = 2; | ||
| 31 | + | ||
| 32 | + FastClustering clustering(config); | ||
| 33 | + auto labels = clustering.Cluster(features.data(), 4, 2); | ||
| 34 | + int32_t k = 0; | ||
| 35 | + for (auto i : labels) { | ||
| 36 | + std::cout << "point " << k << ": label " << i << "\n"; | ||
| 37 | + ++k; | ||
| 38 | + } | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +TEST(FastClustering, TestClusteringWithThreshold) { | ||
| 42 | + std::vector<float> features = { | ||
| 43 | + // point 0 | ||
| 44 | + 0.1, | ||
| 45 | + 0.1, | ||
| 46 | + // point 2 | ||
| 47 | + 0.4, | ||
| 48 | + -0.5, | ||
| 49 | + // point 3 | ||
| 50 | + 0.6, | ||
| 51 | + -0.7, | ||
| 52 | + // point 1 | ||
| 53 | + 0.2, | ||
| 54 | + 0.3, | ||
| 55 | + }; | ||
| 56 | + | ||
| 57 | + FastClusteringConfig config; | ||
| 58 | + config.threshold = 0.5; | ||
| 59 | + | ||
| 60 | + FastClustering clustering(config); | ||
| 61 | + auto labels = clustering.Cluster(features.data(), 4, 2); | ||
| 62 | + int32_t k = 0; | ||
| 63 | + for (auto i : labels) { | ||
| 64 | + std::cout << "point " << k << ": label " << i << "\n"; | ||
| 65 | + ++k; | ||
| 66 | + } | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/fast-clustering.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/fast-clustering.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/fast-clustering.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "Eigen/Dense" | ||
| 10 | +#include "fastcluster-all-in-one.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class FastClustering::Impl { | ||
| 15 | + public: | ||
| 16 | + explicit Impl(const FastClusteringConfig &config) : config_(config) {} | ||
| 17 | + | ||
| 18 | + std::vector<int32_t> Cluster(float *features, int32_t num_rows, | ||
| 19 | + int32_t num_cols) { | ||
| 20 | + if (num_rows <= 0) { | ||
| 21 | + return {}; | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + if (num_rows == 1) { | ||
| 25 | + return {0}; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + Eigen::Map< | ||
| 29 | + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> | ||
| 30 | + m(features, num_rows, num_cols); | ||
| 31 | + m.rowwise().normalize(); | ||
| 32 | + | ||
| 33 | + std::vector<double> distance((num_rows * (num_rows - 1)) / 2); | ||
| 34 | + | ||
| 35 | + int32_t k = 0; | ||
| 36 | + for (int32_t i = 0; i != num_rows; ++i) { | ||
| 37 | + auto v = m.row(i); | ||
| 38 | + for (int32_t j = i + 1; j != num_rows; ++j) { | ||
| 39 | + double cosine_similarity = v.dot(m.row(j)); | ||
| 40 | + double consine_dissimilarity = 1 - cosine_similarity; | ||
| 41 | + | ||
| 42 | + if (consine_dissimilarity < 0) { | ||
| 43 | + consine_dissimilarity = 0; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + distance[k] = consine_dissimilarity; | ||
| 47 | + ++k; | ||
| 48 | + } | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + std::vector<int32_t> merge(2 * (num_rows - 1)); | ||
| 52 | + std::vector<double> height(num_rows - 1); | ||
| 53 | + | ||
| 54 | + fastclustercpp::hclust_fast(num_rows, distance.data(), | ||
| 55 | + fastclustercpp::HCLUST_METHOD_SINGLE, | ||
| 56 | + merge.data(), height.data()); | ||
| 57 | + | ||
| 58 | + std::vector<int32_t> labels(num_rows); | ||
| 59 | + if (config_.num_clusters > 0) { | ||
| 60 | + fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters, | ||
| 61 | + labels.data()); | ||
| 62 | + } else { | ||
| 63 | + fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(), | ||
| 64 | + config_.threshold, labels.data()); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + return labels; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + private: | ||
| 71 | + FastClusteringConfig config_; | ||
| 72 | +}; | ||
| 73 | + | ||
| 74 | +FastClustering::FastClustering(const FastClusteringConfig &config) | ||
| 75 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 76 | + | ||
| 77 | +FastClustering::~FastClustering() = default; | ||
| 78 | + | ||
| 79 | +std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows, | ||
| 80 | + int32_t num_cols) { | ||
| 81 | + return impl_->Cluster(features, num_rows, num_cols); | ||
| 82 | +} | ||
| 83 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/fast-clustering.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/fast-clustering.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/fast-clustering-config.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class FastClustering { | ||
| 16 | + public: | ||
| 17 | + explicit FastClustering(const FastClusteringConfig &config); | ||
| 18 | + ~FastClustering(); | ||
| 19 | + | ||
| 20 | + /** | ||
| 21 | + * @param features Pointer to a 2-D feature matrix in row major. Each row | ||
| 22 | + * is a feature frame. It is changed in-place. We will | ||
| 23 | + * convert each feature frame to a normalized vector. | ||
| 24 | + * That is, the L2-norm of each vector will be equal to 1. | ||
| 25 | + * It uses cosine dissimilarity, | ||
| 26 | + * which is 1 - (cosine similarity) | ||
| 27 | + * @param num_rows Number of feature frames | ||
| 28 | + * @param num-cols The feature dimension. | ||
| 29 | + * | ||
| 30 | + * @return Return a vector of size num_rows. ans[i] contains the label | ||
| 31 | + * for the i-th feature frame, i.e., the i-th row of the feature | ||
| 32 | + * matrix. | ||
| 33 | + */ | ||
| 34 | + std::vector<int32_t> Cluster(float *features, int32_t num_rows, | ||
| 35 | + int32_t num_cols); | ||
| 36 | + | ||
| 37 | + private: | ||
| 38 | + class Impl; | ||
| 39 | + std::unique_ptr<Impl> impl_; | ||
| 40 | +}; | ||
| 41 | + | ||
| 42 | +} // namespace sherpa_onnx | ||
| 43 | +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ |
| @@ -8,16 +8,16 @@ | @@ -8,16 +8,16 @@ | ||
| 8 | namespace sherpa_onnx { | 8 | namespace sherpa_onnx { |
| 9 | 9 | ||
| 10 | struct OnlineCNNBiLSTMModelMetaData { | 10 | struct OnlineCNNBiLSTMModelMetaData { |
| 11 | - int32_t comma_id; | ||
| 12 | - int32_t period_id; | ||
| 13 | - int32_t quest_id; | 11 | + int32_t comma_id = -1; |
| 12 | + int32_t period_id = -1; | ||
| 13 | + int32_t quest_id = -1; | ||
| 14 | 14 | ||
| 15 | - int32_t upper_id; | ||
| 16 | - int32_t cap_id; | ||
| 17 | - int32_t mix_case_id; | 15 | + int32_t upper_id = -1; |
| 16 | + int32_t cap_id = -1; | ||
| 17 | + int32_t mix_case_id = -1; | ||
| 18 | 18 | ||
| 19 | - int32_t num_cases; | ||
| 20 | - int32_t num_punctuations; | 19 | + int32_t num_cases = -1; |
| 20 | + int32_t num_punctuations = -1; | ||
| 21 | }; | 21 | }; |
| 22 | 22 | ||
| 23 | } // namespace sherpa_onnx | 23 | } // namespace sherpa_onnx |
| @@ -169,7 +169,7 @@ static std::vector<int64_t> CoquiPhonemesToIds( | @@ -169,7 +169,7 @@ static std::vector<int64_t> CoquiPhonemesToIds( | ||
| 169 | return ans; | 169 | return ans; |
| 170 | } | 170 | } |
| 171 | 171 | ||
| 172 | -void InitEspeak(const std::string &data_dir) { | 172 | +static void InitEspeak(const std::string &data_dir) { |
| 173 | static std::once_flag init_flag; | 173 | static std::once_flag init_flag; |
| 174 | std::call_once(init_flag, [data_dir]() { | 174 | std::call_once(init_flag, [data_dir]() { |
| 175 | int32_t result = | 175 | int32_t result = |
| @@ -41,7 +41,7 @@ | @@ -41,7 +41,7 @@ | ||
| 41 | namespace sherpa_onnx { | 41 | namespace sherpa_onnx { |
| 42 | 42 | ||
| 43 | template <class I> | 43 | template <class I> |
| 44 | -I Gcd(I m, I n) { | 44 | +static I Gcd(I m, I n) { |
| 45 | // this function is copied from kaldi/src/base/kaldi-math.h | 45 | // this function is copied from kaldi/src/base/kaldi-math.h |
| 46 | if (m == 0 || n == 0) { | 46 | if (m == 0 || n == 0) { |
| 47 | if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. | 47 | if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. |
| @@ -65,7 +65,7 @@ I Gcd(I m, I n) { | @@ -65,7 +65,7 @@ I Gcd(I m, I n) { | ||
| 65 | /// Returns the least common multiple of two integers. Will | 65 | /// Returns the least common multiple of two integers. Will |
| 66 | /// crash unless the inputs are positive. | 66 | /// crash unless the inputs are positive. |
| 67 | template <class I> | 67 | template <class I> |
| 68 | -I Lcm(I m, I n) { | 68 | +static I Lcm(I m, I n) { |
| 69 | // This function is copied from kaldi/src/base/kaldi-math.h | 69 | // This function is copied from kaldi/src/base/kaldi-math.h |
| 70 | assert(m > 0 && n > 0); | 70 | assert(m > 0 && n > 0); |
| 71 | I gcd = Gcd(m, n); | 71 | I gcd = Gcd(m, n); |
-
请 注册 或 登录 后发表评论