Fangjun Kuang
Committed by GitHub

Support silero_vad version 5 (#1064)

@@ -8,7 +8,7 @@ project(sherpa-onnx) @@ -8,7 +8,7 @@ project(sherpa-onnx)
8 # ./nodejs-addon-examples 8 # ./nodejs-addon-examples
9 # ./dart-api-examples/ 9 # ./dart-api-examples/
10 # ./sherpa-onnx/flutter/CHANGELOG.md 10 # ./sherpa-onnx/flutter/CHANGELOG.md
11 -set(SHERPA_ONNX_VERSION "1.10.5") 11 +set(SHERPA_ONNX_VERSION "1.10.6")
12 12
13 # Disable warning about 13 # Disable warning about
14 # 14 #
1 { 1 {
2 "dependencies": { 2 "dependencies": {
3 - "sherpa-onnx-node": "^1.10.3" 3 + "sherpa-onnx-node": "^1.10.6"
4 } 4 }
5 } 5 }
@@ -61,25 +61,11 @@ class SileroVadModel::Impl { @@ -61,25 +61,11 @@ class SileroVadModel::Impl {
61 #endif 61 #endif
62 62
63 void Reset() { 63 void Reset() {
64 - // 2 - number of LSTM layer  
65 - // 1 - batch size  
66 - // 64 - hidden dim  
67 - std::array<int64_t, 3> shape{2, 1, 64};  
68 -  
69 - Ort::Value h =  
70 - Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());  
71 -  
72 - Ort::Value c =  
73 - Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());  
74 -  
75 - Fill<float>(&h, 0);  
76 - Fill<float>(&c, 0);  
77 -  
78 - states_.clear();  
79 -  
80 - states_.reserve(2);  
81 - states_.push_back(std::move(h));  
82 - states_.push_back(std::move(c)); 64 + if (is_v5_) {
  65 + ResetV5();
  66 + } else {
  67 + ResetV4();
  68 + }
83 69
84 triggered_ = false; 70 triggered_ = false;
85 current_sample_ = 0; 71 current_sample_ = 0;
@@ -94,31 +80,7 @@ class SileroVadModel::Impl { @@ -94,31 +80,7 @@ class SileroVadModel::Impl {
94 exit(-1); 80 exit(-1);
95 } 81 }
96 82
97 - auto memory_info =  
98 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
99 -  
100 - std::array<int64_t, 2> x_shape = {1, n};  
101 -  
102 - Ort::Value x =  
103 - Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,  
104 - x_shape.data(), x_shape.size());  
105 -  
106 - int64_t sr_shape = 1;  
107 - Ort::Value sr =  
108 - Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);  
109 -  
110 - std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),  
111 - std::move(states_[0]),  
112 - std::move(states_[1])};  
113 -  
114 - auto out =  
115 - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),  
116 - output_names_ptr_.data(), output_names_ptr_.size());  
117 -  
118 - states_[0] = std::move(out[1]);  
119 - states_[1] = std::move(out[2]);  
120 -  
121 - float prob = out[0].GetTensorData<float>()[0]; 83 + float prob = Run(samples, n);
122 84
123 float threshold = config_.silero_vad.threshold; 85 float threshold = config_.silero_vad.threshold;
124 86
@@ -186,6 +148,8 @@ class SileroVadModel::Impl { @@ -186,6 +148,8 @@ class SileroVadModel::Impl {
186 148
187 int32_t WindowSize() const { return config_.silero_vad.window_size; } 149 int32_t WindowSize() const { return config_.silero_vad.window_size; }
188 150
  151 + int32_t WindowShift() const { return WindowSize() - window_shift_; }
  152 +
189 int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } 153 int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
190 154
191 int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } 155 int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
@@ -205,12 +169,76 @@ class SileroVadModel::Impl { @@ -205,12 +169,76 @@ class SileroVadModel::Impl {
205 169
206 GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); 170 GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
207 GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); 171 GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  172 +
  173 + if (input_names_.size() == 4 && output_names_.size() == 3) {
  174 + is_v5_ = false;
  175 + } else if (input_names_.size() == 3 && output_names_.size() == 2) {
  176 + is_v5_ = true;
  177 +
  178 + // 64 for 16kHz
  179 + // 32 for 8kHz
  180 + window_shift_ = 64;
  181 +
  182 + if (WindowSize() != 512) {
  183 + SHERPA_ONNX_LOGE(
  184 + "For silero_vad v5, we require window_size to be 512 for 16kHz");
  185 + exit(-1);
  186 + }
  187 + } else {
  188 + SHERPA_ONNX_LOGE("Unsupported silero vad model");
  189 + exit(-1);
  190 + }
  191 +
208 Check(); 192 Check();
209 193
210 Reset(); 194 Reset();
211 } 195 }
212 196
213 - void Check() { 197 + void ResetV5() {
  198 + // 2 - number of LSTM layer
  199 + // 1 - batch size
  200 + // 128 - hidden dim
  201 + std::array<int64_t, 3> shape{2, 1, 128};
  202 +
  203 + Ort::Value s =
  204 + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
  205 +
  206 + Fill<float>(&s, 0);
  207 + states_.clear();
  208 + states_.push_back(std::move(s));
  209 + }
  210 +
  211 + void ResetV4() {
  212 + // 2 - number of LSTM layer
  213 + // 1 - batch size
  214 + // 64 - hidden dim
  215 + std::array<int64_t, 3> shape{2, 1, 64};
  216 +
  217 + Ort::Value h =
  218 + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
  219 +
  220 + Ort::Value c =
  221 + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
  222 +
  223 + Fill<float>(&h, 0);
  224 + Fill<float>(&c, 0);
  225 +
  226 + states_.clear();
  227 +
  228 + states_.reserve(2);
  229 + states_.push_back(std::move(h));
  230 + states_.push_back(std::move(c));
  231 + }
  232 +
  233 + void Check() const {
  234 + if (is_v5_) {
  235 + CheckV5();
  236 + } else {
  237 + CheckV4();
  238 + }
  239 + }
  240 +
  241 + void CheckV4() const {
214 if (input_names_.size() != 4) { 242 if (input_names_.size() != 4) {
215 SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", 243 SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
216 static_cast<int32_t>(input_names_.size())); 244 static_cast<int32_t>(input_names_.size()));
@@ -262,6 +290,114 @@ class SileroVadModel::Impl { @@ -262,6 +290,114 @@ class SileroVadModel::Impl {
262 } 290 }
263 } 291 }
264 292
  293 + void CheckV5() const {
  294 + if (input_names_.size() != 3) {
  295 + SHERPA_ONNX_LOGE("Expect 3 inputs. Given: %d",
  296 + static_cast<int32_t>(input_names_.size()));
  297 + exit(-1);
  298 + }
  299 +
  300 + if (input_names_[0] != "input") {
  301 + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input",
  302 + input_names_[0].c_str());
  303 + exit(-1);
  304 + }
  305 +
  306 + if (input_names_[1] != "state") {
  307 + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: state",
  308 + input_names_[1].c_str());
  309 + exit(-1);
  310 + }
  311 +
  312 + if (input_names_[2] != "sr") {
  313 + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: sr", input_names_[2].c_str());
  314 + exit(-1);
  315 + }
  316 +
  317 + // Now for outputs
  318 + if (output_names_.size() != 2) {
  319 + SHERPA_ONNX_LOGE("Expect 2 outputs. Given: %d",
  320 + static_cast<int32_t>(output_names_.size()));
  321 + exit(-1);
  322 + }
  323 +
  324 + if (output_names_[0] != "output") {
  325 + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output",
  326 + output_names_[0].c_str());
  327 + exit(-1);
  328 + }
  329 +
  330 + if (output_names_[1] != "stateN") {
  331 + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: stateN",
  332 + output_names_[1].c_str());
  333 + exit(-1);
  334 + }
  335 + }
  336 +
  337 + float Run(const float *samples, int32_t n) {
  338 + if (is_v5_) {
  339 + return RunV5(samples, n);
  340 + } else {
  341 + return RunV4(samples, n);
  342 + }
  343 + }
  344 +
  345 + float RunV5(const float *samples, int32_t n) {
  346 + auto memory_info =
  347 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  348 +
  349 + std::array<int64_t, 2> x_shape = {1, n};
  350 +
  351 + Ort::Value x =
  352 + Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
  353 + x_shape.data(), x_shape.size());
  354 +
  355 + int64_t sr_shape = 1;
  356 + Ort::Value sr =
  357 + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
  358 +
  359 + std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states_[0]),
  360 + std::move(sr)};
  361 +
  362 + auto out =
  363 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  364 + output_names_ptr_.data(), output_names_ptr_.size());
  365 +
  366 + states_[0] = std::move(out[1]);
  367 +
  368 + float prob = out[0].GetTensorData<float>()[0];
  369 + return prob;
  370 + }
  371 +
  372 + float RunV4(const float *samples, int32_t n) {
  373 + auto memory_info =
  374 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  375 +
  376 + std::array<int64_t, 2> x_shape = {1, n};
  377 +
  378 + Ort::Value x =
  379 + Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
  380 + x_shape.data(), x_shape.size());
  381 +
  382 + int64_t sr_shape = 1;
  383 + Ort::Value sr =
  384 + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
  385 +
  386 + std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
  387 + std::move(states_[0]),
  388 + std::move(states_[1])};
  389 +
  390 + auto out =
  391 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  392 + output_names_ptr_.data(), output_names_ptr_.size());
  393 +
  394 + states_[0] = std::move(out[1]);
  395 + states_[1] = std::move(out[2]);
  396 +
  397 + float prob = out[0].GetTensorData<float>()[0];
  398 + return prob;
  399 + }
  400 +
265 private: 401 private:
266 VadModelConfig config_; 402 VadModelConfig config_;
267 403
@@ -286,6 +422,10 @@ class SileroVadModel::Impl { @@ -286,6 +422,10 @@ class SileroVadModel::Impl {
286 int32_t current_sample_ = 0; 422 int32_t current_sample_ = 0;
287 int32_t temp_start_ = 0; 423 int32_t temp_start_ = 0;
288 int32_t temp_end_ = 0; 424 int32_t temp_end_ = 0;
  425 +
  426 + int32_t window_shift_ = 0;
  427 +
  428 + bool is_v5_ = false;
289 }; 429 };
290 430
291 SileroVadModel::SileroVadModel(const VadModelConfig &config) 431 SileroVadModel::SileroVadModel(const VadModelConfig &config)
@@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) { @@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) {
306 446
307 int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); } 447 int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); }
308 448
  449 +int32_t SileroVadModel::WindowShift() const { return impl_->WindowShift(); }
  450 +
309 int32_t SileroVadModel::MinSilenceDurationSamples() const { 451 int32_t SileroVadModel::MinSilenceDurationSamples() const {
310 return impl_->MinSilenceDurationSamples(); 452 return impl_->MinSilenceDurationSamples();
311 } 453 }
@@ -39,6 +39,11 @@ class SileroVadModel : public VadModel { @@ -39,6 +39,11 @@ class SileroVadModel : public VadModel {
39 39
40 int32_t WindowSize() const override; 40 int32_t WindowSize() const override;
41 41
  42 + // For silero vad V4, it is WindowSize().
  43 + // For silero vad V5, it is WindowSize()-64 for 16kHz and
  44 + // WindowSize()-32 for 8kHz
  45 + int32_t WindowShift() const override;
  46 +
42 int32_t MinSilenceDurationSamples() const override; 47 int32_t MinSilenceDurationSamples() const override;
43 int32_t MinSpeechDurationSamples() const override; 48 int32_t MinSpeechDurationSamples() const override;
44 49
@@ -40,6 +40,8 @@ class VadModel { @@ -40,6 +40,8 @@ class VadModel {
40 40
41 virtual int32_t WindowSize() const = 0; 41 virtual int32_t WindowSize() const = 0;
42 42
  43 + virtual int32_t WindowShift() const = 0;
  44 +
43 virtual int32_t MinSilenceDurationSamples() const = 0; 45 virtual int32_t MinSilenceDurationSamples() const = 0;
44 virtual int32_t MinSpeechDurationSamples() const = 0; 46 virtual int32_t MinSpeechDurationSamples() const = 0;
45 virtual void SetMinSilenceDuration(float s) = 0; 47 virtual void SetMinSilenceDuration(float s) = 0;
@@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl { @@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl {
38 } 38 }
39 39
40 int32_t window_size = model_->WindowSize(); 40 int32_t window_size = model_->WindowSize();
  41 + int32_t window_shift = model_->WindowShift();
41 42
42 // note n is usually window_size and there is no need to use 43 // note n is usually window_size and there is no need to use
43 // an extra buffer here 44 // an extra buffer here
44 last_.insert(last_.end(), samples, samples + n); 45 last_.insert(last_.end(), samples, samples + n);
45 - int32_t k = static_cast<int32_t>(last_.size()) / window_size; 46 +
  47 + // Note: For v4, window_shift == window_size
  48 + int32_t k =
  49 + (static_cast<int32_t>(last_.size()) - window_size) / window_shift + 1;
46 const float *p = last_.data(); 50 const float *p = last_.data();
47 bool is_speech = false; 51 bool is_speech = false;
48 52
49 - for (int32_t i = 0; i != k; ++i, p += window_size) {  
50 - buffer_.Push(p, window_size); 53 + for (int32_t i = 0; i != k; ++i, p += window_shift) {
  54 + buffer_.Push(p, window_shift);
51 // NOTE(fangjun): Please don't use a very large n. 55 // NOTE(fangjun): Please don't use a very large n.
52 bool this_window_is_speech = model_->IsSpeech(p, window_size); 56 bool this_window_is_speech = model_->IsSpeech(p, window_size);
53 is_speech = is_speech || this_window_is_speech; 57 is_speech = is_speech || this_window_is_speech;