Committed by
GitHub
Add C++ and Python support for ten-vad (#2377)
This PR adds support for the TEN VAD model alongside the existing Silero VAD in both C++ and Python interfaces. - Introduces TenVadModelConfig with Python bindings and integrates it into VadModelConfig. - Implements TenVadModel in C++ and extends the factory (VadModel::Create) and detector logic to choose between Silero and TEN VAD. - Updates build files (CMake), fixes a spelling typo, and extends the Python example script to demonstrate --ten-vad-model.
正在显示
20 个修改的文件
包含
902 行增加
和
49 行删除
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.2.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.2.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=f4bd7d53fe8aeaecc4eda9680c72696bb86bf74e86371d81aacacd6f4ca3914d") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.3.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.3.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=d409eddae5a46dc796f0841880f489ff0728b96ae26218702cd438c28667c70e") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.2.tar.gz | ||
| 16 | - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.2.tar.gz | ||
| 17 | - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.2.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.21.2.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.2.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.3.tar.gz |
| 16 | + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.3.tar.gz | ||
| 17 | + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.3.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.21.3.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.3.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| 1 | // cxx-api-examples/zipformer-transducer-simulate-streaming-microphone-cxx-api.cc | 1 | // cxx-api-examples/zipformer-transducer-simulate-streaming-microphone-cxx-api.cc |
| 2 | // Copyright (c) 2025 Xiaomi Corporation | 2 | // Copyright (c) 2025 Xiaomi Corporation |
| 3 | // | 3 | // |
| 4 | -// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++ API | ||
| 5 | -// for streaming speech recognition from a microphone. | 4 | +// This file demonstrates how to use Zipformer transducer with sherpa-onnx's C++ |
| 5 | +// API for streaming speech recognition from a microphone. | ||
| 6 | // | 6 | // |
| 7 | // clang-format off | 7 | // clang-format off |
| 8 | // | 8 | // |
| @@ -19,6 +19,12 @@ For instance, | @@ -19,6 +19,12 @@ For instance, | ||
| 19 | 19 | ||
| 20 | wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx | 20 | wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx |
| 21 | 21 | ||
| 22 | +or download ten-vad.onnx, for instance | ||
| 23 | + | ||
| 24 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/ten-vad.onnx | ||
| 25 | + | ||
| 26 | +Please replace --silero-vad-model with --ten-vad-model below to use ten-vad. | ||
| 27 | + | ||
| 22 | (1) For paraformer | 28 | (1) For paraformer |
| 23 | 29 | ||
| 24 | ./python-api-examples/generate-subtitles.py \ | 30 | ./python-api-examples/generate-subtitles.py \ |
| @@ -124,8 +130,13 @@ def get_args(): | @@ -124,8 +130,13 @@ def get_args(): | ||
| 124 | parser.add_argument( | 130 | parser.add_argument( |
| 125 | "--silero-vad-model", | 131 | "--silero-vad-model", |
| 126 | type=str, | 132 | type=str, |
| 127 | - required=True, | ||
| 128 | - help="Path to silero_vad.onnx", | 133 | + help="Path to silero_vad.onnx.", |
| 134 | + ) | ||
| 135 | + | ||
| 136 | + parser.add_argument( | ||
| 137 | + "--ten-vad-model", | ||
| 138 | + type=str, | ||
| 139 | + help="Path to ten-vad.onnx", | ||
| 129 | ) | 140 | ) |
| 130 | 141 | ||
| 131 | parser.add_argument( | 142 | parser.add_argument( |
| @@ -499,7 +510,12 @@ class Segment: | @@ -499,7 +510,12 @@ class Segment: | ||
| 499 | def main(): | 510 | def main(): |
| 500 | args = get_args() | 511 | args = get_args() |
| 501 | assert_file_exists(args.tokens) | 512 | assert_file_exists(args.tokens) |
| 502 | - assert_file_exists(args.silero_vad_model) | 513 | + if args.silero_vad_model: |
| 514 | + assert_file_exists(args.silero_vad_model) | ||
| 515 | + elif args.ten_vad_model: | ||
| 516 | + assert_file_exists(args.ten_vad_model) | ||
| 517 | + else: | ||
| 518 | + raise ValueError("You need to supply one vad model") | ||
| 503 | 519 | ||
| 504 | assert args.num_threads > 0, args.num_threads | 520 | assert args.num_threads > 0, args.num_threads |
| 505 | 521 | ||
| @@ -536,18 +552,34 @@ def main(): | @@ -536,18 +552,34 @@ def main(): | ||
| 536 | stream = recognizer.create_stream() | 552 | stream = recognizer.create_stream() |
| 537 | 553 | ||
| 538 | config = sherpa_onnx.VadModelConfig() | 554 | config = sherpa_onnx.VadModelConfig() |
| 539 | - config.silero_vad.model = args.silero_vad_model | ||
| 540 | - config.silero_vad.threshold = 0.5 | ||
| 541 | - config.silero_vad.min_silence_duration = 0.25 # seconds | ||
| 542 | - config.silero_vad.min_speech_duration = 0.25 # seconds | ||
| 543 | - | ||
| 544 | - # If the current segment is larger than this value, then it increases | ||
| 545 | - # the threshold to 0.9 internally. After detecting this segment, | ||
| 546 | - # it resets the threshold to its original value. | ||
| 547 | - config.silero_vad.max_speech_duration = 5 # seconds | ||
| 548 | - config.sample_rate = args.sample_rate | ||
| 549 | - | ||
| 550 | - window_size = config.silero_vad.window_size | 555 | + if args.silero_vad_model: |
| 556 | + config.silero_vad.model = args.silero_vad_model | ||
| 557 | + config.silero_vad.threshold = 0.2 | ||
| 558 | + config.silero_vad.min_silence_duration = 0.25 # seconds | ||
| 559 | + config.silero_vad.min_speech_duration = 0.25 # seconds | ||
| 560 | + | ||
| 561 | + # If the current segment is larger than this value, then it increases | ||
| 562 | + # the threshold to 0.9 internally. After detecting this segment, | ||
| 563 | + # it resets the threshold to its original value. | ||
| 564 | + config.silero_vad.max_speech_duration = 5 # seconds | ||
| 565 | + config.sample_rate = args.sample_rate | ||
| 566 | + | ||
| 567 | + window_size = config.silero_vad.window_size | ||
| 568 | + print("use silero-vad") | ||
| 569 | + else: | ||
| 570 | + config.ten_vad.model = args.ten_vad_model | ||
| 571 | + config.ten_vad.threshold = 0.2 | ||
| 572 | + config.ten_vad.min_silence_duration = 0.25 # seconds | ||
| 573 | + config.ten_vad.min_speech_duration = 0.25 # seconds | ||
| 574 | + | ||
| 575 | + # If the current segment is larger than this value, then it increases | ||
| 576 | + # the threshold to 0.9 internally. After detecting this segment, | ||
| 577 | + # it resets the threshold to its original value. | ||
| 578 | + config.ten_vad.max_speech_duration = 5 # seconds | ||
| 579 | + config.sample_rate = args.sample_rate | ||
| 580 | + | ||
| 581 | + window_size = config.ten_vad.window_size | ||
| 582 | + print("use ten-vad") | ||
| 551 | 583 | ||
| 552 | buffer = [] | 584 | buffer = [] |
| 553 | vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100) | 585 | vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100) |
| @@ -123,6 +123,8 @@ set(sources | @@ -123,6 +123,8 @@ set(sources | ||
| 123 | spoken-language-identification.cc | 123 | spoken-language-identification.cc |
| 124 | stack.cc | 124 | stack.cc |
| 125 | symbol-table.cc | 125 | symbol-table.cc |
| 126 | + ten-vad-model-config.cc | ||
| 127 | + ten-vad-model.cc | ||
| 126 | text-utils.cc | 128 | text-utils.cc |
| 127 | transducer-keyword-decoder.cc | 129 | transducer-keyword-decoder.cc |
| 128 | transpose.cc | 130 | transpose.cc |
| @@ -40,7 +40,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) { | @@ -40,7 +40,7 @@ void SileroVadModelConfig::Register(ParseOptions *po) { | ||
| 40 | "to the silero VAD model. WARNING! Silero VAD models were trained using " | 40 | "to the silero VAD model. WARNING! Silero VAD models were trained using " |
| 41 | "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples " | 41 | "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples " |
| 42 | "for 8000 sample rate. Values other than these may affect model " | 42 | "for 8000 sample rate. Values other than these may affect model " |
| 43 | - "perfomance!"); | 43 | + "performance!"); |
| 44 | } | 44 | } |
| 45 | 45 | ||
| 46 | bool SileroVadModelConfig::Validate() const { | 46 | bool SileroVadModelConfig::Validate() const { |
| @@ -24,7 +24,6 @@ struct SileroVadModelConfig { | @@ -24,7 +24,6 @@ struct SileroVadModelConfig { | ||
| 24 | float min_speech_duration = 0.25; // in seconds | 24 | float min_speech_duration = 0.25; // in seconds |
| 25 | 25 | ||
| 26 | // 512, 1024, 1536 samples for 16000 Hz | 26 | // 512, 1024, 1536 samples for 16000 Hz |
| 27 | - // 256, 512, 768 samples for 800 Hz | ||
| 28 | int32_t window_size = 512; // in samples | 27 | int32_t window_size = 512; // in samples |
| 29 | 28 | ||
| 30 | // If a speech segment is longer than this value, then we increase | 29 | // If a speech segment is longer than this value, then we increase |
sherpa-onnx/csrc/ten-vad-model-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/ten-vad-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/ten-vad-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void TenVadModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("ten-vad-model", &model, "Path to TEN VAD ONNX model."); | ||
| 14 | + | ||
| 15 | + po->Register("ten-vad-threshold", &threshold, | ||
| 16 | + "Speech threshold. TEN VAD outputs speech probabilities for " | ||
| 17 | + "each audio chunk, probabilities ABOVE this value are " | ||
| 18 | + "considered as SPEECH. It is better to tune this parameter for " | ||
| 19 | + "each dataset separately, but lazy " | ||
| 20 | + "0.5 is pretty good for most datasets."); | ||
| 21 | + | ||
| 22 | + po->Register("ten-vad-min-silence-duration", &min_silence_duration, | ||
| 23 | + "In seconds. In the end of each speech chunk wait for " | ||
| 24 | + "--ten-vad-min-silence-duration seconds before separating it"); | ||
| 25 | + | ||
| 26 | + po->Register("ten-vad-min-speech-duration", &min_speech_duration, | ||
| 27 | + "In seconds. In the end of each silence chunk wait for " | ||
| 28 | + "--ten-vad-min-speech-duration seconds before separating it"); | ||
| 29 | + | ||
| 30 | + po->Register( | ||
| 31 | + "ten-vad-max-speech-duration", &max_speech_duration, | ||
| 32 | + "In seconds. If a speech segment is longer than this value, then we " | ||
| 33 | + "increase the threshold to 0.9. After finishing detecting the segment, " | ||
| 34 | + "the threshold value is reset to its original value."); | ||
| 35 | + | ||
| 36 | + po->Register( | ||
| 37 | + "ten-vad-window-size", &window_size, | ||
| 38 | + "In samples. Audio chunks of --ten-vad-window-size samples are fed " | ||
| 39 | + "to the ten VAD model. WARNING! Please use 160 or 256 "); | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +bool TenVadModelConfig::Validate() const { | ||
| 43 | + if (model.empty()) { | ||
| 44 | + SHERPA_ONNX_LOGE("Please provide --ten-vad-model"); | ||
| 45 | + return false; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + if (!FileExists(model)) { | ||
| 49 | + SHERPA_ONNX_LOGE("TEN vad model file '%s' does not exist", model.c_str()); | ||
| 50 | + return false; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + if (threshold < 0.01) { | ||
| 54 | + SHERPA_ONNX_LOGE( | ||
| 55 | + "Please use a larger value for --ten-vad-threshold. Given: %f", | ||
| 56 | + threshold); | ||
| 57 | + return false; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + if (threshold >= 1) { | ||
| 61 | + SHERPA_ONNX_LOGE( | ||
| 62 | + "Please use a smaller value for --ten-vad-threshold. Given: %f", | ||
| 63 | + threshold); | ||
| 64 | + return false; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + if (min_silence_duration <= 0) { | ||
| 68 | + SHERPA_ONNX_LOGE( | ||
| 69 | + "Please use a larger value for --ten-vad-min-silence-duration. " | ||
| 70 | + "Given: " | ||
| 71 | + "%f", | ||
| 72 | + min_silence_duration); | ||
| 73 | + return false; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + if (min_speech_duration <= 0) { | ||
| 77 | + SHERPA_ONNX_LOGE( | ||
| 78 | + "Please use a larger value for --ten-vad-min-speech-duration. " | ||
| 79 | + "Given: " | ||
| 80 | + "%f", | ||
| 81 | + min_speech_duration); | ||
| 82 | + return false; | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + if (max_speech_duration <= 0) { | ||
| 86 | + SHERPA_ONNX_LOGE( | ||
| 87 | + "Please use a larger value for --ten-vad-max-speech-duration. " | ||
| 88 | + "Given: " | ||
| 89 | + "%f", | ||
| 90 | + max_speech_duration); | ||
| 91 | + return false; | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + return true; | ||
| 95 | +} | ||
| 96 | + | ||
| 97 | +std::string TenVadModelConfig::ToString() const { | ||
| 98 | + std::ostringstream os; | ||
| 99 | + | ||
| 100 | + os << "TenVadModelConfig("; | ||
| 101 | + os << "model=\"" << model << "\", "; | ||
| 102 | + os << "threshold=" << threshold << ", "; | ||
| 103 | + os << "min_silence_duration=" << min_silence_duration << ", "; | ||
| 104 | + os << "min_speech_duration=" << min_speech_duration << ", "; | ||
| 105 | + os << "max_speech_duration=" << max_speech_duration << ", "; | ||
| 106 | + os << "window_size=" << window_size << ")"; | ||
| 107 | + | ||
| 108 | + return os.str(); | ||
| 109 | +} | ||
| 110 | + | ||
| 111 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/ten-vad-model-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/ten-vad-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct TenVadModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + // threshold to classify a segment as speech | ||
| 17 | + // | ||
| 18 | + // If the predicted probability of a segment is larger than this | ||
| 19 | + // value, then it is classified as speech. | ||
| 20 | + float threshold = 0.5; | ||
| 21 | + | ||
| 22 | + float min_silence_duration = 0.5; // in seconds | ||
| 23 | + | ||
| 24 | + float min_speech_duration = 0.25; // in seconds | ||
| 25 | + | ||
| 26 | + // 160 or 256 | ||
| 27 | + int32_t window_size = 256; // in samples | ||
| 28 | + | ||
| 29 | + // If a speech segment is longer than this value, then we increase | ||
| 30 | + // the threshold to 0.9. After finishing detecting the segment, | ||
| 31 | + // the threshold value is reset to its original value. | ||
| 32 | + float max_speech_duration = 20; // in seconds | ||
| 33 | + | ||
| 34 | + TenVadModelConfig() = default; | ||
| 35 | + | ||
| 36 | + void Register(ParseOptions *po); | ||
| 37 | + | ||
| 38 | + bool Validate() const; | ||
| 39 | + | ||
| 40 | + std::string ToString() const; | ||
| 41 | +}; | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx | ||
| 44 | + | ||
| 45 | +#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/ten-vad-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/ten-vad-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/ten-vad-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cmath> | ||
| 9 | +#include <cstring> | ||
| 10 | +#include <memory> | ||
| 11 | +#include <string> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#if __OHOS__ | ||
| 21 | +#include "rawfile/raw_file_manager.h" | ||
| 22 | +#endif | ||
| 23 | + | ||
| 24 | +#include "kaldi-native-fbank/csrc/mel-computations.h" | ||
| 25 | +#include "kaldi-native-fbank/csrc/rfft.h" | ||
| 26 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 27 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 28 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 29 | +#include "sherpa-onnx/csrc/session.h" | ||
| 30 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 31 | + | ||
| 32 | +namespace sherpa_onnx { | ||
| 33 | + | ||
| 34 | +class TenVadModel::Impl { | ||
| 35 | + public: | ||
| 36 | + explicit Impl(const VadModelConfig &config) | ||
| 37 | + : config_(config), | ||
| 38 | + rfft_(1024), | ||
| 39 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 40 | + sess_opts_(GetSessionOptions(config)), | ||
| 41 | + allocator_{}, | ||
| 42 | + sample_rate_(config.sample_rate) { | ||
| 43 | + auto buf = ReadFile(config.ten_vad.model); | ||
| 44 | + Init(buf.data(), buf.size()); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + template <typename Manager> | ||
| 48 | + Impl(Manager *mgr, const VadModelConfig &config) | ||
| 49 | + : config_(config), | ||
| 50 | + rfft_(1024), | ||
| 51 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 52 | + sess_opts_(GetSessionOptions(config)), | ||
| 53 | + allocator_{}, | ||
| 54 | + sample_rate_(config.sample_rate) { | ||
| 55 | + auto buf = ReadFile(mgr, config.ten_vad.model); | ||
| 56 | + Init(buf.data(), buf.size()); | ||
| 57 | + } | ||
| 58 | + | ||
| 59 | + void Reset() { | ||
| 60 | + triggered_ = false; | ||
| 61 | + current_sample_ = 0; | ||
| 62 | + temp_start_ = 0; | ||
| 63 | + temp_end_ = 0; | ||
| 64 | + | ||
| 65 | + last_sample_ = 0; | ||
| 66 | + | ||
| 67 | + last_features_.resize(3 * 41); | ||
| 68 | + std::fill(last_features_.begin(), last_features_.end(), 0.0f); | ||
| 69 | + tmp_samples_.resize(1024); | ||
| 70 | + | ||
| 71 | + ResetStates(); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + bool IsSpeech(const float *samples, int32_t n) { | ||
| 75 | + if (n != WindowSize()) { | ||
| 76 | + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize()); | ||
| 77 | + SHERPA_ONNX_EXIT(-1); | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + float prob = Run(samples, n); | ||
| 81 | + | ||
| 82 | + float threshold = config_.ten_vad.threshold; | ||
| 83 | + | ||
| 84 | + current_sample_ += config_.ten_vad.window_size; | ||
| 85 | + | ||
| 86 | + if (prob > threshold && temp_end_ != 0) { | ||
| 87 | + temp_end_ = 0; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + if (prob > threshold && temp_start_ == 0) { | ||
| 91 | + // start speaking, but we require that it must satisfy | ||
| 92 | + // min_speech_duration | ||
| 93 | + temp_start_ = current_sample_; | ||
| 94 | + return false; | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + if (prob > threshold && temp_start_ != 0 && !triggered_) { | ||
| 98 | + if (current_sample_ - temp_start_ < min_speech_samples_) { | ||
| 99 | + return false; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + triggered_ = true; | ||
| 103 | + | ||
| 104 | + return true; | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + if ((prob < threshold) && !triggered_) { | ||
| 108 | + // silence | ||
| 109 | + temp_start_ = 0; | ||
| 110 | + temp_end_ = 0; | ||
| 111 | + return false; | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + if ((prob > threshold - 0.15) && triggered_) { | ||
| 115 | + // speaking | ||
| 116 | + return true; | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + if ((prob > threshold) && !triggered_) { | ||
| 120 | + // start speaking | ||
| 121 | + triggered_ = true; | ||
| 122 | + | ||
| 123 | + return true; | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + if ((prob < threshold) && triggered_) { | ||
| 127 | + // stop to speak | ||
| 128 | + if (temp_end_ == 0) { | ||
| 129 | + temp_end_ = current_sample_; | ||
| 130 | + } | ||
| 131 | + | ||
| 132 | + if (current_sample_ - temp_end_ < min_silence_samples_) { | ||
| 133 | + // continue speaking | ||
| 134 | + return true; | ||
| 135 | + } | ||
| 136 | + // stopped speaking | ||
| 137 | + temp_start_ = 0; | ||
| 138 | + temp_end_ = 0; | ||
| 139 | + triggered_ = false; | ||
| 140 | + return false; | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + return false; | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + int32_t WindowShift() const { return config_.ten_vad.window_size; } | ||
| 147 | + | ||
| 148 | + int32_t WindowSize() const { return config_.ten_vad.window_size; } | ||
| 149 | + | ||
| 150 | + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } | ||
| 151 | + | ||
| 152 | + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } | ||
| 153 | + | ||
| 154 | + void SetMinSilenceDuration(float s) { | ||
| 155 | + min_silence_samples_ = sample_rate_ * s; | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + void SetThreshold(float threshold) { config_.ten_vad.threshold = threshold; } | ||
| 159 | + | ||
| 160 | + private: | ||
| 161 | + void Init(void *model_data, size_t model_data_length) { | ||
| 162 | + if (sample_rate_ != 16000) { | ||
| 163 | + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", | ||
| 164 | + config_.sample_rate); | ||
| 165 | + SHERPA_ONNX_EXIT(-1); | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + if (config_.ten_vad.window_size > 768) { | ||
| 169 | + SHERPA_ONNX_LOGE("Windows size %d for ten-vad is too large", | ||
| 170 | + config_.ten_vad.window_size); | ||
| 171 | + SHERPA_ONNX_EXIT(-1); | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + min_silence_samples_ = sample_rate_ * config_.ten_vad.min_silence_duration; | ||
| 175 | + | ||
| 176 | + min_speech_samples_ = sample_rate_ * config_.ten_vad.min_speech_duration; | ||
| 177 | + | ||
| 178 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 179 | + sess_opts_); | ||
| 180 | + | ||
| 181 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 182 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 183 | + | ||
| 184 | + InitMelBanks(); | ||
| 185 | + | ||
| 186 | + Check(); | ||
| 187 | + | ||
| 188 | + Reset(); | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + void ResetStates() { | ||
| 192 | + std::array<int64_t, 2> shape{1, 64}; | ||
| 193 | + | ||
| 194 | + states_.clear(); | ||
| 195 | + states_.reserve(4); | ||
| 196 | + for (int32_t i = 0; i != 4; ++i) { | ||
| 197 | + Ort::Value s = Ort::Value::CreateTensor<float>(allocator_, shape.data(), | ||
| 198 | + shape.size()); | ||
| 199 | + | ||
| 200 | + Fill<float>(&s, 0); | ||
| 201 | + states_.push_back(std::move(s)); | ||
| 202 | + } | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + void InitMelBanks() { | ||
| 206 | + knf::FrameExtractionOptions frame_opts; | ||
| 207 | + | ||
| 208 | + // 16 kHz, so num_fft is 16000*64/1000 = 1024 | ||
| 209 | + frame_opts.frame_length_ms = 64; | ||
| 210 | + | ||
| 211 | + knf::MelBanksOptions mel_opts; | ||
| 212 | + mel_opts.is_librosa = true; | ||
| 213 | + mel_opts.norm = ""; | ||
| 214 | + mel_opts.use_slaney_mel_scale = true; | ||
| 215 | + mel_opts.floor_to_int_bin = true; | ||
| 216 | + mel_opts.low_freq = 0; | ||
| 217 | + mel_opts.high_freq = 8000; | ||
| 218 | + mel_opts.num_bins = 40; | ||
| 219 | + | ||
| 220 | + mel_banks_ = std::make_unique<knf::MelBanks>(mel_opts, frame_opts, 1.0f); | ||
| 221 | + | ||
| 222 | + features_.resize(41); | ||
| 223 | + } | ||
| 224 | + | ||
| 225 | + void Check() { | ||
| 226 | + // get meta data | ||
| 227 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 228 | + if (config_.debug) { | ||
| 229 | + std::ostringstream os; | ||
| 230 | + os << "---ten-vad---\n"; | ||
| 231 | + PrintModelMetadata(os, meta_data); | ||
| 232 | +#if __OHOS__ | ||
| 233 | + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); | ||
| 234 | +#else | ||
| 235 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 236 | +#endif | ||
| 237 | + } | ||
| 238 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 239 | + | ||
| 240 | + std::string model_type; | ||
| 241 | + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(model_type, "model_type"); | ||
| 242 | + | ||
| 243 | + if (model_type.empty()) { | ||
| 244 | + SHERPA_ONNX_LOGE( | ||
| 245 | + "Please download ten-vad.onnx or ten-vad.int8.onnx from\n" | ||
| 246 | + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models" | ||
| 247 | + "\nWe have added meta data to the original ten-vad.onnx from\n" | ||
| 248 | + "https://github.com/TEN-framework/ten-vad"); | ||
| 249 | + SHERPA_ONNX_EXIT(-1); | ||
| 250 | + } | ||
| 251 | + | ||
| 252 | + if (model_type != "ten-vad") { | ||
| 253 | + SHERPA_ONNX_LOGE("Expect model type 'ten-vad', given '%s'", | ||
| 254 | + model_type.c_str()); | ||
| 255 | + SHERPA_ONNX_EXIT(-1); | ||
| 256 | + } | ||
| 257 | + | ||
| 258 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(mean_, "mean"); | ||
| 259 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); | ||
| 260 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(window_, "window"); | ||
| 261 | + | ||
| 262 | + if (mean_.size() != 41) { | ||
| 263 | + SHERPA_ONNX_LOGE( | ||
| 264 | + "Incorrect size of the mean vector. Given %d, expected 41", | ||
| 265 | + static_cast<int32_t>(mean_.size())); | ||
| 266 | + SHERPA_ONNX_EXIT(-1); | ||
| 267 | + } | ||
| 268 | + | ||
| 269 | + if (inv_stddev_.size() != 41) { | ||
| 270 | + SHERPA_ONNX_LOGE( | ||
| 271 | + "Incorrect size of the inv_stddev vector. Given %d, expected 41", | ||
| 272 | + static_cast<int32_t>(inv_stddev_.size())); | ||
| 273 | + SHERPA_ONNX_EXIT(-1); | ||
| 274 | + } | ||
| 275 | + | ||
| 276 | + if (window_.size() != 768) { | ||
| 277 | + SHERPA_ONNX_LOGE( | ||
| 278 | + "Incorrect size of the window vector. Given %d, expected 768", | ||
| 279 | + static_cast<int32_t>(window_.size())); | ||
| 280 | + SHERPA_ONNX_EXIT(-1); | ||
| 281 | + } | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + static void Scale(const float *samples, int32_t n, float *out) { | ||
| 285 | + for (int32_t i = 0; i != n; ++i) { | ||
| 286 | + out[i] = samples[i] * 32768; | ||
| 287 | + } | ||
| 288 | + } | ||
| 289 | + | ||
| 290 | + void Preemphasis(const float *samples, int32_t n, float *out) { | ||
| 291 | + float t = samples[n - 1]; | ||
| 292 | + | ||
| 293 | + for (int32_t i = n - 1; i > 0; --i) { | ||
| 294 | + out[i] = samples[i] - 0.97 * samples[i - 1]; | ||
| 295 | + } | ||
| 296 | + | ||
| 297 | + out[0] = samples[0] - 0.97 * last_sample_; | ||
| 298 | + | ||
| 299 | + last_sample_ = t; | ||
| 300 | + } | ||
| 301 | + | ||
| 302 | + static void ApplyWindow(const float *samples, const float *window, int32_t n, | ||
| 303 | + float *out) { | ||
| 304 | + for (int32_t i = 0; i != n; ++i) { | ||
| 305 | + out[i] = samples[i] * window[i]; | ||
| 306 | + } | ||
| 307 | + } | ||
| 308 | + | ||
| 309 | + static void ComputePowerSpectrum(const float *fft_bins, int32_t n, | ||
| 310 | + float *out) { | ||
| 311 | + out[0] = fft_bins[0] * fft_bins[0]; | ||
| 312 | + out[n - 1] = fft_bins[1] * fft_bins[1]; | ||
| 313 | + | ||
| 314 | + for (int32_t i = 1; i < n / 2; ++i) { | ||
| 315 | + float real = fft_bins[2 * i]; | ||
| 316 | + float imag = fft_bins[2 * i + 1]; | ||
| 317 | + out[i] = real * real + imag * imag; | ||
| 318 | + } | ||
| 319 | + } | ||
| 320 | + | ||
| 321 | + static void LogMel(const float *in, int32_t n, float *out) { | ||
| 322 | + for (int32_t i = 0; i != n; ++i) { | ||
| 323 | + // 20.79441541679836 is log(32768*32768) | ||
| 324 | + out[i] = std::logf(in[i] + 1e-10) - 20.79441541679836f; | ||
| 325 | + } | ||
| 326 | + } | ||
| 327 | + | ||
| 328 | + void ApplyNormalization(const float *in, float *out) const { | ||
| 329 | + for (int32_t i = 0; i != static_cast<int32_t>(mean_.size()); ++i) { | ||
| 330 | + out[i] = (in[i] - mean_[i]) * inv_stddev_[i]; | ||
| 331 | + } | ||
| 332 | + } | ||
| 333 | + | ||
| 334 | + void ComputeFeatures(const float *samples, int32_t n) { | ||
| 335 | + std::fill(tmp_samples_.begin() + n, tmp_samples_.end(), 0.0f); | ||
| 336 | + | ||
| 337 | + Scale(samples, n, tmp_samples_.data()); | ||
| 338 | + | ||
| 339 | + Preemphasis(tmp_samples_.data(), n, tmp_samples_.data()); | ||
| 340 | + ApplyWindow(tmp_samples_.data(), window_.data(), n, tmp_samples_.data()); | ||
| 341 | + | ||
| 342 | + rfft_.Compute(tmp_samples_.data()); | ||
| 343 | + auto &power_spectrum = tmp_samples_; | ||
| 344 | + ComputePowerSpectrum(tmp_samples_.data(), tmp_samples_.size(), | ||
| 345 | + power_spectrum.data()); | ||
| 346 | + | ||
| 347 | + // note only the first half of power_spectrum is used inside Compute() | ||
| 348 | + mel_banks_->Compute(power_spectrum.data(), features_.data()); | ||
| 349 | + LogMel(features_.data(), static_cast<int32_t>(features_.size()) - 1, | ||
| 350 | + features_.data()); | ||
| 351 | + | ||
| 352 | + // Note(fangjun): The ten-vad model expects a pitch feature, but we set it | ||
| 353 | + // to 0 as a simplification. This may reduce performance as noted | ||
| 354 | + // in the PR #2377 | ||
| 355 | + features_.back() = 0; | ||
| 356 | + | ||
| 357 | + ApplyNormalization(features_.data(), features_.data()); | ||
| 358 | + | ||
| 359 | + std::memmove(last_features_.data(), | ||
| 360 | + last_features_.data() + features_.size(), | ||
| 361 | + 2 * features_.size() * sizeof(float)); | ||
| 362 | + std::copy(features_.begin(), features_.end(), | ||
| 363 | + last_features_.begin() + 2 * features_.size()); | ||
| 364 | + } | ||
| 365 | + | ||
| 366 | + float Run(const float *samples, int32_t n) { | ||
| 367 | + ComputeFeatures(samples, n); | ||
| 368 | + | ||
| 369 | + auto memory_info = | ||
| 370 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 371 | + | ||
| 372 | + std::array<int64_t, 3> x_shape = {1, 3, 41}; | ||
| 373 | + | ||
| 374 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(), | ||
| 375 | + last_features_.size(), | ||
| 376 | + x_shape.data(), x_shape.size()); | ||
| 377 | + | ||
| 378 | + std::vector<Ort::Value> inputs; | ||
| 379 | + inputs.reserve(input_names_.size()); | ||
| 380 | + | ||
| 381 | + inputs.push_back(std::move(x)); | ||
| 382 | + for (auto &s : states_) { | ||
| 383 | + inputs.push_back(std::move(s)); | ||
| 384 | + } | ||
| 385 | + | ||
| 386 | + auto out = | ||
| 387 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 388 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 389 | + | ||
| 390 | + for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) { | ||
| 391 | + states_[i - 1] = std::move(out[i]); | ||
| 392 | + } | ||
| 393 | + | ||
| 394 | + float prob = out[0].GetTensorData<float>()[0]; | ||
| 395 | + | ||
| 396 | + return prob; | ||
| 397 | + } | ||
| 398 | + | ||
| 399 | + private: | ||
| 400 | + VadModelConfig config_; | ||
| 401 | + knf::Rfft rfft_; | ||
| 402 | + std::unique_ptr<knf::MelBanks> mel_banks_; | ||
| 403 | + | ||
| 404 | + Ort::Env env_; | ||
| 405 | + Ort::SessionOptions sess_opts_; | ||
| 406 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 407 | + | ||
| 408 | + std::unique_ptr<Ort::Session> sess_; | ||
| 409 | + | ||
| 410 | + std::vector<std::string> input_names_; | ||
| 411 | + std::vector<const char *> input_names_ptr_; | ||
| 412 | + | ||
| 413 | + std::vector<std::string> output_names_; | ||
| 414 | + std::vector<const char *> output_names_ptr_; | ||
| 415 | + | ||
| 416 | + std::vector<Ort::Value> states_; | ||
| 417 | + int64_t sample_rate_; | ||
| 418 | + int32_t min_silence_samples_; | ||
| 419 | + int32_t min_speech_samples_; | ||
| 420 | + | ||
| 421 | + bool triggered_ = false; | ||
| 422 | + int32_t current_sample_ = 0; | ||
| 423 | + int32_t temp_start_ = 0; | ||
| 424 | + int32_t temp_end_ = 0; | ||
| 425 | + | ||
| 426 | + float last_sample_ = 0; | ||
| 427 | + | ||
| 428 | + std::vector<float> mean_; | ||
| 429 | + std::vector<float> inv_stddev_; | ||
| 430 | + std::vector<float> window_; | ||
| 431 | + | ||
| 432 | + std::vector<float> features_; | ||
| 433 | + std::vector<float> last_features_; // (3, 41), row major | ||
| 434 | + std::vector<float> tmp_samples_; // (1024,) | ||
| 435 | +}; | ||
| 436 | + | ||
| 437 | +TenVadModel::TenVadModel(const VadModelConfig &config) | ||
| 438 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 439 | + | ||
| 440 | +template <typename Manager> | ||
| 441 | +TenVadModel::TenVadModel(Manager *mgr, const VadModelConfig &config) | ||
| 442 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 443 | + | ||
| 444 | +TenVadModel::~TenVadModel() = default; | ||
| 445 | + | ||
| 446 | +void TenVadModel::Reset() { return impl_->Reset(); } | ||
| 447 | + | ||
| 448 | +bool TenVadModel::IsSpeech(const float *samples, int32_t n) { | ||
| 449 | + return impl_->IsSpeech(samples, n); | ||
| 450 | +} | ||
| 451 | + | ||
| 452 | +int32_t TenVadModel::WindowSize() const { return impl_->WindowSize(); } | ||
| 453 | + | ||
| 454 | +int32_t TenVadModel::WindowShift() const { return impl_->WindowShift(); } | ||
| 455 | + | ||
| 456 | +int32_t TenVadModel::MinSilenceDurationSamples() const { | ||
| 457 | + return impl_->MinSilenceDurationSamples(); | ||
| 458 | +} | ||
| 459 | + | ||
| 460 | +int32_t TenVadModel::MinSpeechDurationSamples() const { | ||
| 461 | + return impl_->MinSpeechDurationSamples(); | ||
| 462 | +} | ||
| 463 | + | ||
| 464 | +void TenVadModel::SetMinSilenceDuration(float s) { | ||
| 465 | + impl_->SetMinSilenceDuration(s); | ||
| 466 | +} | ||
| 467 | + | ||
| 468 | +void TenVadModel::SetThreshold(float threshold) { | ||
| 469 | + impl_->SetThreshold(threshold); | ||
| 470 | +} | ||
| 471 | + | ||
| 472 | +#if __ANDROID_API__ >= 9 | ||
| 473 | +template TenVadModel::TenVadModel(AAssetManager *mgr, | ||
| 474 | + const VadModelConfig &config); | ||
| 475 | +#endif | ||
| 476 | + | ||
| 477 | +#if __OHOS__ | ||
| 478 | +template TenVadModel::TenVadModel(NativeResourceManager *mgr, | ||
| 479 | + const VadModelConfig &config); | ||
| 480 | +#endif | ||
| 481 | + | ||
| 482 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/ten-vad-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/ten-vad-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/vad-model.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +class TenVadModel : public VadModel { | ||
| 14 | + public: | ||
| 15 | + explicit TenVadModel(const VadModelConfig &config); | ||
| 16 | + | ||
| 17 | + template <typename Manager> | ||
| 18 | + TenVadModel(Manager *mgr, const VadModelConfig &config); | ||
| 19 | + | ||
| 20 | + ~TenVadModel() override; | ||
| 21 | + | ||
| 22 | + // reset the internal model states | ||
| 23 | + void Reset() override; | ||
| 24 | + | ||
| 25 | + /** | ||
| 26 | + * @param samples Pointer to a 1-d array containing audio samples. | ||
| 27 | + * Each sample should be normalized to the range [-1, 1]. | ||
| 28 | + * @param n Number of samples. | ||
| 29 | + * | ||
| 30 | + * @return Return true if speech is detected. Return false otherwise. | ||
| 31 | + */ | ||
| 32 | + bool IsSpeech(const float *samples, int32_t n) override; | ||
| 33 | + | ||
| 34 | + // 256 or 160 | ||
| 35 | + int32_t WindowSize() const override; | ||
| 36 | + | ||
| 37 | + // 256 or 128 | ||
| 38 | + int32_t WindowShift() const override; | ||
| 39 | + | ||
| 40 | + int32_t MinSilenceDurationSamples() const override; | ||
| 41 | + int32_t MinSpeechDurationSamples() const override; | ||
| 42 | + | ||
| 43 | + void SetMinSilenceDuration(float s) override; | ||
| 44 | + void SetThreshold(float threshold) override; | ||
| 45 | + | ||
| 46 | + private: | ||
| 47 | + class Impl; | ||
| 48 | + std::unique_ptr<Impl> impl_; | ||
| 49 | +}; | ||
| 50 | + | ||
| 51 | +} // namespace sherpa_onnx | ||
| 52 | + | ||
| 53 | +#endif // SHERPA_ONNX_CSRC_TEN_VAD_MODEL_H_ |
| @@ -10,9 +10,9 @@ namespace sherpa_onnx { | @@ -10,9 +10,9 @@ namespace sherpa_onnx { | ||
| 10 | /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). | 10 | /** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). |
| 11 | * | 11 | * |
| 12 | * @param allocator | 12 | * @param allocator |
| 13 | - * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. | 13 | + * @param v A 3-D tensor of shape (B, T, C). Its data type is type. |
| 14 | * | 14 | * |
| 15 | - * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type. | 15 | + * @return Return a 3-D tensor of shape (T, B, C). Its data type is type. |
| 16 | */ | 16 | */ |
| 17 | template <typename type = float> | 17 | template <typename type = float> |
| 18 | Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); | 18 | Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); |
| @@ -20,9 +20,9 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); | @@ -20,9 +20,9 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v); | ||
| 20 | /** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). | 20 | /** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). |
| 21 | * | 21 | * |
| 22 | * @param allocator | 22 | * @param allocator |
| 23 | - * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. | 23 | + * @param v A 3-D tensor of shape (B, T, C). Its data type is type. |
| 24 | * | 24 | * |
| 25 | - * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type. | 25 | + * @return Return a 3-D tensor of shape (B, C, T). Its data type is type. |
| 26 | */ | 26 | */ |
| 27 | template <typename type = float> | 27 | template <typename type = float> |
| 28 | Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v); | 28 | Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v); |
| @@ -14,6 +14,7 @@ namespace sherpa_onnx { | @@ -14,6 +14,7 @@ namespace sherpa_onnx { | ||
| 14 | 14 | ||
| 15 | void VadModelConfig::Register(ParseOptions *po) { | 15 | void VadModelConfig::Register(ParseOptions *po) { |
| 16 | silero_vad.Register(po); | 16 | silero_vad.Register(po); |
| 17 | + ten_vad.Register(po); | ||
| 17 | 18 | ||
| 18 | po->Register("vad-sample-rate", &sample_rate, | 19 | po->Register("vad-sample-rate", &sample_rate, |
| 19 | "Sample rate expected by the VAD model"); | 20 | "Sample rate expected by the VAD model"); |
| @@ -48,7 +49,17 @@ bool VadModelConfig::Validate() const { | @@ -48,7 +49,17 @@ bool VadModelConfig::Validate() const { | ||
| 48 | } | 49 | } |
| 49 | } | 50 | } |
| 50 | 51 | ||
| 51 | - return silero_vad.Validate(); | 52 | + if (!silero_vad.model.empty()) { |
| 53 | + return silero_vad.Validate(); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + if (!ten_vad.model.empty()) { | ||
| 57 | + return ten_vad.Validate(); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + SHERPA_ONNX_LOGE("Please provide one VAD model."); | ||
| 61 | + | ||
| 62 | + return false; | ||
| 52 | } | 63 | } |
| 53 | 64 | ||
| 54 | std::string VadModelConfig::ToString() const { | 65 | std::string VadModelConfig::ToString() const { |
| @@ -56,6 +67,7 @@ std::string VadModelConfig::ToString() const { | @@ -56,6 +67,7 @@ std::string VadModelConfig::ToString() const { | ||
| 56 | 67 | ||
| 57 | os << "VadModelConfig("; | 68 | os << "VadModelConfig("; |
| 58 | os << "silero_vad=" << silero_vad.ToString() << ", "; | 69 | os << "silero_vad=" << silero_vad.ToString() << ", "; |
| 70 | + os << "ten_vad=" << ten_vad.ToString() << ", "; | ||
| 59 | os << "sample_rate=" << sample_rate << ", "; | 71 | os << "sample_rate=" << sample_rate << ", "; |
| 60 | os << "num_threads=" << num_threads << ", "; | 72 | os << "num_threads=" << num_threads << ", "; |
| 61 | os << "provider=\"" << provider << "\", "; | 73 | os << "provider=\"" << provider << "\", "; |
| @@ -8,11 +8,13 @@ | @@ -8,11 +8,13 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/parse-options.h" | 9 | #include "sherpa-onnx/csrc/parse-options.h" |
| 10 | #include "sherpa-onnx/csrc/silero-vad-model-config.h" | 10 | #include "sherpa-onnx/csrc/silero-vad-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/ten-vad-model-config.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| 14 | struct VadModelConfig { | 15 | struct VadModelConfig { |
| 15 | SileroVadModelConfig silero_vad; | 16 | SileroVadModelConfig silero_vad; |
| 17 | + TenVadModelConfig ten_vad; | ||
| 16 | 18 | ||
| 17 | int32_t sample_rate = 16000; | 19 | int32_t sample_rate = 16000; |
| 18 | int32_t num_threads = 1; | 20 | int32_t num_threads = 1; |
| @@ -23,9 +25,11 @@ struct VadModelConfig { | @@ -23,9 +25,11 @@ struct VadModelConfig { | ||
| 23 | 25 | ||
| 24 | VadModelConfig() = default; | 26 | VadModelConfig() = default; |
| 25 | 27 | ||
| 26 | - VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t sample_rate, | 28 | + VadModelConfig(const SileroVadModelConfig &silero_vad, |
| 29 | + const TenVadModelConfig &ten_vad, int32_t sample_rate, | ||
| 27 | int32_t num_threads, const std::string &provider, bool debug) | 30 | int32_t num_threads, const std::string &provider, bool debug) |
| 28 | : silero_vad(silero_vad), | 31 | : silero_vad(silero_vad), |
| 32 | + ten_vad(ten_vad), | ||
| 29 | sample_rate(sample_rate), | 33 | sample_rate(sample_rate), |
| 30 | num_threads(num_threads), | 34 | num_threads(num_threads), |
| 31 | provider(provider), | 35 | provider(provider), |
| @@ -19,13 +19,19 @@ | @@ -19,13 +19,19 @@ | ||
| 19 | 19 | ||
| 20 | #include "sherpa-onnx/csrc/macros.h" | 20 | #include "sherpa-onnx/csrc/macros.h" |
| 21 | #include "sherpa-onnx/csrc/silero-vad-model.h" | 21 | #include "sherpa-onnx/csrc/silero-vad-model.h" |
| 22 | +#include "sherpa-onnx/csrc/ten-vad-model.h" | ||
| 22 | 23 | ||
| 23 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 24 | 25 | ||
| 25 | std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { | 26 | std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { |
| 26 | if (config.provider == "rknn") { | 27 | if (config.provider == "rknn") { |
| 27 | #if SHERPA_ONNX_ENABLE_RKNN | 28 | #if SHERPA_ONNX_ENABLE_RKNN |
| 28 | - return std::make_unique<SileroVadModelRknn>(config); | 29 | + if (!config.silero_vad.model.empty()) { |
| 30 | + return std::make_unique<SileroVadModelRknn>(config); | ||
| 31 | + } else { | ||
| 32 | + SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present"); | ||
| 33 | + SHERPA_ONNX_EXIT(-1); | ||
| 34 | + } | ||
| 29 | #else | 35 | #else |
| 30 | SHERPA_ONNX_LOGE( | 36 | SHERPA_ONNX_LOGE( |
| 31 | "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " | 37 | "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " |
| @@ -34,7 +40,17 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { | @@ -34,7 +40,17 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { | ||
| 34 | return nullptr; | 40 | return nullptr; |
| 35 | #endif | 41 | #endif |
| 36 | } | 42 | } |
| 37 | - return std::make_unique<SileroVadModel>(config); | 43 | + |
| 44 | + if (!config.silero_vad.model.empty()) { | ||
| 45 | + return std::make_unique<SileroVadModel>(config); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + if (!config.ten_vad.model.empty()) { | ||
| 49 | + return std::make_unique<TenVadModel>(config); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + SHERPA_ONNX_LOGE("Please provide a vad model"); | ||
| 53 | + return nullptr; | ||
| 38 | } | 54 | } |
| 39 | 55 | ||
| 40 | template <typename Manager> | 56 | template <typename Manager> |
| @@ -42,7 +58,12 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, | @@ -42,7 +58,12 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, | ||
| 42 | const VadModelConfig &config) { | 58 | const VadModelConfig &config) { |
| 43 | if (config.provider == "rknn") { | 59 | if (config.provider == "rknn") { |
| 44 | #if SHERPA_ONNX_ENABLE_RKNN | 60 | #if SHERPA_ONNX_ENABLE_RKNN |
| 45 | - return std::make_unique<SileroVadModelRknn>(mgr, config); | 61 | + if (!config.silero_vad.model.empty()) { |
| 62 | + return std::make_unique<SileroVadModelRknn>(mgr, config); | ||
| 63 | + } else { | ||
| 64 | + SHERPA_ONNX_LOGE("Only silero-vad is supported for RKNN at present"); | ||
| 65 | + SHERPA_ONNX_EXIT(-1); | ||
| 66 | + } | ||
| 46 | #else | 67 | #else |
| 47 | SHERPA_ONNX_LOGE( | 68 | SHERPA_ONNX_LOGE( |
| 48 | "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " | 69 | "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " |
| @@ -51,7 +72,16 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, | @@ -51,7 +72,16 @@ std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, | ||
| 51 | return nullptr; | 72 | return nullptr; |
| 52 | #endif | 73 | #endif |
| 53 | } | 74 | } |
| 54 | - return std::make_unique<SileroVadModel>(mgr, config); | 75 | + if (!config.silero_vad.model.empty()) { |
| 76 | + return std::make_unique<SileroVadModel>(mgr, config); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + if (!config.ten_vad.model.empty()) { | ||
| 80 | + return std::make_unique<TenVadModel>(mgr, config); | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + SHERPA_ONNX_LOGE("Please provide a vad model"); | ||
| 84 | + return nullptr; | ||
| 55 | } | 85 | } |
| 56 | 86 | ||
| 57 | #if __ANDROID_API__ >= 9 | 87 | #if __ANDROID_API__ >= 9 |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | #endif | 18 | #endif |
| 19 | 19 | ||
| 20 | #include "sherpa-onnx/csrc/circular-buffer.h" | 20 | #include "sherpa-onnx/csrc/circular-buffer.h" |
| 21 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 21 | #include "sherpa-onnx/csrc/vad-model.h" | 22 | #include "sherpa-onnx/csrc/vad-model.h" |
| 22 | 23 | ||
| 23 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| @@ -45,8 +46,16 @@ class VoiceActivityDetector::Impl { | @@ -45,8 +46,16 @@ class VoiceActivityDetector::Impl { | ||
| 45 | model_->SetMinSilenceDuration(new_min_silence_duration_s_); | 46 | model_->SetMinSilenceDuration(new_min_silence_duration_s_); |
| 46 | model_->SetThreshold(new_threshold_); | 47 | model_->SetThreshold(new_threshold_); |
| 47 | } else { | 48 | } else { |
| 48 | - model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration); | ||
| 49 | - model_->SetThreshold(config_.silero_vad.threshold); | 49 | + if (!config_.silero_vad.model.empty()) { |
| 50 | + model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration); | ||
| 51 | + model_->SetThreshold(config_.silero_vad.threshold); | ||
| 52 | + } else if (!config_.ten_vad.model.empty()) { | ||
| 53 | + model_->SetMinSilenceDuration(config_.ten_vad.min_silence_duration); | ||
| 54 | + model_->SetThreshold(config_.ten_vad.threshold); | ||
| 55 | + } else { | ||
| 56 | + SHERPA_ONNX_LOGE("Unknown vad model"); | ||
| 57 | + SHERPA_ONNX_EXIT(-1); | ||
| 58 | + } | ||
| 50 | } | 59 | } |
| 51 | 60 | ||
| 52 | int32_t window_size = model_->WindowSize(); | 61 | int32_t window_size = model_->WindowSize(); |
| @@ -160,11 +169,16 @@ class VoiceActivityDetector::Impl { | @@ -160,11 +169,16 @@ class VoiceActivityDetector::Impl { | ||
| 160 | 169 | ||
| 161 | private: | 170 | private: |
| 162 | void Init() { | 171 | void Init() { |
| 163 | - // TODO(fangjun): Currently, we support only one vad model. | ||
| 164 | - // If a new vad model is added, we need to change the place | ||
| 165 | - // where max_speech_duration is placed. | ||
| 166 | - max_utterance_length_ = | ||
| 167 | - config_.sample_rate * config_.silero_vad.max_speech_duration; | 172 | + if (!config_.silero_vad.model.empty()) { |
| 173 | + max_utterance_length_ = | ||
| 174 | + config_.sample_rate * config_.silero_vad.max_speech_duration; | ||
| 175 | + } else if (!config_.ten_vad.model.empty()) { | ||
| 176 | + max_utterance_length_ = | ||
| 177 | + config_.sample_rate * config_.ten_vad.max_speech_duration; | ||
| 178 | + } else { | ||
| 179 | + SHERPA_ONNX_LOGE("Unsupported VAD model"); | ||
| 180 | + SHERPA_ONNX_EXIT(-1); | ||
| 181 | + } | ||
| 168 | } | 182 | } |
| 169 | 183 | ||
| 170 | private: | 184 | private: |
| @@ -51,6 +51,7 @@ set(srcs | @@ -51,6 +51,7 @@ set(srcs | ||
| 51 | speaker-embedding-extractor.cc | 51 | speaker-embedding-extractor.cc |
| 52 | speaker-embedding-manager.cc | 52 | speaker-embedding-manager.cc |
| 53 | spoken-language-identification.cc | 53 | spoken-language-identification.cc |
| 54 | + ten-vad-model-config.cc | ||
| 54 | tensorrt-config.cc | 55 | tensorrt-config.cc |
| 55 | vad-model-config.cc | 56 | vad-model-config.cc |
| 56 | vad-model.cc | 57 | vad-model.cc |
| 1 | +// sherpa-onnx/python/csrc/ten-vad-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/ten-vad-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/ten-vad-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindTenVadModelConfig(py::module *m) { | ||
| 15 | + using PyClass = TenVadModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "TenVadModelConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init([](const std::string &model, float threshold, | ||
| 19 | + float min_silence_duration, float min_speech_duration, | ||
| 20 | + int32_t window_size, | ||
| 21 | + float max_speech_duration) -> std::unique_ptr<PyClass> { | ||
| 22 | + auto ans = std::make_unique<PyClass>(); | ||
| 23 | + | ||
| 24 | + ans->model = model; | ||
| 25 | + ans->threshold = threshold; | ||
| 26 | + ans->min_silence_duration = min_silence_duration; | ||
| 27 | + ans->min_speech_duration = min_speech_duration; | ||
| 28 | + ans->window_size = window_size; | ||
| 29 | + ans->max_speech_duration = max_speech_duration; | ||
| 30 | + | ||
| 31 | + return ans; | ||
| 32 | + }), | ||
| 33 | + py::arg("model"), py::arg("threshold") = 0.5, | ||
| 34 | + py::arg("min_silence_duration") = 0.5, | ||
| 35 | + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 256, | ||
| 36 | + py::arg("max_speech_duration") = 20) | ||
| 37 | + .def_readwrite("model", &PyClass::model) | ||
| 38 | + .def_readwrite("threshold", &PyClass::threshold) | ||
| 39 | + .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) | ||
| 40 | + .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) | ||
| 41 | + .def_readwrite("window_size", &PyClass::window_size) | ||
| 42 | + .def_readwrite("max_speech_duration", &PyClass::max_speech_duration) | ||
| 43 | + .def("__str__", &PyClass::ToString) | ||
| 44 | + .def("validate", &PyClass::Validate); | ||
| 45 | +} | ||
| 46 | + | ||
| 47 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/ten-vad-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindTenVadModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_TEN_VAD_MODEL_CONFIG_H_ |
| @@ -8,21 +8,25 @@ | @@ -8,21 +8,25 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/vad-model-config.h" | 9 | #include "sherpa-onnx/csrc/vad-model-config.h" |
| 10 | #include "sherpa-onnx/python/csrc/silero-vad-model-config.h" | 10 | #include "sherpa-onnx/python/csrc/silero-vad-model-config.h" |
| 11 | +#include "sherpa-onnx/python/csrc/ten-vad-model-config.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| 14 | void PybindVadModelConfig(py::module *m) { | 15 | void PybindVadModelConfig(py::module *m) { |
| 15 | PybindSileroVadModelConfig(m); | 16 | PybindSileroVadModelConfig(m); |
| 17 | + PybindTenVadModelConfig(m); | ||
| 16 | 18 | ||
| 17 | using PyClass = VadModelConfig; | 19 | using PyClass = VadModelConfig; |
| 18 | py::class_<PyClass>(*m, "VadModelConfig") | 20 | py::class_<PyClass>(*m, "VadModelConfig") |
| 19 | .def(py::init<>()) | 21 | .def(py::init<>()) |
| 20 | - .def(py::init<const SileroVadModelConfig &, int32_t, int32_t, | ||
| 21 | - const std::string &, bool>(), | ||
| 22 | - py::arg("silero_vad"), py::arg("sample_rate") = 16000, | ||
| 23 | - py::arg("num_threads") = 1, py::arg("provider") = "cpu", | ||
| 24 | - py::arg("debug") = false) | 22 | + .def(py::init<const SileroVadModelConfig &, const TenVadModelConfig &, |
| 23 | + int32_t, int32_t, const std::string &, bool>(), | ||
| 24 | + py::arg("silero_vad") = SileroVadModelConfig{}, | ||
| 25 | + py::arg("ten_vad") = TenVadModelConfig{}, | ||
| 26 | + py::arg("sample_rate") = 16000, py::arg("num_threads") = 1, | ||
| 27 | + py::arg("provider") = "cpu", py::arg("debug") = false) | ||
| 25 | .def_readwrite("silero_vad", &PyClass::silero_vad) | 28 | .def_readwrite("silero_vad", &PyClass::silero_vad) |
| 29 | + .def_readwrite("ten_vad", &PyClass::ten_vad) | ||
| 26 | .def_readwrite("sample_rate", &PyClass::sample_rate) | 30 | .def_readwrite("sample_rate", &PyClass::sample_rate) |
| 27 | .def_readwrite("num_threads", &PyClass::num_threads) | 31 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 28 | .def_readwrite("provider", &PyClass::provider) | 32 | .def_readwrite("provider", &PyClass::provider) |
| @@ -64,6 +64,7 @@ from _sherpa_onnx import ( | @@ -64,6 +64,7 @@ from _sherpa_onnx import ( | ||
| 64 | SpokenLanguageIdentification, | 64 | SpokenLanguageIdentification, |
| 65 | SpokenLanguageIdentificationConfig, | 65 | SpokenLanguageIdentificationConfig, |
| 66 | SpokenLanguageIdentificationWhisperConfig, | 66 | SpokenLanguageIdentificationWhisperConfig, |
| 67 | + TenVadModelConfig, | ||
| 67 | VadModel, | 68 | VadModel, |
| 68 | VadModelConfig, | 69 | VadModelConfig, |
| 69 | VoiceActivityDetector, | 70 | VoiceActivityDetector, |
-
请 注册 或 登录 后发表评论