Fangjun Kuang
Committed by GitHub

Support specifying max speech duration for VAD. (#1348)

@@ -406,7 +406,14 @@ def main(): @@ -406,7 +406,14 @@ def main():
406 406
407 config = sherpa_onnx.VadModelConfig() 407 config = sherpa_onnx.VadModelConfig()
408 config.silero_vad.model = args.silero_vad_model 408 config.silero_vad.model = args.silero_vad_model
409 - config.silero_vad.min_silence_duration = 0.25 409 + config.silero_vad.threshold = 0.5
  410 + config.silero_vad.min_silence_duration = 0.25 # seconds
  411 + config.silero_vad.min_speech_duration = 0.25 # seconds
  412 +
  413 + # If the current segment is larger than this value, then it increases
  414 + # the threshold to 0.9 internally. After detecting this segment,
  415 + # it resets the threshold to its original value.
  416 + config.silero_vad.max_speech_duration = 5 # seconds
410 config.sample_rate = args.sample_rate 417 config.sample_rate = args.sample_rate
411 418
412 window_size = config.silero_vad.window_size 419 window_size = config.silero_vad.window_size
@@ -29,6 +29,12 @@ void SileroVadModelConfig::Register(ParseOptions *po) { @@ -29,6 +29,12 @@ void SileroVadModelConfig::Register(ParseOptions *po) {
29 "--silero-vad-min-speech-duration seconds before separating it"); 29 "--silero-vad-min-speech-duration seconds before separating it");
30 30
31 po->Register( 31 po->Register(
  32 + "silero-vad-max-speech-duration", &max_speech_duration,
  33 + "In seconds. If a speech segment is longer than this value, then we "
  34 + "increase the threshold to 0.9. After finishing detecting the segment, "
  35 + "the threshold value is reset to its original value.");
  36 +
  37 + po->Register(
32 "silero-vad-window-size", &window_size, 38 "silero-vad-window-size", &window_size,
33 "In samples. Audio chunks of --silero-vad-window-size samples are fed " 39 "In samples. Audio chunks of --silero-vad-window-size samples are fed "
34 "to the silero VAD model. WARNING! Silero VAD models were trained using " 40 "to the silero VAD model. WARNING! Silero VAD models were trained using "
@@ -63,6 +69,33 @@ bool SileroVadModelConfig::Validate() const { @@ -63,6 +69,33 @@ bool SileroVadModelConfig::Validate() const {
63 return false; 69 return false;
64 } 70 }
65 71
  72 + if (min_silence_duration <= 0) {
  73 + SHERPA_ONNX_LOGE(
  74 + "Please use a larger value for --silero-vad-min-silence-duration. "
  75 + "Given: "
  76 + "%f",
  77 + min_silence_duration);
  78 + return false;
  79 + }
  80 +
  81 + if (min_speech_duration <= 0) {
  82 + SHERPA_ONNX_LOGE(
  83 + "Please use a larger value for --silero-vad-min-speech-duration. "
  84 + "Given: "
  85 + "%f",
  86 + min_speech_duration);
  87 + return false;
  88 + }
  89 +
  90 + if (max_speech_duration <= 0) {
  91 + SHERPA_ONNX_LOGE(
  92 + "Please use a larger value for --silero-vad-max-speech-duration. "
  93 + "Given: "
  94 + "%f",
  95 + max_speech_duration);
  96 + return false;
  97 + }
  98 +
66 return true; 99 return true;
67 } 100 }
68 101
@@ -74,6 +107,7 @@ std::string SileroVadModelConfig::ToString() const { @@ -74,6 +107,7 @@ std::string SileroVadModelConfig::ToString() const {
74 os << "threshold=" << threshold << ", "; 107 os << "threshold=" << threshold << ", ";
75 os << "min_silence_duration=" << min_silence_duration << ", "; 108 os << "min_silence_duration=" << min_silence_duration << ", ";
76 os << "min_speech_duration=" << min_speech_duration << ", "; 109 os << "min_speech_duration=" << min_speech_duration << ", ";
  110 + os << "max_speech_duration=" << max_speech_duration << ", ";
77 os << "window_size=" << window_size << ")"; 111 os << "window_size=" << window_size << ")";
78 112
79 return os.str(); 113 return os.str();
@@ -27,6 +27,11 @@ struct SileroVadModelConfig { @@ -27,6 +27,11 @@ struct SileroVadModelConfig {
27 // 256, 512, 768 samples for 800 Hz 27 // 256, 512, 768 samples for 800 Hz
28 int32_t window_size = 512; // in samples 28 int32_t window_size = 512; // in samples
29 29
  30 + // If a speech segment is longer than this value, then we increase
  31 + // the threshold to 0.9. After finishing detecting the segment,
  32 + // the threshold value is reset to its original value.
  33 + float max_speech_duration = 20; // in seconds
  34 +
30 SileroVadModelConfig() = default; 35 SileroVadModelConfig() = default;
31 36
32 void Register(ParseOptions *po); 37 void Register(ParseOptions *po);
@@ -18,14 +18,18 @@ class VoiceActivityDetector::Impl { @@ -18,14 +18,18 @@ class VoiceActivityDetector::Impl {
18 explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) 18 explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60)
19 : model_(VadModel::Create(config)), 19 : model_(VadModel::Create(config)),
20 config_(config), 20 config_(config),
21 - buffer_(buffer_size_in_seconds * config.sample_rate) {} 21 + buffer_(buffer_size_in_seconds * config.sample_rate) {
  22 + Init();
  23 + }
22 24
23 #if __ANDROID_API__ >= 9 25 #if __ANDROID_API__ >= 9
24 Impl(AAssetManager *mgr, const VadModelConfig &config, 26 Impl(AAssetManager *mgr, const VadModelConfig &config,
25 float buffer_size_in_seconds = 60) 27 float buffer_size_in_seconds = 60)
26 : model_(VadModel::Create(mgr, config)), 28 : model_(VadModel::Create(mgr, config)),
27 config_(config), 29 config_(config),
28 - buffer_(buffer_size_in_seconds * config.sample_rate) {} 30 + buffer_(buffer_size_in_seconds * config.sample_rate) {
  31 + Init();
  32 + }
29 #endif 33 #endif
30 34
31 void AcceptWaveform(const float *samples, int32_t n) { 35 void AcceptWaveform(const float *samples, int32_t n) {
@@ -146,6 +150,15 @@ class VoiceActivityDetector::Impl { @@ -146,6 +150,15 @@ class VoiceActivityDetector::Impl {
146 const VadModelConfig &GetConfig() const { return config_; } 150 const VadModelConfig &GetConfig() const { return config_; }
147 151
148 private: 152 private:
  153 + void Init() {
  154 + // TODO(fangjun): Currently, we support only one vad model.
  155 + // If a new vad model is added, we need to change the place
  156 + // where max_speech_duration is placed.
  157 + max_utterance_length_ =
  158 + config_.sample_rate * config_.silero_vad.max_speech_duration;
  159 + }
  160 +
  161 + private:
149 std::queue<SpeechSegment> segments_; 162 std::queue<SpeechSegment> segments_;
150 163
151 std::unique_ptr<VadModel> model_; 164 std::unique_ptr<VadModel> model_;
@@ -153,9 +166,9 @@ class VoiceActivityDetector::Impl { @@ -153,9 +166,9 @@ class VoiceActivityDetector::Impl {
153 CircularBuffer buffer_; 166 CircularBuffer buffer_;
154 std::vector<float> last_; 167 std::vector<float> last_;
155 168
156 - int max_utterance_length_ = 16000 * 20; // in samples 169 + int max_utterance_length_ = -1; // in samples
157 float new_min_silence_duration_s_ = 0.1; 170 float new_min_silence_duration_s_ = 0.1;
158 - float new_threshold_ = 1.10; 171 + float new_threshold_ = 0.90;
159 172
160 int32_t start_ = -1; 173 int32_t start_ = -1;
161 }; 174 };
@@ -17,7 +17,8 @@ void PybindSileroVadModelConfig(py::module *m) { @@ -17,7 +17,8 @@ void PybindSileroVadModelConfig(py::module *m) {
17 .def(py::init<>()) 17 .def(py::init<>())
18 .def(py::init([](const std::string &model, float threshold, 18 .def(py::init([](const std::string &model, float threshold,
19 float min_silence_duration, float min_speech_duration, 19 float min_silence_duration, float min_speech_duration,
20 - int32_t window_size) -> std::unique_ptr<PyClass> { 20 + int32_t window_size,
  21 + float max_speech_duration) -> std::unique_ptr<PyClass> {
21 auto ans = std::make_unique<PyClass>(); 22 auto ans = std::make_unique<PyClass>();
22 23
23 ans->model = model; 24 ans->model = model;
@@ -25,17 +26,20 @@ void PybindSileroVadModelConfig(py::module *m) { @@ -25,17 +26,20 @@ void PybindSileroVadModelConfig(py::module *m) {
25 ans->min_silence_duration = min_silence_duration; 26 ans->min_silence_duration = min_silence_duration;
26 ans->min_speech_duration = min_speech_duration; 27 ans->min_speech_duration = min_speech_duration;
27 ans->window_size = window_size; 28 ans->window_size = window_size;
  29 + ans->max_speech_duration = max_speech_duration;
28 30
29 return ans; 31 return ans;
30 }), 32 }),
31 py::arg("model"), py::arg("threshold") = 0.5, 33 py::arg("model"), py::arg("threshold") = 0.5,
32 py::arg("min_silence_duration") = 0.5, 34 py::arg("min_silence_duration") = 0.5,
33 - py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512) 35 + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512,
  36 + py::arg("max_speech_duration") = 20)
34 .def_readwrite("model", &PyClass::model) 37 .def_readwrite("model", &PyClass::model)
35 .def_readwrite("threshold", &PyClass::threshold) 38 .def_readwrite("threshold", &PyClass::threshold)
36 .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) 39 .def_readwrite("min_silence_duration", &PyClass::min_silence_duration)
37 .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) 40 .def_readwrite("min_speech_duration", &PyClass::min_speech_duration)
38 .def_readwrite("window_size", &PyClass::window_size) 41 .def_readwrite("window_size", &PyClass::window_size)
  42 + .def_readwrite("max_speech_duration", &PyClass::max_speech_duration)
39 .def("__str__", &PyClass::ToString) 43 .def("__str__", &PyClass::ToString)
40 .def("validate", &PyClass::Validate); 44 .def("validate", &PyClass::Validate);
41 } 45 }