Fangjun Kuang
Committed by GitHub

Add C++ and Python support for ten-vad (#2377)

This PR adds support for the TEN VAD model alongside the existing Silero VAD in both C++ and Python interfaces.

- Introduces TenVadModelConfig with Python bindings and integrates it into VadModelConfig.
- Implements TenVadModel in C++ and extends the factory (VadModel::Create) and detector logic to choose between Silero and TEN VAD.
- Updates build files (CMake), fixes a spelling typo, and extends the Python example script to demonstrate --ten-vad-model.
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.2.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.2.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=f4bd7d53fe8aeaecc4eda9680c72696bb86bf74e86371d81aacacd6f4ca3914d") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.3.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.3.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=d409eddae5a46dc796f0841880f489ff0728b96ae26218702cd438c28667c70e")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.2.tar.gz  
16 - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.2.tar.gz  
17 - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.2.tar.gz  
18 - /tmp/kaldi-native-fbank-1.21.2.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.2.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.3.tar.gz
  16 + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.3.tar.gz
  17 + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.3.tar.gz
  18 + /tmp/kaldi-native-fbank-1.21.3.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.3.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
1 // cxx-api-examples/zipformer-transducer-simulate-streaming-microphone-cxx-api.cc 1 // cxx-api-examples/zipformer-transducer-simulate-streaming-microphone-cxx-api.cc
2 // Copyright (c) 2025 Xiaomi Corporation 2 // Copyright (c) 2025 Xiaomi Corporation
3 // 3 //
4 -// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++ API  
5 -// for streaming speech recognition from a microphone. 4 +// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++
  5 +// API for streaming speech recognition from a microphone.
6 // 6 //
7 // clang-format off 7 // clang-format off
8 // 8 //
@@ -19,6 +19,12 @@ For instance, @@ -19,6 +19,12 @@ For instance,
19 19
20 wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx 20 wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
21 21
  22 +or download ten-vad.onnx, for instance
  23 +
  24 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
  25 +
  26 +Please replace --silero-vad-model with --ten-vad-model below to use ten-vad.
  27 +
22 (1) For paraformer 28 (1) For paraformer
23 29
24 ./python-api-examples/generate-subtitles.py \ 30 ./python-api-examples/generate-subtitles.py \
@@ -124,8 +130,13 @@ def get_args(): @@ -124,8 +130,13 @@ def get_args():
124 parser.add_argument( 130 parser.add_argument(
125 "--silero-vad-model", 131 "--silero-vad-model",
126 type=str, 132 type=str,
127 - required=True,  
128 - help="Path to silero_vad.onnx", 133 + help="Path to silero_vad.onnx.",
  134 + )
  135 +
  136 + parser.add_argument(
  137 + "--ten-vad-model",
  138 + type=str,
  139 + help="Path to ten-vad.onnx",
129 ) 140 )
130 141
131 parser.add_argument( 142 parser.add_argument(
@@ -499,7 +510,12 @@ class Segment: @@ -499,7 +510,12 @@ class Segment:
499 def main(): 510 def main():
500 args = get_args() 511 args = get_args()
501 assert_file_exists(args.tokens) 512 assert_file_exists(args.tokens)
502 - assert_file_exists(args.silero_vad_model) 513 + if args.silero_vad_model:
  514 + assert_file_exists(args.silero_vad_model)
  515 + elif args.ten_vad_model:
  516 + assert_file_exists(args.ten_vad_model)
  517 + else:
  518 + raise ValueError("You need to supply one vad model")
503 519
504 assert args.num_threads > 0, args.num_threads 520 assert args.num_threads > 0, args.num_threads
505 521
@@ -536,18 +552,34 @@ def main(): @@ -536,18 +552,34 @@ def main():
536 stream = recognizer.create_stream() 552 stream = recognizer.create_stream()
537 553
538 config = sherpa_onnx.VadModelConfig() 554 config = sherpa_onnx.VadModelConfig()
539 - config.silero_vad.model = args.silero_vad_model  
540 - config.silero_vad.threshold = 0.5  
541 - config.silero_vad.min_silence_duration = 0.25 # seconds  
542 - config.silero_vad.min_speech_duration = 0.25 # seconds  
543 -  
544 - # If the current segment is larger than this value, then it increases  
545 - # the threshold to 0.9 internally. After detecting this segment,  
546 - # it resets the threshold to its original value.  
547 - config.silero_vad.max_speech_duration = 5 # seconds  
548 - config.sample_rate = args.sample_rate  
549 -  
550 - window_size = config.silero_vad.window_size 555 + if args.silero_vad_model:
  556 + config.silero_vad.model = args.silero_vad_model
  557 + config.silero_vad.threshold = 0.2
  558 + config.silero_vad.min_silence_duration = 0.25 # seconds
  559 + config.silero_vad.min_speech_duration = 0.25 # seconds
  560 +
  561 + # If the current segment is larger than this value, then it increases
  562 + # the threshold to 0.9 internally. After detecting this segment,
  563 + # it resets the threshold to its original value.
  564 + config.silero_vad.max_speech_duration = 5 # seconds
  565 + config.sample_rate = args.sample_rate
  566 +
  567 + window_size = config.silero_vad.window_size
  568 + print("use silero-vad")
  569 + else:
  570 + config.ten_vad.model = args.ten_vad_model
  571 + config.ten_vad.threshold = 0.2
  572 + config.ten_vad.min_silence_duration = 0.25 # seconds
  573 + config.ten_vad.min_speech_duration = 0.25 # seconds
  574 +
  575 + # If the current segment is larger than this value, then it increases
  576 + # the threshold to 0.9 internally. After detecting this segment,
  577 + # it resets the threshold to its original value.
  578 + config.ten_vad.max_speech_duration = 5 # seconds
  579 + config.sample_rate = args.sample_rate
  580 +
  581 + window_size = config.ten_vad.window_size
  582 + print("use ten-vad")
551 583
552 buffer = [] 584 buffer = []
553 vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100) 585 vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100)
@@ -123,6 +123,8 @@ set(sources @@ -123,6 +123,8 @@ set(sources
123 spoken-language-identification.cc 123 spoken-language-identification.cc
124 stack.cc 124 stack.cc
125 symbol-table.cc 125 symbol-table.cc
  126 + ten-vad-model-config.cc
  127 + ten-vad-model.cc
126 text-utils.cc 128 text-utils.cc
127 transducer-keyword-decoder.cc 129 transducer-keyword-decoder.cc
128 transpose.cc 130 transpose.cc
@@ -40,7 +40,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) { @@ -40,7 +40,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) {
40 "to the silero VAD model. WARNING! Silero VAD models were trained using " 40 "to the silero VAD model. WARNING! Silero VAD models were trained using "
41 "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples " 41 "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples "
42 "for 8000 sample rate. Values other than these may affect model " 42 "for 8000 sample rate. Values other than these may affect model "
43 - "perfomance!"); 43 + "performance!");
44 } 44 }
45 45
46 bool SileroVadModelConfig::Validate() const { 46 bool SileroVadModelConfig::Validate() const {
@@ -24,7 +24,6 @@ struct SileroVadModelConfig { @@ -24,7 +24,6 @@ struct SileroVadModelConfig {
24 float min_speech_duration = 0.25; // in seconds 24 float min_speech_duration = 0.25; // in seconds
25 25
26 // 512, 1024, 1536 samples for 16000 Hz 26 // 512, 1024, 1536 samples for 16000 Hz
27 - // 256, 512, 768 samples for 800 Hz  
28 int32_t window_size = 512; // in samples 27 int32_t window_size = 512; // in samples
29 28
30 // If a speech segment is longer than this value, then we increase 29 // If a speech segment is longer than this value, then we increase
  1 +// sherpa-onnx/csrc/ten-vad-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/ten-vad-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void TenVadModelConfig::Register(ParseOptions *po) {
  13 + po->Register("ten-vad-model", &model, "Path to TEN VAD ONNX model.");
  14 +
  15 + po->Register("ten-vad-threshold", &threshold,
  16 + "Speech threshold. TEN VAD outputs speech probabilities for "
  17 + "each audio chunk, probabilities ABOVE this value are "
  18 + "considered as SPEECH. It is better to tune this parameter for "
  19 + "each dataset separately, but lazy "
  20 + "0.5 is pretty good for most datasets.");
  21 +
  22 + po->Register("ten-vad-min-silence-duration", &min_silence_duration,
  23 + "In seconds. In the end of each speech chunk wait for "
  24 + "--ten-vad-min-silence-duration seconds before separating it");
  25 +
  26 + po->Register("ten-vad-min-speech-duration", &min_speech_duration,
  27 + "In seconds. In the end of each silence chunk wait for "
  28 + "--ten-vad-min-speech-duration seconds before separating it");
  29 +
  30 + po->Register(
  31 + "ten-vad-max-speech-duration", &max_speech_duration,
  32 + "In seconds. If a speech segment is longer than this value, then we "
  33 + "increase the threshold to 0.9. After finishing detecting the segment, "
  34 + "the threshold value is reset to its original value.");
  35 +
  36 + po->Register(
  37 + "ten-vad-window-size", &window_size,
  38 + "In samples. Audio chunks of --ten-vad-window-size samples are fed "
  39 + "to the ten VAD model. WARNING! Please use 160 or 256 ");
  40 +}
  41 +
  42 +bool TenVadModelConfig::Validate() const {
  43 + if (model.empty()) {
  44 + SHERPA_ONNX_LOGE("Please provide --ten-vad-model");
  45 + return false;
  46 + }
  47 +
  48 + if (!FileExists(model)) {
  49 + SHERPA_ONNX_LOGE("TEN vad model file '%s' does not exist", model.c_str());
  50 + return false;
  51 + }
  52 +
  53 + if (threshold < 0.01) {
  54 + SHERPA_ONNX_LOGE(
  55 + "Please use a larger value for --ten-vad-threshold. Given: %f",
  56 + threshold);
  57 + return false;
  58 + }
  59 +
  60 + if (threshold >= 1) {
  61 + SHERPA_ONNX_LOGE(
  62 + "Please use a smaller value for --ten-vad-threshold. Given: %f",
  63 + threshold);
  64 + return false;
  65 + }
  66 +
  67 + if (min_silence_duration <= 0) {
  68 + SHERPA_ONNX_LOGE(
  69 + "Please use a larger value for --ten-vad-min-silence-duration. "
  70 + "Given: "
  71 + "%f",
  72 + min_silence_duration);
  73 + return false;
  74 + }
  75 +
  76 + if (min_speech_duration <= 0) {
  77 + SHERPA_ONNX_LOGE(
  78 + "Please use a larger value for --ten-vad-min-speech-duration. "
  79 + "Given: "
  80 + "%f",
  81 + min_speech_duration);
  82 + return false;
  83 + }
  84 +
  85 + if (max_speech_duration <= 0) {
  86 + SHERPA_ONNX_LOGE(
  87 + "Please use a larger value for --ten-vad-max-speech-duration. "
  88 + "Given: "
  89 + "%f",
  90 + max_speech_duration);
  91 + return false;
  92 + }
  93 +
  94 + return true;
  95 +}
  96 +
  97 +std::string TenVadModelConfig::ToString() const {
  98 + std::ostringstream os;
  99 +
  100 + os << "TenVadModelConfig(";
  101 + os << "model=\"" << model << "\", ";
  102 + os << "threshold=" << threshold << ", ";
  103 + os << "min_silence_duration=" << min_silence_duration << ", ";
  104 + os << "min_speech_duration=" << min_speech_duration << ", ";
  105 + os << "max_speech_duration=" << max_speech_duration << ", ";
  106 + os << "window_size=" << window_size << ")";
  107 +
  108 + return os.str();
  109 +}
  110 +
  111 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/ten-vad-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct TenVadModelConfig {
  14 + std::string model;
  15 +
  16 + // threshold to classify a segment as speech
  17 + //
  18 + // If the predicted probability of a segment is larger than this
  19 + // value, then it is classified as speech.
  20 + float threshold = 0.5;
  21 +
  22 + float min_silence_duration = 0.5; // in seconds
  23 +
  24 + float min_speech_duration = 0.25; // in seconds
  25 +
  26 + // 160 or 256
  27 + int32_t window_size = 256; // in samples
  28 +
  29 + // If a speech segment is longer than this value, then we increase
  30 + // the threshold to 0.9. After finishing detecting the segment,
  31 + // the threshold value is reset to its original value.
  32 + float max_speech_duration = 20; // in seconds
  33 +
  34 + TenVadModelConfig() = default;
  35 +
  36 + void Register(ParseOptions *po);
  37 +
  38 + bool Validate() const;
  39 +
  40 + std::string ToString() const;
  41 +};
  42 +
  43 +} // namespace sherpa_onnx
  44 +
  45 +#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/ten-vad-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/ten-vad-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <cstring>
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
  20 +#if __OHOS__
  21 +#include "rawfile/raw_file_manager.h"
  22 +#endif
  23 +
  24 +#include "kaldi-native-fbank/csrc/mel-computations.h"
  25 +#include "kaldi-native-fbank/csrc/rfft.h"
  26 +#include "sherpa-onnx/csrc/file-utils.h"
  27 +#include "sherpa-onnx/csrc/macros.h"
  28 +#include "sherpa-onnx/csrc/onnx-utils.h"
  29 +#include "sherpa-onnx/csrc/session.h"
  30 +#include "sherpa-onnx/csrc/text-utils.h"
  31 +
  32 +namespace sherpa_onnx {
  33 +
  34 +class TenVadModel::Impl {
  35 + public:
  36 + explicit Impl(const VadModelConfig &config)
  37 + : config_(config),
  38 + rfft_(1024),
  39 + env_(ORT_LOGGING_LEVEL_ERROR),
  40 + sess_opts_(GetSessionOptions(config)),
  41 + allocator_{},
  42 + sample_rate_(config.sample_rate) {
  43 + auto buf = ReadFile(config.ten_vad.model);
  44 + Init(buf.data(), buf.size());
  45 + }
  46 +
  47 + template <typename Manager>
  48 + Impl(Manager *mgr, const VadModelConfig &config)
  49 + : config_(config),
  50 + rfft_(1024),
  51 + env_(ORT_LOGGING_LEVEL_ERROR),
  52 + sess_opts_(GetSessionOptions(config)),
  53 + allocator_{},
  54 + sample_rate_(config.sample_rate) {
  55 + auto buf = ReadFile(mgr, config.ten_vad.model);
  56 + Init(buf.data(), buf.size());
  57 + }
  58 +
  59 + void Reset() {
  60 + triggered_ = false;
  61 + current_sample_ = 0;
  62 + temp_start_ = 0;
  63 + temp_end_ = 0;
  64 +
  65 + last_sample_ = 0;
  66 +
  67 + last_features_.resize(3 * 41);
  68 + std::fill(last_features_.begin(), last_features_.end(), 0.0f);
  69 + tmp_samples_.resize(1024);
  70 +
  71 + ResetStates();
  72 + }
  73 +
  74 + bool IsSpeech(const float *samples, int32_t n) {
  75 + if (n != WindowSize()) {
  76 + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize());
  77 + SHERPA_ONNX_EXIT(-1);
  78 + }
  79 +
  80 + float prob = Run(samples, n);
  81 +
  82 + float threshold = config_.ten_vad.threshold;
  83 +
  84 + current_sample_ += config_.ten_vad.window_size;
  85 +
  86 + if (prob > threshold && temp_end_ != 0) {
  87 + temp_end_ = 0;
  88 + }
  89 +
  90 + if (prob > threshold && temp_start_ == 0) {
  91 + // start speaking, but we require that it must satisfy
  92 + // min_speech_duration
  93 + temp_start_ = current_sample_;
  94 + return false;
  95 + }
  96 +
  97 + if (prob > threshold && temp_start_ != 0 && !triggered_) {
  98 + if (current_sample_ - temp_start_ < min_speech_samples_) {
  99 + return false;
  100 + }
  101 +
  102 + triggered_ = true;
  103 +
  104 + return true;
  105 + }
  106 +
  107 + if ((prob < threshold) && !triggered_) {
  108 + // silence
  109 + temp_start_ = 0;
  110 + temp_end_ = 0;
  111 + return false;
  112 + }
  113 +
  114 + if ((prob > threshold - 0.15) && triggered_) {
  115 + // speaking
  116 + return true;
  117 + }
  118 +
  119 + if ((prob > threshold) && !triggered_) {
  120 + // start speaking
  121 + triggered_ = true;
  122 +
  123 + return true;
  124 + }
  125 +
  126 + if ((prob < threshold) && triggered_) {
  127 + // stop to speak
  128 + if (temp_end_ == 0) {
  129 + temp_end_ = current_sample_;
  130 + }
  131 +
  132 + if (current_sample_ - temp_end_ < min_silence_samples_) {
  133 + // continue speaking
  134 + return true;
  135 + }
  136 + // stopped speaking
  137 + temp_start_ = 0;
  138 + temp_end_ = 0;
  139 + triggered_ = false;
  140 + return false;
  141 + }
  142 +
  143 + return false;
  144 + }
  145 +
  146 + int32_t WindowShift() const { return config_.ten_vad.window_size; }
  147 +
  148 + int32_t WindowSize() const { return config_.ten_vad.window_size; }
  149 +
  150 + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
  151 +
  152 + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
  153 +
  154 + void SetMinSilenceDuration(float s) {
  155 + min_silence_samples_ = sample_rate_ * s;
  156 + }
  157 +
  158 + void SetThreshold(float threshold) { config_.ten_vad.threshold = threshold; }
  159 +
  160 + private:
  161 + void Init(void *model_data, size_t model_data_length) {
  162 + if (sample_rate_ != 16000) {
  163 + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
  164 + config_.sample_rate);
  165 + SHERPA_ONNX_EXIT(-1);
  166 + }
  167 +
  168 + if (config_.ten_vad.window_size > 768) {
  169 + SHERPA_ONNX_LOGE("Windows size %d for ten-vad is too large",
  170 + config_.ten_vad.window_size);
  171 + SHERPA_ONNX_EXIT(-1);
  172 + }
  173 +
  174 + min_silence_samples_ = sample_rate_ * config_.ten_vad.min_silence_duration;
  175 +
  176 + min_speech_samples_ = sample_rate_ * config_.ten_vad.min_speech_duration;
  177 +
  178 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  179 + sess_opts_);
  180 +
  181 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  182 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  183 +
  184 + InitMelBanks();
  185 +
  186 + Check();
  187 +
  188 + Reset();
  189 + }
  190 +
  191 + void ResetStates() {
  192 + std::array<int64_t, 2> shape{1, 64};
  193 +
  194 + states_.clear();
  195 + states_.reserve(4);
  196 + for (int32_t i = 0; i != 4; ++i) {
  197 + Ort::Value s = Ort::Value::CreateTensor<float>(allocator_, shape.data(),
  198 + shape.size());
  199 +
  200 + Fill<float>(&s, 0);
  201 + states_.push_back(std::move(s));
  202 + }
  203 + }
  204 +
  205 + void InitMelBanks() {
  206 + knf::FrameExtractionOptions frame_opts;
  207 +
  208 + // 16 kHz, so num_fft is 16000*64/1000 = 1024
  209 + frame_opts.frame_length_ms = 64;
  210 +
  211 + knf::MelBanksOptions mel_opts;
  212 + mel_opts.is_librosa = true;
  213 + mel_opts.norm = "";
  214 + mel_opts.use_slaney_mel_scale = true;
  215 + mel_opts.floor_to_int_bin = true;
  216 + mel_opts.low_freq = 0;
  217 + mel_opts.high_freq = 8000;
  218 + mel_opts.num_bins = 40;
  219 +
  220 + mel_banks_ = std::make_unique<knf::MelBanks>(mel_opts, frame_opts, 1.0f);
  221 +
  222 + features_.resize(41);
  223 + }
  224 +
  225 + void Check() {
  226 + // get meta data
  227 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  228 + if (config_.debug) {
  229 + std::ostringstream os;
  230 + os << "---ten-vad---\n";
  231 + PrintModelMetadata(os, meta_data);
  232 +#if __OHOS__
  233 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  234 +#else
  235 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  236 +#endif
  237 + }
  238 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  239 +
  240 + std::string model_type;
  241 + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(model_type, "model_type");
  242 +
  243 + if (model_type.empty()) {
  244 + SHERPA_ONNX_LOGE(
  245 + "Please download ten-vad.onnx or ten-vad.int8.onnx from\n"
  246 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models"
  247 + "\nWe have added meta data to the original ten-vad.onnx from\n"
  248 + "https://github.com/TEN-framework/ten-vad");
  249 + SHERPA_ONNX_EXIT(-1);
  250 + }
  251 +
  252 + if (model_type != "ten-vad") {
  253 + SHERPA_ONNX_LOGE("Expect model type 'ten-vad', given '%s'",
  254 + model_type.c_str());
  255 + SHERPA_ONNX_EXIT(-1);
  256 + }
  257 +
  258 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(mean_, "mean");
  259 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
  260 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(window_, "window");
  261 +
  262 + if (mean_.size() != 41) {
  263 + SHERPA_ONNX_LOGE(
  264 + "Incorrect size of the mean vector. Given %d, expected 41",
  265 + static_cast<int32_t>(mean_.size()));
  266 + SHERPA_ONNX_EXIT(-1);
  267 + }
  268 +
  269 + if (inv_stddev_.size() != 41) {
  270 + SHERPA_ONNX_LOGE(
  271 + "Incorrect size of the inv_stddev vector. Given %d, expected 41",
  272 + static_cast<int32_t>(inv_stddev_.size()));
  273 + SHERPA_ONNX_EXIT(-1);
  274 + }
  275 +
  276 + if (window_.size() != 768) {
  277 + SHERPA_ONNX_LOGE(
  278 + "Incorrect size of the window vector. Given %d, expected 768",
  279 + static_cast<int32_t>(window_.size()));
  280 + SHERPA_ONNX_EXIT(-1);
  281 + }
  282 + }
  283 +
  284 + static void Scale(const float *samples, int32_t n, float *out) {
  285 + for (int32_t i = 0; i != n; ++i) {
  286 + out[i] = samples[i] * 32768;
  287 + }
  288 + }
  289 +
  290 + void Preemphasis(const float *samples, int32_t n, float *out) {
  291 + float t = samples[n - 1];
  292 +
  293 + for (int32_t i = n - 1; i > 0; --i) {
  294 + out[i] = samples[i] - 0.97 * samples[i - 1];
  295 + }
  296 +
  297 + out[0] = samples[0] - 0.97 * last_sample_;
  298 +
  299 + last_sample_ = t;
  300 + }
  301 +
  302 + static void ApplyWindow(const float *samples, const float *window, int32_t n,
  303 + float *out) {
  304 + for (int32_t i = 0; i != n; ++i) {
  305 + out[i] = samples[i] * window[i];
  306 + }
  307 + }
  308 +
  309 + static void ComputePowerSpectrum(const float *fft_bins, int32_t n,
  310 + float *out) {
  311 + out[0] = fft_bins[0] * fft_bins[0];
  312 + out[n - 1] = fft_bins[1] * fft_bins[1];
  313 +
  314 + for (int32_t i = 1; i < n / 2; ++i) {
  315 + float real = fft_bins[2 * i];
  316 + float imag = fft_bins[2 * i + 1];
  317 + out[i] = real * real + imag * imag;
  318 + }
  319 + }
  320 +
  321 + static void LogMel(const float *in, int32_t n, float *out) {
  322 + for (int32_t i = 0; i != n; ++i) {
  323 + // 20.79441541679836 is log(32768*32768)
  324 + out[i] = std::logf(in[i] + 1e-10) - 20.79441541679836f;
  325 + }
  326 + }
  327 +
  328 + void ApplyNormalization(const float *in, float *out) const {
  329 + for (int32_t i = 0; i != static_cast<int32_t>(mean_.size()); ++i) {
  330 + out[i] = (in[i] - mean_[i]) * inv_stddev_[i];
  331 + }
  332 + }
  333 +
  334 + void ComputeFeatures(const float *samples, int32_t n) {
  335 + std::fill(tmp_samples_.begin() + n, tmp_samples_.end(), 0.0f);
  336 +
  337 + Scale(samples, n, tmp_samples_.data());
  338 +
  339 + Preemphasis(tmp_samples_.data(), n, tmp_samples_.data());
  340 + ApplyWindow(tmp_samples_.data(), window_.data(), n, tmp_samples_.data());
  341 +
  342 + rfft_.Compute(tmp_samples_.data());
  343 + auto &power_spectrum = tmp_samples_;
  344 + ComputePowerSpectrum(tmp_samples_.data(), tmp_samples_.size(),
  345 + power_spectrum.data());
  346 +
  347 + // note only the first half of power_spectrum is used inside Compute()
  348 + mel_banks_->Compute(power_spectrum.data(), features_.data());
  349 + LogMel(features_.data(), static_cast<int32_t>(features_.size()) - 1,
  350 + features_.data());
  351 +
  352 + // Note(fangjun): The ten-vad model expects a pitch feature, but we set it
  353 + // to 0 as a simplification. This may reduce performance as noted
  354 + // in the PR #2377
  355 + features_.back() = 0;
  356 +
  357 + ApplyNormalization(features_.data(), features_.data());
  358 +
  359 + std::memmove(last_features_.data(),
  360 + last_features_.data() + features_.size(),
  361 + 2 * features_.size() * sizeof(float));
  362 + std::copy(features_.begin(), features_.end(),
  363 + last_features_.begin() + 2 * features_.size());
  364 + }
  365 +
  366 + float Run(const float *samples, int32_t n) {
  367 + ComputeFeatures(samples, n);
  368 +
  369 + auto memory_info =
  370 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  371 +
  372 + std::array<int64_t, 3> x_shape = {1, 3, 41};
  373 +
  374 + Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),
  375 + last_features_.size(),
  376 + x_shape.data(), x_shape.size());
  377 +
  378 + std::vector<Ort::Value> inputs;
  379 + inputs.reserve(input_names_.size());
  380 +
  381 + inputs.push_back(std::move(x));
  382 + for (auto &s : states_) {
  383 + inputs.push_back(std::move(s));
  384 + }
  385 +
  386 + auto out =
  387 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  388 + output_names_ptr_.data(), output_names_ptr_.size());
  389 +
  390 + for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {
  391 + states_[i - 1] = std::move(out[i]);
  392 + }
  393 +
  394 + float prob = out[0].GetTensorData<float>()[0];
  395 +
  396 + return prob;
  397 + }
  398 +
  399 + private:
  400 + VadModelConfig config_;
  401 + knf::Rfft rfft_;
  402 + std::unique_ptr<knf::MelBanks> mel_banks_;
  403 +
  404 + Ort::Env env_;
  405 + Ort::SessionOptions sess_opts_;
  406 + Ort::AllocatorWithDefaultOptions allocator_;
  407 +
  408 + std::unique_ptr<Ort::Session> sess_;
  409 +
  410 + std::vector<std::string> input_names_;
  411 + std::vector<const char *> input_names_ptr_;
  412 +
  413 + std::vector<std::string> output_names_;
  414 + std::vector<const char *> output_names_ptr_;
  415 +
  416 + std::vector<Ort::Value> states_;
  417 + int64_t sample_rate_;
  418 + int32_t min_silence_samples_;
  419 + int32_t min_speech_samples_;
  420 +
  421 + bool triggered_ = false;
  422 + int32_t current_sample_ = 0;
  423 + int32_t temp_start_ = 0;
  424 + int32_t temp_end_ = 0;
  425 +
  426 + float last_sample_ = 0;
  427 +
  428 + std::vector<float> mean_;
  429 + std::vector<float> inv_stddev_;
  430 + std::vector<float> window_;
  431 +
  432 + std::vector<float> features_;
  433 + std::vector<float> last_features_; // (3, 41), row major
  434 + std::vector<float> tmp_samples_; // (1024,)
  435 +};
  436 +
  437 +TenVadModel::TenVadModel(const VadModelConfig &config)
  438 + : impl_(std::make_unique<Impl>(config)) {}
  439 +
  440 +template <typename Manager>
  441 +TenVadModel::TenVadModel(Manager *mgr, const VadModelConfig &config)
  442 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  443 +
  444 +TenVadModel::~TenVadModel() = default;
  445 +
  446 +void TenVadModel::Reset() { return impl_->Reset(); }
  447 +
  448 +bool TenVadModel::IsSpeech(const float *samples, int32_t n) {
  449 + return impl_->IsSpeech(samples, n);
  450 +}
  451 +
  452 +int32_t TenVadModel::WindowSize() const { return impl_->WindowSize(); }
  453 +
  454 +int32_t TenVadModel::WindowShift() const { return impl_->WindowShift(); }
  455 +
  456 +int32_t TenVadModel::MinSilenceDurationSamples() const {
  457 + return impl_->MinSilenceDurationSamples();
  458 +}
  459 +
  460 +int32_t TenVadModel::MinSpeechDurationSamples() const {
  461 + return impl_->MinSpeechDurationSamples();
  462 +}
  463 +
  464 +void TenVadModel::SetMinSilenceDuration(float s) {
  465 + impl_->SetMinSilenceDuration(s);
  466 +}
  467 +
  468 +void TenVadModel::SetThreshold(float threshold) {
  469 + impl_->SetThreshold(threshold);
  470 +}
  471 +
  472 +#if __ANDROID_API__ >= 9
  473 +template TenVadModel::TenVadModel(AAssetManager *mgr,
  474 + const VadModelConfig &config);
  475 +#endif
  476 +
  477 +#if __OHOS__
  478 +template TenVadModel::TenVadModel(NativeResourceManager *mgr,
  479 + const VadModelConfig &config);
  480 +#endif
  481 +
  482 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/ten-vad-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
  6 +
  7 +#include <memory>
  8 +
  9 +#include "sherpa-onnx/csrc/vad-model.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +class TenVadModel : public VadModel {
  14 + public:
  15 + explicit TenVadModel(const VadModelConfig &config);
  16 +
  17 + template <typename Manager>
  18 + TenVadModel(Manager *mgr, const VadModelConfig &config);
  19 +
  20 + ~TenVadModel() override;
  21 +
  22 + // reset the internal model states
  23 + void Reset() override;
  24 +
  25 + /**
  26 + * @param samples Pointer to a 1-d array containing audio samples.
  27 + * Each sample should be normalized to the range [-1, 1].
  28 + * @param n Number of samples.
  29 + *
  30 + * @return Return true if speech is detected. Return false otherwise.
  31 + */
  32 + bool IsSpeech(const float *samples, int32_t n) override;
  33 +
  34 + // 256 or 160
  35 + int32_t WindowSize() const override;
  36 +
  37 + // 256 or 128
  38 + int32_t WindowShift() const override;
  39 +
  40 + int32_t MinSilenceDurationSamples() const override;
  41 + int32_t MinSpeechDurationSamples() const override;
  42 +
  43 + void SetMinSilenceDuration(float s) override;
  44 + void SetThreshold(float threshold) override;
  45 +
  46 + private:
  47 + class Impl;
  48 + std::unique_ptr<Impl> impl_;
  49 +};
  50 +
  51 +} // namespace sherpa_onnx
  52 +
  53 +#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
@@ -10,9 +10,9 @@ namespace sherpa_onnx { @@ -10,9 +10,9 @@ namespace sherpa_onnx {
10 /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). 10 /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
11 * 11 *
12 * @param allocator 12 * @param allocator
13 - * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. 13 + * @param v A 3-D tensor of shape (B, T, C). Its data type is type.
14 * 14 *
15 - * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type. 15 + * @return Return a 3-D tensor of shape (T, B, C). Its data type is type.
16 */ 16 */
17 template <typename type = float> 17 template <typename type = float>
18 Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); 18 Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
@@ -20,9 +20,9 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); @@ -20,9 +20,9 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
20 /** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). 20 /** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
21 * 21 *
22 * @param allocator 22 * @param allocator
23 - * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. 23 + * @param v A 3-D tensor of shape (B, T, C). Its data type is type.
24 * 24 *
25 - * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type. 25 + * @return Return a 3-D tensor of shape (B, C, T). Its data type is type.
26 */ 26 */
27 template <typename type = float> 27 template <typename type = float>
28 Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v); 28 Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
@@ -14,6 +14,7 @@ namespace sherpa_onnx { @@ -14,6 +14,7 @@ namespace sherpa_onnx {
14 14
15 void VadModelConfig::Register(ParseOptions *po) { 15 void VadModelConfig::Register(ParseOptions *po) {
16 silero_vad.Register(po); 16 silero_vad.Register(po);
  17 + ten_vad.Register(po);
17 18
18 po->Register("vad-sample-rate", &sample_rate, 19 po->Register("vad-sample-rate", &sample_rate,
19 "Sample rate expected by the VAD model"); 20 "Sample rate expected by the VAD model");
@@ -48,7 +49,17 @@ bool VadModelConfig::Validate() const { @@ -48,7 +49,17 @@ bool VadModelConfig::Validate() const {
48 } 49 }
49 } 50 }
50 51
51 - return silero_vad.Validate(); 52 + if (!silero_vad.model.empty()) {
  53 + return silero_vad.Validate();
  54 + }
  55 +
  56 + if (!ten_vad.model.empty()) {
  57 + return ten_vad.Validate();
  58 + }
  59 +
  60 + SHERPA_ONNX_LOGE("Please provide one VAD model.");
  61 +
  62 + return false;
52 } 63 }
53 64
54 std::string VadModelConfig::ToString() const { 65 std::string VadModelConfig::ToString() const {
@@ -56,6 +67,7 @@ std::string VadModelConfig::ToString() const { @@ -56,6 +67,7 @@ std::string VadModelConfig::ToString() const {
56 67
57 os << "VadModelConfig("; 68 os << "VadModelConfig(";
58 os << "silero_vad=" << silero_vad.ToString() << ", "; 69 os << "silero_vad=" << silero_vad.ToString() << ", ";
  70 + os << "ten_vad=" << ten_vad.ToString() << ", ";
59 os << "sample_rate=" << sample_rate << ", "; 71 os << "sample_rate=" << sample_rate << ", ";
60 os << "num_threads=" << num_threads << ", "; 72 os << "num_threads=" << num_threads << ", ";
61 os << "provider=\"" << provider << "\", "; 73 os << "provider=\"" << provider << "\", ";
@@ -8,11 +8,13 @@ @@ -8,11 +8,13 @@
8 8
9 #include "sherpa-onnx/csrc/parse-options.h" 9 #include "sherpa-onnx/csrc/parse-options.h"
10 #include "sherpa-onnx/csrc/silero-vad-model-config.h" 10 #include "sherpa-onnx/csrc/silero-vad-model-config.h"
  11 +#include "sherpa-onnx/csrc/ten-vad-model-config.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
14 struct VadModelConfig { 15 struct VadModelConfig {
15 SileroVadModelConfig silero_vad; 16 SileroVadModelConfig silero_vad;
  17 + TenVadModelConfig ten_vad;
16 18
17 int32_t sample_rate = 16000; 19 int32_t sample_rate = 16000;
18 int32_t num_threads = 1; 20 int32_t num_threads = 1;
@@ -23,9 +25,11 @@ struct VadModelConfig { @@ -23,9 +25,11 @@ struct VadModelConfig {
23 25
24 VadModelConfig() = default; 26 VadModelConfig() = default;
25 27
26 - VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t sample_rate, 28 + VadModelConfig(const SileroVadModelConfig &silero_vad,
  29 + const TenVadModelConfig &ten_vad, int32_t sample_rate,
27 int32_t num_threads, const std::string &provider, bool debug) 30 int32_t num_threads, const std::string &provider, bool debug)
28 : silero_vad(silero_vad), 31 : silero_vad(silero_vad),
  32 + ten_vad(ten_vad),
29 sample_rate(sample_rate), 33 sample_rate(sample_rate),
30 num_threads(num_threads), 34 num_threads(num_threads),
31 provider(provider), 35 provider(provider),
@@ -19,13 +19,19 @@ @@ -19,13 +19,19 @@
19 19
20 #include "sherpa-onnx/csrc/macros.h" 20 #include "sherpa-onnx/csrc/macros.h"
21 #include "sherpa-onnx/csrc/silero-vad-model.h" 21 #include "sherpa-onnx/csrc/silero-vad-model.h"
  22 +#include "sherpa-onnx/csrc/ten-vad-model.h"
22 23
23 namespace sherpa_onnx { 24 namespace sherpa_onnx {
24 25
25 std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { 26 std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
26 if (config.provider == "rknn") { 27 if (config.provider == "rknn") {
27 #if SHERPA_ONNX_ENABLE_RKNN 28 #if SHERPA_ONNX_ENABLE_RKNN
28 - return std::make_unique<SileroVadModelRknn>(config); 29 + if (!config.silero_vad.model.empty()) {
  30 + return std::make_unique<SileroVadModelRknn>(config);
  31 + } else {
  32 + SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present");
  33 + SHERPA_ONNX_EXIT(-1);
  34 + }
29 #else 35 #else
30 SHERPA_ONNX_LOGE( 36 SHERPA_ONNX_LOGE(
31 "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " 37 "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
@@ -34,7 +40,17 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { @@ -34,7 +40,17 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
34 return nullptr; 40 return nullptr;
35 #endif 41 #endif
36 } 42 }
37 - return std::make_unique<SileroVadModel>(config); 43 +
  44 + if (!config.silero_vad.model.empty()) {
  45 + return std::make_unique<SileroVadModel>(config);
  46 + }
  47 +
  48 + if (!config.ten_vad.model.empty()) {
  49 + return std::make_unique<TenVadModel>(config);
  50 + }
  51 +
  52 + SHERPA_ONNX_LOGE("Please provide a vad model");
  53 + return nullptr;
38 } 54 }
39 55
40 template <typename Manager> 56 template <typename Manager>
@@ -42,7 +58,12 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, @@ -42,7 +58,12 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
42 const VadModelConfig &config) { 58 const VadModelConfig &config) {
43 if (config.provider == "rknn") { 59 if (config.provider == "rknn") {
44 #if SHERPA_ONNX_ENABLE_RKNN 60 #if SHERPA_ONNX_ENABLE_RKNN
45 - return std::make_unique<SileroVadModelRknn>(mgr, config); 61 + if (!config.silero_vad.model.empty()) {
  62 + return std::make_unique<SileroVadModelRknn>(mgr, config);
  63 + } else {
  64 + SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present");
  65 + SHERPA_ONNX_EXIT(-1);
  66 + }
46 #else 67 #else
47 SHERPA_ONNX_LOGE( 68 SHERPA_ONNX_LOGE(
48 "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " 69 "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
@@ -51,7 +72,16 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, @@ -51,7 +72,16 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
51 return nullptr; 72 return nullptr;
52 #endif 73 #endif
53 } 74 }
54 - return std::make_unique<SileroVadModel>(mgr, config); 75 + if (!config.silero_vad.model.empty()) {
  76 + return std::make_unique<SileroVadModel>(mgr, config);
  77 + }
  78 +
  79 + if (!config.ten_vad.model.empty()) {
  80 + return std::make_unique<TenVadModel>(mgr, config);
  81 + }
  82 +
  83 + SHERPA_ONNX_LOGE("Please provide a vad model");
  84 + return nullptr;
55 } 85 }
56 86
57 #if __ANDROID_API__ >= 9 87 #if __ANDROID_API__ >= 9
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 #endif 18 #endif
19 19
20 #include "sherpa-onnx/csrc/circular-buffer.h" 20 #include "sherpa-onnx/csrc/circular-buffer.h"
  21 +#include "sherpa-onnx/csrc/macros.h"
21 #include "sherpa-onnx/csrc/vad-model.h" 22 #include "sherpa-onnx/csrc/vad-model.h"
22 23
23 namespace sherpa_onnx { 24 namespace sherpa_onnx {
@@ -45,8 +46,16 @@ class VoiceActivityDetector::Impl { @@ -45,8 +46,16 @@ class VoiceActivityDetector::Impl {
45 model_->SetMinSilenceDuration(new_min_silence_duration_s_); 46 model_->SetMinSilenceDuration(new_min_silence_duration_s_);
46 model_->SetThreshold(new_threshold_); 47 model_->SetThreshold(new_threshold_);
47 } else { 48 } else {
48 - model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration);  
49 - model_->SetThreshold(config_.silero_vad.threshold); 49 + if (!config_.silero_vad.model.empty()) {
  50 + model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration);
  51 + model_->SetThreshold(config_.silero_vad.threshold);
  52 + } else if (!config_.ten_vad.model.empty()) {
  53 + model_->SetMinSilenceDuration(config_.ten_vad.min_silence_duration);
  54 + model_->SetThreshold(config_.ten_vad.threshold);
  55 + } else {
  56 + SHERPA_ONNX_LOGE("Unknown vad model");
  57 + SHERPA_ONNX_EXIT(-1);
  58 + }
50 } 59 }
51 60
52 int32_t window_size = model_->WindowSize(); 61 int32_t window_size = model_->WindowSize();
@@ -160,11 +169,16 @@ class VoiceActivityDetector::Impl { @@ -160,11 +169,16 @@ class VoiceActivityDetector::Impl {
160 169
161 private: 170 private:
162 void Init() { 171 void Init() {
163 - // TODO(fangjun): Currently, we support only one vad model.  
164 - // If a new vad model is added, we need to change the place  
165 - // where max_speech_duration is placed.  
166 - max_utterance_length_ =  
167 - config_.sample_rate * config_.silero_vad.max_speech_duration; 172 + if (!config_.silero_vad.model.empty()) {
  173 + max_utterance_length_ =
  174 + config_.sample_rate * config_.silero_vad.max_speech_duration;
  175 + } else if (!config_.ten_vad.model.empty()) {
  176 + max_utterance_length_ =
  177 + config_.sample_rate * config_.ten_vad.max_speech_duration;
  178 + } else {
  179 + SHERPA_ONNX_LOGE("Unsupported VAD model");
  180 + SHERPA_ONNX_EXIT(-1);
  181 + }
168 } 182 }
169 183
170 private: 184 private:
@@ -51,6 +51,7 @@ set(srcs @@ -51,6 +51,7 @@ set(srcs
51 speaker-embedding-extractor.cc 51 speaker-embedding-extractor.cc
52 speaker-embedding-manager.cc 52 speaker-embedding-manager.cc
53 spoken-language-identification.cc 53 spoken-language-identification.cc
  54 + ten-vad-model-config.cc
54 tensorrt-config.cc 55 tensorrt-config.cc
55 vad-model-config.cc 56 vad-model-config.cc
56 vad-model.cc 57 vad-model.cc
  1 +// sherpa-onnx/python/csrc/ten-vad-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/ten-vad-model-config.h"
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/ten-vad-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindTenVadModelConfig(py::module *m) {
  15 + using PyClass = TenVadModelConfig;
  16 + py::class_<PyClass>(*m, "TenVadModelConfig")
  17 + .def(py::init<>())
  18 + .def(py::init([](const std::string &model, float threshold,
  19 + float min_silence_duration, float min_speech_duration,
  20 + int32_t window_size,
  21 + float max_speech_duration) -> std::unique_ptr<PyClass> {
  22 + auto ans = std::make_unique<PyClass>();
  23 +
  24 + ans->model = model;
  25 + ans->threshold = threshold;
  26 + ans->min_silence_duration = min_silence_duration;
  27 + ans->min_speech_duration = min_speech_duration;
  28 + ans->window_size = window_size;
  29 + ans->max_speech_duration = max_speech_duration;
  30 +
  31 + return ans;
  32 + }),
  33 + py::arg("model"), py::arg("threshold") = 0.5,
  34 + py::arg("min_silence_duration") = 0.5,
  35 + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 256,
  36 + py::arg("max_speech_duration") = 20)
  37 + .def_readwrite("model", &PyClass::model)
  38 + .def_readwrite("threshold", &PyClass::threshold)
  39 + .def_readwrite("min_silence_duration", &PyClass::min_silence_duration)
  40 + .def_readwrite("min_speech_duration", &PyClass::min_speech_duration)
  41 + .def_readwrite("window_size", &PyClass::window_size)
  42 + .def_readwrite("max_speech_duration", &PyClass::max_speech_duration)
  43 + .def("__str__", &PyClass::ToString)
  44 + .def("validate", &PyClass::Validate);
  45 +}
  46 +
  47 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/ten-vad-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindTenVadModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
@@ -8,21 +8,25 @@ @@ -8,21 +8,25 @@
8 8
9 #include "sherpa-onnx/csrc/vad-model-config.h" 9 #include "sherpa-onnx/csrc/vad-model-config.h"
10 #include "sherpa-onnx/python/csrc/silero-vad-model-config.h" 10 #include "sherpa-onnx/python/csrc/silero-vad-model-config.h"
  11 +#include "sherpa-onnx/python/csrc/ten-vad-model-config.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
14 void PybindVadModelConfig(py::module *m) { 15 void PybindVadModelConfig(py::module *m) {
15 PybindSileroVadModelConfig(m); 16 PybindSileroVadModelConfig(m);
  17 + PybindTenVadModelConfig(m);
16 18
17 using PyClass = VadModelConfig; 19 using PyClass = VadModelConfig;
18 py::class_<PyClass>(*m, "VadModelConfig") 20 py::class_<PyClass>(*m, "VadModelConfig")
19 .def(py::init<>()) 21 .def(py::init<>())
20 - .def(py::init<const SileroVadModelConfig &, int32_t, int32_t,  
21 - const std::string &, bool>(),  
22 - py::arg("silero_vad"), py::arg("sample_rate") = 16000,  
23 - py::arg("num_threads") = 1, py::arg("provider") = "cpu",  
24 - py::arg("debug") = false) 22 + .def(py::init<const SileroVadModelConfig &, const TenVadModelConfig &,
  23 + int32_t, int32_t, const std::string &, bool>(),
  24 + py::arg("silero_vad") = SileroVadModelConfig{},
  25 + py::arg("ten_vad") = TenVadModelConfig{},
  26 + py::arg("sample_rate") = 16000, py::arg("num_threads") = 1,
  27 + py::arg("provider") = "cpu", py::arg("debug") = false)
25 .def_readwrite("silero_vad", &PyClass::silero_vad) 28 .def_readwrite("silero_vad", &PyClass::silero_vad)
  29 + .def_readwrite("ten_vad", &PyClass::ten_vad)
26 .def_readwrite("sample_rate", &PyClass::sample_rate) 30 .def_readwrite("sample_rate", &PyClass::sample_rate)
27 .def_readwrite("num_threads", &PyClass::num_threads) 31 .def_readwrite("num_threads", &PyClass::num_threads)
28 .def_readwrite("provider", &PyClass::provider) 32 .def_readwrite("provider", &PyClass::provider)
@@ -64,6 +64,7 @@ from _sherpa_onnx import ( @@ -64,6 +64,7 @@ from _sherpa_onnx import (
64 SpokenLanguageIdentification, 64 SpokenLanguageIdentification,
65 SpokenLanguageIdentificationConfig, 65 SpokenLanguageIdentificationConfig,
66 SpokenLanguageIdentificationWhisperConfig, 66 SpokenLanguageIdentificationWhisperConfig,
  67 + TenVadModelConfig,
67 VadModel, 68 VadModel,
68 VadModelConfig, 69 VadModelConfig,
69 VoiceActivityDetector, 70 VoiceActivityDetector,