Fangjun Kuang
Committed by GitHub

Add C++ and Python support for ten-vad (#2377)

This PR adds support for the TEN VAD model alongside the existing Silero VAD in both C++ and Python interfaces.

- Introduces TenVadModelConfig with Python bindings and integrates it into VadModelConfig.
- Implements TenVadModel in C++ and extends the factory (VadModel::Create) and detector logic to choose between Silero and TEN VAD.
- Updates build files (CMake), fixes a spelling typo, and extends the Python example script to demonstrate --ten-vad-model.
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.2.tar.gz")
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.2.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=f4bd7d53fe8aeaecc4eda9680c72696bb86bf74e86371d81aacacd6f4ca3914d")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.3.tar.gz")
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.3.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=d409eddae5a46dc796f0841880f489ff0728b96ae26218702cd438c28667c70e")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.21.2.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.2.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.2.tar.gz
/tmp/kaldi-native-fbank-1.21.2.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.21.2.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.21.3.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.3.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.3.tar.gz
/tmp/kaldi-native-fbank-1.21.3.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.21.3.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
// cxx-api-examples/zipformer-transducer-simulate-streaming-microphone-cxx-api.cc
// Copyright (c) 2025 Xiaomi Corporation
//
// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++ API
// for streaming speech recognition from a microphone.
// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++
// API for streaming speech recognition from a microphone.
//
// clang-format off
//
... ...
... ... @@ -19,6 +19,12 @@ For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx
or download ten-vad.onnx, for instance
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx
Please replace --silero-vad-model with --ten-vad-model below to use ten-vad.
(1) For paraformer
./python-api-examples/generate-subtitles.py \
... ... @@ -124,8 +130,13 @@ def get_args():
parser.add_argument(
"--silero-vad-model",
type=str,
required=True,
help="Path to silero_vad.onnx",
help="Path to silero_vad.onnx.",
)
parser.add_argument(
"--ten-vad-model",
type=str,
help="Path to ten-vad.onnx",
)
parser.add_argument(
... ... @@ -499,7 +510,12 @@ class Segment:
def main():
args = get_args()
assert_file_exists(args.tokens)
if args.silero_vad_model:
assert_file_exists(args.silero_vad_model)
elif args.ten_vad_model:
assert_file_exists(args.ten_vad_model)
else:
raise ValueError("You need to supply one vad model")
assert args.num_threads > 0, args.num_threads
... ... @@ -536,8 +552,9 @@ def main():
stream = recognizer.create_stream()
config = sherpa_onnx.VadModelConfig()
if args.silero_vad_model:
config.silero_vad.model = args.silero_vad_model
config.silero_vad.threshold = 0.5
config.silero_vad.threshold = 0.2
config.silero_vad.min_silence_duration = 0.25 # seconds
config.silero_vad.min_speech_duration = 0.25 # seconds
... ... @@ -548,6 +565,21 @@ def main():
config.sample_rate = args.sample_rate
window_size = config.silero_vad.window_size
print("use silero-vad")
else:
config.ten_vad.model = args.ten_vad_model
config.ten_vad.threshold = 0.2
config.ten_vad.min_silence_duration = 0.25 # seconds
config.ten_vad.min_speech_duration = 0.25 # seconds
# If the current segment is larger than this value, then it increases
# the threshold to 0.9 internally. After detecting this segment,
# it resets the threshold to its original value.
config.ten_vad.max_speech_duration = 5 # seconds
config.sample_rate = args.sample_rate
window_size = config.ten_vad.window_size
print("use ten-vad")
buffer = []
vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100)
... ...
... ... @@ -123,6 +123,8 @@ set(sources
spoken-language-identification.cc
stack.cc
symbol-table.cc
ten-vad-model-config.cc
ten-vad-model.cc
text-utils.cc
transducer-keyword-decoder.cc
transpose.cc
... ...
... ... @@ -40,7 +40,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) {
"to the silero VAD model. WARNING! Silero VAD models were trained using "
"512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples "
"for 8000 sample rate. Values other than these may affect model "
"perfomance!");
"performance!");
}
bool SileroVadModelConfig::Validate() const {
... ...
... ... @@ -24,7 +24,6 @@ struct SileroVadModelConfig {
float min_speech_duration = 0.25; // in seconds
// 512, 1024, 1536 samples for 16000 Hz
// 256, 512, 768 samples for 800 Hz
int32_t window_size = 512; // in samples
// If a speech segment is longer than this value, then we increase
... ...
// sherpa-onnx/csrc/ten-vad-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/ten-vad-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void TenVadModelConfig::Register(ParseOptions *po) {
po->Register("ten-vad-model", &model, "Path to TEN VAD ONNX model.");
po->Register("ten-vad-threshold", &threshold,
"Speech threshold. TEN VAD outputs speech probabilities for "
"each audio chunk, probabilities ABOVE this value are "
"considered as SPEECH. It is better to tune this parameter for "
"each dataset separately, but lazy "
"0.5 is pretty good for most datasets.");
po->Register("ten-vad-min-silence-duration", &min_silence_duration,
"In seconds. In the end of each speech chunk wait for "
"--ten-vad-min-silence-duration seconds before separating it");
po->Register("ten-vad-min-speech-duration", &min_speech_duration,
"In seconds. In the end of each silence chunk wait for "
"--ten-vad-min-speech-duration seconds before separating it");
po->Register(
"ten-vad-max-speech-duration", &max_speech_duration,
"In seconds. If a speech segment is longer than this value, then we "
"increase the threshold to 0.9. After finishing detecting the segment, "
"the threshold value is reset to its original value.");
po->Register(
"ten-vad-window-size", &window_size,
"In samples. Audio chunks of --ten-vad-window-size samples are fed "
"to the ten VAD model. WARNING! Please use 160 or 256 ");
}
bool TenVadModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --ten-vad-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("TEN vad model file '%s' does not exist", model.c_str());
return false;
}
if (threshold < 0.01) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --ten-vad-threshold. Given: %f",
threshold);
return false;
}
if (threshold >= 1) {
SHERPA_ONNX_LOGE(
"Please use a smaller value for --ten-vad-threshold. Given: %f",
threshold);
return false;
}
if (min_silence_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --ten-vad-min-silence-duration. "
"Given: "
"%f",
min_silence_duration);
return false;
}
if (min_speech_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --ten-vad-min-speech-duration. "
"Given: "
"%f",
min_speech_duration);
return false;
}
if (max_speech_duration <= 0) {
SHERPA_ONNX_LOGE(
"Please use a larger value for --ten-vad-max-speech-duration. "
"Given: "
"%f",
max_speech_duration);
return false;
}
return true;
}
std::string TenVadModelConfig::ToString() const {
std::ostringstream os;
os << "TenVadModelConfig(";
os << "model=\"" << model << "\", ";
os << "threshold=" << threshold << ", ";
os << "min_silence_duration=" << min_silence_duration << ", ";
os << "min_speech_duration=" << min_speech_duration << ", ";
os << "max_speech_duration=" << max_speech_duration << ", ";
os << "window_size=" << window_size << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/ten-vad-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct TenVadModelConfig {
std::string model;
// threshold to classify a segment as speech
//
// If the predicted probability of a segment is larger than this
// value, then it is classified as speech.
float threshold = 0.5;
float min_silence_duration = 0.5; // in seconds
float min_speech_duration = 0.25; // in seconds
// 160 or 256
int32_t window_size = 256; // in samples
// If a speech segment is longer than this value, then we increase
// the threshold to 0.9. After finishing detecting the segment,
// the threshold value is reset to its original value.
float max_speech_duration = 20; // in seconds
TenVadModelConfig() = default;
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/ten-vad-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/ten-vad-model.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/rfft.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class TenVadModel::Impl {
public:
explicit Impl(const VadModelConfig &config)
: config_(config),
rfft_(1024),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{},
sample_rate_(config.sample_rate) {
auto buf = ReadFile(config.ten_vad.model);
Init(buf.data(), buf.size());
}
template <typename Manager>
Impl(Manager *mgr, const VadModelConfig &config)
: config_(config),
rfft_(1024),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{},
sample_rate_(config.sample_rate) {
auto buf = ReadFile(mgr, config.ten_vad.model);
Init(buf.data(), buf.size());
}
void Reset() {
triggered_ = false;
current_sample_ = 0;
temp_start_ = 0;
temp_end_ = 0;
last_sample_ = 0;
last_features_.resize(3 * 41);
std::fill(last_features_.begin(), last_features_.end(), 0.0f);
tmp_samples_.resize(1024);
ResetStates();
}
bool IsSpeech(const float *samples, int32_t n) {
if (n != WindowSize()) {
SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize());
SHERPA_ONNX_EXIT(-1);
}
float prob = Run(samples, n);
float threshold = config_.ten_vad.threshold;
current_sample_ += config_.ten_vad.window_size;
if (prob > threshold && temp_end_ != 0) {
temp_end_ = 0;
}
if (prob > threshold && temp_start_ == 0) {
// start speaking, but we require that it must satisfy
// min_speech_duration
temp_start_ = current_sample_;
return false;
}
if (prob > threshold && temp_start_ != 0 && !triggered_) {
if (current_sample_ - temp_start_ < min_speech_samples_) {
return false;
}
triggered_ = true;
return true;
}
if ((prob < threshold) && !triggered_) {
// silence
temp_start_ = 0;
temp_end_ = 0;
return false;
}
if ((prob > threshold - 0.15) && triggered_) {
// speaking
return true;
}
if ((prob > threshold) && !triggered_) {
// start speaking
triggered_ = true;
return true;
}
if ((prob < threshold) && triggered_) {
// stop to speak
if (temp_end_ == 0) {
temp_end_ = current_sample_;
}
if (current_sample_ - temp_end_ < min_silence_samples_) {
// continue speaking
return true;
}
// stopped speaking
temp_start_ = 0;
temp_end_ = 0;
triggered_ = false;
return false;
}
return false;
}
int32_t WindowShift() const { return config_.ten_vad.window_size; }
int32_t WindowSize() const { return config_.ten_vad.window_size; }
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
void SetMinSilenceDuration(float s) {
min_silence_samples_ = sample_rate_ * s;
}
void SetThreshold(float threshold) { config_.ten_vad.threshold = threshold; }
private:
void Init(void *model_data, size_t model_data_length) {
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config_.sample_rate);
SHERPA_ONNX_EXIT(-1);
}
if (config_.ten_vad.window_size > 768) {
SHERPA_ONNX_LOGE("Windows size %d for ten-vad is too large",
config_.ten_vad.window_size);
SHERPA_ONNX_EXIT(-1);
}
min_silence_samples_ = sample_rate_ * config_.ten_vad.min_silence_duration;
min_speech_samples_ = sample_rate_ * config_.ten_vad.min_speech_duration;
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
InitMelBanks();
Check();
Reset();
}
void ResetStates() {
std::array<int64_t, 2> shape{1, 64};
states_.clear();
states_.reserve(4);
for (int32_t i = 0; i != 4; ++i) {
Ort::Value s = Ort::Value::CreateTensor<float>(allocator_, shape.data(),
shape.size());
Fill<float>(&s, 0);
states_.push_back(std::move(s));
}
}
void InitMelBanks() {
knf::FrameExtractionOptions frame_opts;
// 16 kHz, so num_fft is 16000*64/1000 = 1024
frame_opts.frame_length_ms = 64;
knf::MelBanksOptions mel_opts;
mel_opts.is_librosa = true;
mel_opts.norm = "";
mel_opts.use_slaney_mel_scale = true;
mel_opts.floor_to_int_bin = true;
mel_opts.low_freq = 0;
mel_opts.high_freq = 8000;
mel_opts.num_bins = 40;
mel_banks_ = std::make_unique<knf::MelBanks>(mel_opts, frame_opts, 1.0f);
features_.resize(41);
}
void Check() {
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---ten-vad---\n";
PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(model_type, "model_type");
if (model_type.empty()) {
SHERPA_ONNX_LOGE(
"Please download ten-vad.onnx or ten-vad.int8.onnx from\n"
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models"
"\nWe have added meta data to the original ten-vad.onnx from\n"
"https://github.com/TEN-framework/ten-vad");
SHERPA_ONNX_EXIT(-1);
}
if (model_type != "ten-vad") {
SHERPA_ONNX_LOGE("Expect model type 'ten-vad', given '%s'",
model_type.c_str());
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(mean_, "mean");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(window_, "window");
if (mean_.size() != 41) {
SHERPA_ONNX_LOGE(
"Incorrect size of the mean vector. Given %d, expected 41",
static_cast<int32_t>(mean_.size()));
SHERPA_ONNX_EXIT(-1);
}
if (inv_stddev_.size() != 41) {
SHERPA_ONNX_LOGE(
"Incorrect size of the inv_stddev vector. Given %d, expected 41",
static_cast<int32_t>(inv_stddev_.size()));
SHERPA_ONNX_EXIT(-1);
}
if (window_.size() != 768) {
SHERPA_ONNX_LOGE(
"Incorrect size of the window vector. Given %d, expected 768",
static_cast<int32_t>(window_.size()));
SHERPA_ONNX_EXIT(-1);
}
}
static void Scale(const float *samples, int32_t n, float *out) {
for (int32_t i = 0; i != n; ++i) {
out[i] = samples[i] * 32768;
}
}
void Preemphasis(const float *samples, int32_t n, float *out) {
float t = samples[n - 1];
for (int32_t i = n - 1; i > 0; --i) {
out[i] = samples[i] - 0.97 * samples[i - 1];
}
out[0] = samples[0] - 0.97 * last_sample_;
last_sample_ = t;
}
static void ApplyWindow(const float *samples, const float *window, int32_t n,
float *out) {
for (int32_t i = 0; i != n; ++i) {
out[i] = samples[i] * window[i];
}
}
static void ComputePowerSpectrum(const float *fft_bins, int32_t n,
float *out) {
out[0] = fft_bins[0] * fft_bins[0];
out[n - 1] = fft_bins[1] * fft_bins[1];
for (int32_t i = 1; i < n / 2; ++i) {
float real = fft_bins[2 * i];
float imag = fft_bins[2 * i + 1];
out[i] = real * real + imag * imag;
}
}
static void LogMel(const float *in, int32_t n, float *out) {
for (int32_t i = 0; i != n; ++i) {
// 20.79441541679836 is log(32768*32768)
out[i] = std::logf(in[i] + 1e-10) - 20.79441541679836f;
}
}
void ApplyNormalization(const float *in, float *out) const {
for (int32_t i = 0; i != static_cast<int32_t>(mean_.size()); ++i) {
out[i] = (in[i] - mean_[i]) * inv_stddev_[i];
}
}
void ComputeFeatures(const float *samples, int32_t n) {
std::fill(tmp_samples_.begin() + n, tmp_samples_.end(), 0.0f);
Scale(samples, n, tmp_samples_.data());
Preemphasis(tmp_samples_.data(), n, tmp_samples_.data());
ApplyWindow(tmp_samples_.data(), window_.data(), n, tmp_samples_.data());
rfft_.Compute(tmp_samples_.data());
auto &power_spectrum = tmp_samples_;
ComputePowerSpectrum(tmp_samples_.data(), tmp_samples_.size(),
power_spectrum.data());
// note only the first half of power_spectrum is used inside Compute()
mel_banks_->Compute(power_spectrum.data(), features_.data());
LogMel(features_.data(), static_cast<int32_t>(features_.size()) - 1,
features_.data());
// Note(fangjun): The ten-vad model expects a pitch feature, but we set it
// to 0 as a simplification. This may reduce performance as noted
// in the PR #2377
features_.back() = 0;
ApplyNormalization(features_.data(), features_.data());
std::memmove(last_features_.data(),
last_features_.data() + features_.size(),
2 * features_.size() * sizeof(float));
std::copy(features_.begin(), features_.end(),
last_features_.begin() + 2 * features_.size());
}
float Run(const float *samples, int32_t n) {
ComputeFeatures(samples, n);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape = {1, 3, 41};
Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),
last_features_.size(),
x_shape.data(), x_shape.size());
std::vector<Ort::Value> inputs;
inputs.reserve(input_names_.size());
inputs.push_back(std::move(x));
for (auto &s : states_) {
inputs.push_back(std::move(s));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {
states_[i - 1] = std::move(out[i]);
}
float prob = out[0].GetTensorData<float>()[0];
return prob;
}
private:
VadModelConfig config_;
knf::Rfft rfft_;
std::unique_ptr<knf::MelBanks> mel_banks_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
std::vector<Ort::Value> states_;
int64_t sample_rate_;
int32_t min_silence_samples_;
int32_t min_speech_samples_;
bool triggered_ = false;
int32_t current_sample_ = 0;
int32_t temp_start_ = 0;
int32_t temp_end_ = 0;
float last_sample_ = 0;
std::vector<float> mean_;
std::vector<float> inv_stddev_;
std::vector<float> window_;
std::vector<float> features_;
std::vector<float> last_features_; // (3, 41), row major
std::vector<float> tmp_samples_; // (1024,)
};
TenVadModel::TenVadModel(const VadModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
TenVadModel::TenVadModel(Manager *mgr, const VadModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
TenVadModel::~TenVadModel() = default;
void TenVadModel::Reset() { return impl_->Reset(); }
bool TenVadModel::IsSpeech(const float *samples, int32_t n) {
return impl_->IsSpeech(samples, n);
}
int32_t TenVadModel::WindowSize() const { return impl_->WindowSize(); }
int32_t TenVadModel::WindowShift() const { return impl_->WindowShift(); }
int32_t TenVadModel::MinSilenceDurationSamples() const {
return impl_->MinSilenceDurationSamples();
}
int32_t TenVadModel::MinSpeechDurationSamples() const {
return impl_->MinSpeechDurationSamples();
}
void TenVadModel::SetMinSilenceDuration(float s) {
impl_->SetMinSilenceDuration(s);
}
void TenVadModel::SetThreshold(float threshold) {
impl_->SetThreshold(threshold);
}
#if __ANDROID_API__ >= 9
template TenVadModel::TenVadModel(AAssetManager *mgr,
const VadModelConfig &config);
#endif
#if __OHOS__
template TenVadModel::TenVadModel(NativeResourceManager *mgr,
const VadModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/ten-vad-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
#include <memory>
#include "sherpa-onnx/csrc/vad-model.h"
namespace sherpa_onnx {
class TenVadModel : public VadModel {
public:
explicit TenVadModel(const VadModelConfig &config);
template <typename Manager>
TenVadModel(Manager *mgr, const VadModelConfig &config);
~TenVadModel() override;
// reset the internal model states
void Reset() override;
/**
* @param samples Pointer to a 1-d array containing audio samples.
* Each sample should be normalized to the range [-1, 1].
* @param n Number of samples.
*
* @return Return true if speech is detected. Return false otherwise.
*/
bool IsSpeech(const float *samples, int32_t n) override;
// 256 or 160
int32_t WindowSize() const override;
// 256 or 128
int32_t WindowShift() const override;
int32_t MinSilenceDurationSamples() const override;
int32_t MinSpeechDurationSamples() const override;
void SetMinSilenceDuration(float s) override;
void SetThreshold(float threshold) override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_
... ...
... ... @@ -10,9 +10,9 @@ namespace sherpa_onnx {
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
*
* @param allocator
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
* @param v A 3-D tensor of shape (B, T, C). Its data type is type.
*
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is type.
* @return Return a 3-D tensor of shape (T, B, C). Its data type is type.
*/
template <typename type = float>
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
... ... @@ -20,9 +20,9 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
*
* @param allocator
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
* @param v A 3-D tensor of shape (B, T, C). Its data type is type.
*
* @return Return a 3-D tensor of shape (B, C, T). Its datatype is type.
* @return Return a 3-D tensor of shape (B, C, T). Its data type is type.
*/
template <typename type = float>
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
... ...
... ... @@ -14,6 +14,7 @@ namespace sherpa_onnx {
void VadModelConfig::Register(ParseOptions *po) {
silero_vad.Register(po);
ten_vad.Register(po);
po->Register("vad-sample-rate", &sample_rate,
"Sample rate expected by the VAD model");
... ... @@ -48,7 +49,17 @@ bool VadModelConfig::Validate() const {
}
}
if (!silero_vad.model.empty()) {
return silero_vad.Validate();
}
if (!ten_vad.model.empty()) {
return ten_vad.Validate();
}
SHERPA_ONNX_LOGE("Please provide one VAD model.");
return false;
}
std::string VadModelConfig::ToString() const {
... ... @@ -56,6 +67,7 @@ std::string VadModelConfig::ToString() const {
os << "VadModelConfig(";
os << "silero_vad=" << silero_vad.ToString() << ", ";
os << "ten_vad=" << ten_vad.ToString() << ", ";
os << "sample_rate=" << sample_rate << ", ";
os << "num_threads=" << num_threads << ", ";
os << "provider=\"" << provider << "\", ";
... ...
... ... @@ -8,11 +8,13 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/silero-vad-model-config.h"
#include "sherpa-onnx/csrc/ten-vad-model-config.h"
namespace sherpa_onnx {
struct VadModelConfig {
SileroVadModelConfig silero_vad;
TenVadModelConfig ten_vad;
int32_t sample_rate = 16000;
int32_t num_threads = 1;
... ... @@ -23,9 +25,11 @@ struct VadModelConfig {
VadModelConfig() = default;
VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t sample_rate,
VadModelConfig(const SileroVadModelConfig &silero_vad,
const TenVadModelConfig &ten_vad, int32_t sample_rate,
int32_t num_threads, const std::string &provider, bool debug)
: silero_vad(silero_vad),
ten_vad(ten_vad),
sample_rate(sample_rate),
num_threads(num_threads),
provider(provider),
... ...
... ... @@ -19,13 +19,19 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/silero-vad-model.h"
#include "sherpa-onnx/csrc/ten-vad-model.h"
namespace sherpa_onnx {
std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
if (config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (!config.silero_vad.model.empty()) {
return std::make_unique<SileroVadModelRknn>(config);
} else {
SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present");
SHERPA_ONNX_EXIT(-1);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
... ... @@ -34,7 +40,17 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
return nullptr;
#endif
}
if (!config.silero_vad.model.empty()) {
return std::make_unique<SileroVadModel>(config);
}
if (!config.ten_vad.model.empty()) {
return std::make_unique<TenVadModel>(config);
}
SHERPA_ONNX_LOGE("Please provide a vad model");
return nullptr;
}
template <typename Manager>
... ... @@ -42,7 +58,12 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
const VadModelConfig &config) {
if (config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (!config.silero_vad.model.empty()) {
return std::make_unique<SileroVadModelRknn>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present");
SHERPA_ONNX_EXIT(-1);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
... ... @@ -51,7 +72,16 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
return nullptr;
#endif
}
if (!config.silero_vad.model.empty()) {
return std::make_unique<SileroVadModel>(mgr, config);
}
if (!config.ten_vad.model.empty()) {
return std::make_unique<TenVadModel>(mgr, config);
}
SHERPA_ONNX_LOGE("Please provide a vad model");
return nullptr;
}
#if __ANDROID_API__ >= 9
... ...
... ... @@ -18,6 +18,7 @@
#endif
#include "sherpa-onnx/csrc/circular-buffer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/vad-model.h"
namespace sherpa_onnx {
... ... @@ -45,8 +46,16 @@ class VoiceActivityDetector::Impl {
model_->SetMinSilenceDuration(new_min_silence_duration_s_);
model_->SetThreshold(new_threshold_);
} else {
if (!config_.silero_vad.model.empty()) {
model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration);
model_->SetThreshold(config_.silero_vad.threshold);
} else if (!config_.ten_vad.model.empty()) {
model_->SetMinSilenceDuration(config_.ten_vad.min_silence_duration);
model_->SetThreshold(config_.ten_vad.threshold);
} else {
SHERPA_ONNX_LOGE("Unknown vad model");
SHERPA_ONNX_EXIT(-1);
}
}
int32_t window_size = model_->WindowSize();
... ... @@ -160,11 +169,16 @@ class VoiceActivityDetector::Impl {
private:
void Init() {
// TODO(fangjun): Currently, we support only one vad model.
// If a new vad model is added, we need to change the place
// where max_speech_duration is placed.
if (!config_.silero_vad.model.empty()) {
max_utterance_length_ =
config_.sample_rate * config_.silero_vad.max_speech_duration;
} else if (!config_.ten_vad.model.empty()) {
max_utterance_length_ =
config_.sample_rate * config_.ten_vad.max_speech_duration;
} else {
SHERPA_ONNX_LOGE("Unsupported VAD model");
SHERPA_ONNX_EXIT(-1);
}
}
private:
... ...
... ... @@ -51,6 +51,7 @@ set(srcs
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
spoken-language-identification.cc
ten-vad-model-config.cc
tensorrt-config.cc
vad-model-config.cc
vad-model.cc
... ...
// sherpa-onnx/python/csrc/ten-vad-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/ten-vad-model-config.h"
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/ten-vad-model-config.h"
namespace sherpa_onnx {
void PybindTenVadModelConfig(py::module *m) {
using PyClass = TenVadModelConfig;
py::class_<PyClass>(*m, "TenVadModelConfig")
.def(py::init<>())
.def(py::init([](const std::string &model, float threshold,
float min_silence_duration, float min_speech_duration,
int32_t window_size,
float max_speech_duration) -> std::unique_ptr<PyClass> {
auto ans = std::make_unique<PyClass>();
ans->model = model;
ans->threshold = threshold;
ans->min_silence_duration = min_silence_duration;
ans->min_speech_duration = min_speech_duration;
ans->window_size = window_size;
ans->max_speech_duration = max_speech_duration;
return ans;
}),
py::arg("model"), py::arg("threshold") = 0.5,
py::arg("min_silence_duration") = 0.5,
py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 256,
py::arg("max_speech_duration") = 20)
.def_readwrite("model", &PyClass::model)
.def_readwrite("threshold", &PyClass::threshold)
.def_readwrite("min_silence_duration", &PyClass::min_silence_duration)
.def_readwrite("min_speech_duration", &PyClass::min_speech_duration)
.def_readwrite("window_size", &PyClass::window_size)
.def_readwrite("max_speech_duration", &PyClass::max_speech_duration)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/ten-vad-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindTenVadModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_
... ...
... ... @@ -8,21 +8,25 @@
#include "sherpa-onnx/csrc/vad-model-config.h"
#include "sherpa-onnx/python/csrc/silero-vad-model-config.h"
#include "sherpa-onnx/python/csrc/ten-vad-model-config.h"
namespace sherpa_onnx {
void PybindVadModelConfig(py::module *m) {
PybindSileroVadModelConfig(m);
PybindTenVadModelConfig(m);
using PyClass = VadModelConfig;
py::class_<PyClass>(*m, "VadModelConfig")
.def(py::init<>())
.def(py::init<const SileroVadModelConfig &, int32_t, int32_t,
const std::string &, bool>(),
py::arg("silero_vad"), py::arg("sample_rate") = 16000,
py::arg("num_threads") = 1, py::arg("provider") = "cpu",
py::arg("debug") = false)
.def(py::init<const SileroVadModelConfig &, const TenVadModelConfig &,
int32_t, int32_t, const std::string &, bool>(),
py::arg("silero_vad") = SileroVadModelConfig{},
py::arg("ten_vad") = TenVadModelConfig{},
py::arg("sample_rate") = 16000, py::arg("num_threads") = 1,
py::arg("provider") = "cpu", py::arg("debug") = false)
.def_readwrite("silero_vad", &PyClass::silero_vad)
.def_readwrite("ten_vad", &PyClass::ten_vad)
.def_readwrite("sample_rate", &PyClass::sample_rate)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("provider", &PyClass::provider)
... ...
... ... @@ -64,6 +64,7 @@ from _sherpa_onnx import (
SpokenLanguageIdentification,
SpokenLanguageIdentificationConfig,
SpokenLanguageIdentificationWhisperConfig,
TenVadModelConfig,
VadModel,
VadModelConfig,
VoiceActivityDetector,
... ...