Committed by
GitHub
Support specifying max speech duration for VAD. (#1348)
正在显示
5 个修改的文件
包含
70 行增加
和
7 行删除
| @@ -406,7 +406,14 @@ def main(): | @@ -406,7 +406,14 @@ def main(): | ||
| 406 | 406 | ||
| 407 | config = sherpa_onnx.VadModelConfig() | 407 | config = sherpa_onnx.VadModelConfig() |
| 408 | config.silero_vad.model = args.silero_vad_model | 408 | config.silero_vad.model = args.silero_vad_model |
| 409 | - config.silero_vad.min_silence_duration = 0.25 | 409 | + config.silero_vad.threshold = 0.5 |
| 410 | + config.silero_vad.min_silence_duration = 0.25 # seconds | ||
| 411 | + config.silero_vad.min_speech_duration = 0.25 # seconds | ||
| 412 | + | ||
| 413 | + # If the current segment is larger than this value, then it increases | ||
| 414 | + # the threshold to 0.9 internally. After detecting this segment, | ||
| 415 | + # it resets the threshold to its original value. | ||
| 416 | + config.silero_vad.max_speech_duration = 5 # seconds | ||
| 410 | config.sample_rate = args.sample_rate | 417 | config.sample_rate = args.sample_rate |
| 411 | 418 | ||
| 412 | window_size = config.silero_vad.window_size | 419 | window_size = config.silero_vad.window_size |
| @@ -29,6 +29,12 @@ void SileroVadModelConfig::Register(ParseOptions *po) { | @@ -29,6 +29,12 @@ void SileroVadModelConfig::Register(ParseOptions *po) { | ||
| 29 | "--silero-vad-min-speech-duration seconds before separating it"); | 29 | "--silero-vad-min-speech-duration seconds before separating it"); |
| 30 | 30 | ||
| 31 | po->Register( | 31 | po->Register( |
| 32 | + "silero-vad-max-speech-duration", &max_speech_duration, | ||
| 33 | + "In seconds. If a speech segment is longer than this value, then we " | ||
| 34 | + "increase the threshold to 0.9. After finishing detecting the segment, " | ||
| 35 | + "the threshold value is reset to its original value."); | ||
| 36 | + | ||
| 37 | + po->Register( | ||
| 32 | "silero-vad-window-size", &window_size, | 38 | "silero-vad-window-size", &window_size, |
| 33 | "In samples. Audio chunks of --silero-vad-window-size samples are fed " | 39 | "In samples. Audio chunks of --silero-vad-window-size samples are fed " |
| 34 | "to the silero VAD model. WARNING! Silero VAD models were trained using " | 40 | "to the silero VAD model. WARNING! Silero VAD models were trained using " |
| @@ -63,6 +69,33 @@ bool SileroVadModelConfig::Validate() const { | @@ -63,6 +69,33 @@ bool SileroVadModelConfig::Validate() const { | ||
| 63 | return false; | 69 | return false; |
| 64 | } | 70 | } |
| 65 | 71 | ||
| 72 | + if (min_silence_duration <= 0) { | ||
| 73 | + SHERPA_ONNX_LOGE( | ||
| 74 | + "Please use a larger value for --silero-vad-min-silence-duration. " | ||
| 75 | + "Given: " | ||
| 76 | + "%f", | ||
| 77 | + min_silence_duration); | ||
| 78 | + return false; | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + if (min_speech_duration <= 0) { | ||
| 82 | + SHERPA_ONNX_LOGE( | ||
| 83 | + "Please use a larger value for --silero-vad-min-speech-duration. " | ||
| 84 | + "Given: " | ||
| 85 | + "%f", | ||
| 86 | + min_speech_duration); | ||
| 87 | + return false; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + if (max_speech_duration <= 0) { | ||
| 91 | + SHERPA_ONNX_LOGE( | ||
| 92 | + "Please use a larger value for --silero-vad-max-speech-duration. " | ||
| 93 | + "Given: " | ||
| 94 | + "%f", | ||
| 95 | + max_speech_duration); | ||
| 96 | + return false; | ||
| 97 | + } | ||
| 98 | + | ||
| 66 | return true; | 99 | return true; |
| 67 | } | 100 | } |
| 68 | 101 | ||
| @@ -74,6 +107,7 @@ std::string SileroVadModelConfig::ToString() const { | @@ -74,6 +107,7 @@ std::string SileroVadModelConfig::ToString() const { | ||
| 74 | os << "threshold=" << threshold << ", "; | 107 | os << "threshold=" << threshold << ", "; |
| 75 | os << "min_silence_duration=" << min_silence_duration << ", "; | 108 | os << "min_silence_duration=" << min_silence_duration << ", "; |
| 76 | os << "min_speech_duration=" << min_speech_duration << ", "; | 109 | os << "min_speech_duration=" << min_speech_duration << ", "; |
| 110 | + os << "max_speech_duration=" << max_speech_duration << ", "; | ||
| 77 | os << "window_size=" << window_size << ")"; | 111 | os << "window_size=" << window_size << ")"; |
| 78 | 112 | ||
| 79 | return os.str(); | 113 | return os.str(); |
| @@ -27,6 +27,11 @@ struct SileroVadModelConfig { | @@ -27,6 +27,11 @@ struct SileroVadModelConfig { | ||
| 27 | // 256, 512, 768 samples for 800 Hz | 27 | // 256, 512, 768 samples for 800 Hz |
| 28 | int32_t window_size = 512; // in samples | 28 | int32_t window_size = 512; // in samples |
| 29 | 29 | ||
| 30 | + // If a speech segment is longer than this value, then we increase | ||
| 31 | + // the threshold to 0.9. After finishing detecting the segment, | ||
| 32 | + // the threshold value is reset to its original value. | ||
| 33 | + float max_speech_duration = 20; // in seconds | ||
| 34 | + | ||
| 30 | SileroVadModelConfig() = default; | 35 | SileroVadModelConfig() = default; |
| 31 | 36 | ||
| 32 | void Register(ParseOptions *po); | 37 | void Register(ParseOptions *po); |
| @@ -18,14 +18,18 @@ class VoiceActivityDetector::Impl { | @@ -18,14 +18,18 @@ class VoiceActivityDetector::Impl { | ||
| 18 | explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) | 18 | explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) |
| 19 | : model_(VadModel::Create(config)), | 19 | : model_(VadModel::Create(config)), |
| 20 | config_(config), | 20 | config_(config), |
| 21 | - buffer_(buffer_size_in_seconds * config.sample_rate) {} | 21 | + buffer_(buffer_size_in_seconds * config.sample_rate) { |
| 22 | + Init(); | ||
| 23 | + } | ||
| 22 | 24 | ||
| 23 | #if __ANDROID_API__ >= 9 | 25 | #if __ANDROID_API__ >= 9 |
| 24 | Impl(AAssetManager *mgr, const VadModelConfig &config, | 26 | Impl(AAssetManager *mgr, const VadModelConfig &config, |
| 25 | float buffer_size_in_seconds = 60) | 27 | float buffer_size_in_seconds = 60) |
| 26 | : model_(VadModel::Create(mgr, config)), | 28 | : model_(VadModel::Create(mgr, config)), |
| 27 | config_(config), | 29 | config_(config), |
| 28 | - buffer_(buffer_size_in_seconds * config.sample_rate) {} | 30 | + buffer_(buffer_size_in_seconds * config.sample_rate) { |
| 31 | + Init(); | ||
| 32 | + } | ||
| 29 | #endif | 33 | #endif |
| 30 | 34 | ||
| 31 | void AcceptWaveform(const float *samples, int32_t n) { | 35 | void AcceptWaveform(const float *samples, int32_t n) { |
| @@ -146,6 +150,15 @@ class VoiceActivityDetector::Impl { | @@ -146,6 +150,15 @@ class VoiceActivityDetector::Impl { | ||
| 146 | const VadModelConfig &GetConfig() const { return config_; } | 150 | const VadModelConfig &GetConfig() const { return config_; } |
| 147 | 151 | ||
| 148 | private: | 152 | private: |
| 153 | + void Init() { | ||
| 154 | + // TODO(fangjun): Currently, we support only one vad model. | ||
| 155 | + // If a new vad model is added, we need to change the place | ||
| 156 | + // where max_speech_duration is placed. | ||
| 157 | + max_utterance_length_ = | ||
| 158 | + config_.sample_rate * config_.silero_vad.max_speech_duration; | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + private: | ||
| 149 | std::queue<SpeechSegment> segments_; | 162 | std::queue<SpeechSegment> segments_; |
| 150 | 163 | ||
| 151 | std::unique_ptr<VadModel> model_; | 164 | std::unique_ptr<VadModel> model_; |
| @@ -153,9 +166,9 @@ class VoiceActivityDetector::Impl { | @@ -153,9 +166,9 @@ class VoiceActivityDetector::Impl { | ||
| 153 | CircularBuffer buffer_; | 166 | CircularBuffer buffer_; |
| 154 | std::vector<float> last_; | 167 | std::vector<float> last_; |
| 155 | 168 | ||
| 156 | - int max_utterance_length_ = 16000 * 20; // in samples | 169 | + int max_utterance_length_ = -1; // in samples |
| 157 | float new_min_silence_duration_s_ = 0.1; | 170 | float new_min_silence_duration_s_ = 0.1; |
| 158 | - float new_threshold_ = 1.10; | 171 | + float new_threshold_ = 0.90; |
| 159 | 172 | ||
| 160 | int32_t start_ = -1; | 173 | int32_t start_ = -1; |
| 161 | }; | 174 | }; |
| @@ -17,7 +17,8 @@ void PybindSileroVadModelConfig(py::module *m) { | @@ -17,7 +17,8 @@ void PybindSileroVadModelConfig(py::module *m) { | ||
| 17 | .def(py::init<>()) | 17 | .def(py::init<>()) |
| 18 | .def(py::init([](const std::string &model, float threshold, | 18 | .def(py::init([](const std::string &model, float threshold, |
| 19 | float min_silence_duration, float min_speech_duration, | 19 | float min_silence_duration, float min_speech_duration, |
| 20 | - int32_t window_size) -> std::unique_ptr<PyClass> { | 20 | + int32_t window_size, |
| 21 | + float max_speech_duration) -> std::unique_ptr<PyClass> { | ||
| 21 | auto ans = std::make_unique<PyClass>(); | 22 | auto ans = std::make_unique<PyClass>(); |
| 22 | 23 | ||
| 23 | ans->model = model; | 24 | ans->model = model; |
| @@ -25,17 +26,20 @@ void PybindSileroVadModelConfig(py::module *m) { | @@ -25,17 +26,20 @@ void PybindSileroVadModelConfig(py::module *m) { | ||
| 25 | ans->min_silence_duration = min_silence_duration; | 26 | ans->min_silence_duration = min_silence_duration; |
| 26 | ans->min_speech_duration = min_speech_duration; | 27 | ans->min_speech_duration = min_speech_duration; |
| 27 | ans->window_size = window_size; | 28 | ans->window_size = window_size; |
| 29 | + ans->max_speech_duration = max_speech_duration; | ||
| 28 | 30 | ||
| 29 | return ans; | 31 | return ans; |
| 30 | }), | 32 | }), |
| 31 | py::arg("model"), py::arg("threshold") = 0.5, | 33 | py::arg("model"), py::arg("threshold") = 0.5, |
| 32 | py::arg("min_silence_duration") = 0.5, | 34 | py::arg("min_silence_duration") = 0.5, |
| 33 | - py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512) | 35 | + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512, |
| 36 | + py::arg("max_speech_duration") = 20) | ||
| 34 | .def_readwrite("model", &PyClass::model) | 37 | .def_readwrite("model", &PyClass::model) |
| 35 | .def_readwrite("threshold", &PyClass::threshold) | 38 | .def_readwrite("threshold", &PyClass::threshold) |
| 36 | .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) | 39 | .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) |
| 37 | .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) | 40 | .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) |
| 38 | .def_readwrite("window_size", &PyClass::window_size) | 41 | .def_readwrite("window_size", &PyClass::window_size) |
| 42 | + .def_readwrite("max_speech_duration", &PyClass::max_speech_duration) | ||
| 39 | .def("__str__", &PyClass::ToString) | 43 | .def("__str__", &PyClass::ToString) |
| 40 | .def("validate", &PyClass::Validate); | 44 | .def("validate", &PyClass::Validate); |
| 41 | } | 45 | } |
-
请 注册 或 登录 后发表评论