Fangjun Kuang
Committed by GitHub

Support silero-vad v4 exported by k2-fsa (#2372)

@@ -180,7 +180,8 @@ class SileroVadModel::Impl { @@ -180,7 +180,8 @@ class SileroVadModel::Impl {
180 GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); 180 GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
181 GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); 181 GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
182 182
183 - if (input_names_.size() == 4 && output_names_.size() == 3) { 183 + if ((input_names_.size() == 4 && output_names_.size() == 3) ||
  184 + IsExportedByK2Fsa()) {
184 is_v5_ = false; 185 is_v5_ = false;
185 } else if (input_names_.size() == 3 && output_names_.size() == 2) { 186 } else if (input_names_.size() == 3 && output_names_.size() == 2) {
186 is_v5_ = true; 187 is_v5_ = true;
@@ -248,7 +249,23 @@ class SileroVadModel::Impl { @@ -248,7 +249,23 @@ class SileroVadModel::Impl {
248 } 249 }
249 } 250 }
250 251
  252 + bool IsExportedByK2Fsa() const {
  253 + if (input_names_.size() == 3 && input_names_[0] == "x" &&
  254 + input_names_[1] == "h" && input_names_[2] == "c" &&
  255 + output_names_.size() == 3 && output_names_[0] == "prob" &&
  256 + output_names_[1] == "new_h" && output_names_[2] == "new_c") {
  257 + // this version is exported and maintained by us (k2-fsa)
  258 + return true;
  259 + }
  260 +
  261 + return false;
  262 + }
  263 +
251 void CheckV4() const { 264 void CheckV4() const {
  265 + if (IsExportedByK2Fsa()) {
  266 + return;
  267 + }
  268 +
252 if (input_names_.size() != 4) { 269 if (input_names_.size() != 4) {
253 SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", 270 SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
254 static_cast<int32_t>(input_names_.size())); 271 static_cast<int32_t>(input_names_.size()));
@@ -393,9 +410,15 @@ class SileroVadModel::Impl { @@ -393,9 +410,15 @@ class SileroVadModel::Impl {
393 Ort::Value sr = 410 Ort::Value sr =
394 Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); 411 Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
395 412
396 - std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),  
397 - std::move(states_[0]),  
398 - std::move(states_[1])}; 413 + std::vector<Ort::Value> inputs;
  414 + inputs.reserve(input_names_.size());
  415 +
  416 + inputs.push_back(std::move(x));
  417 + if (input_names_.size() == 4) {
  418 + inputs.push_back(std::move(sr));
  419 + }
  420 + inputs.push_back(std::move(states_[0]));
  421 + inputs.push_back(std::move(states_[1]));
399 422
400 auto out = 423 auto out =
401 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 424 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),