Committed by
GitHub
Support silero-vad v4 exported by k2-fsa (#2372)
正在显示
1 个修改的文件
包含
27 行增加
和
4 行删除
| @@ -180,7 +180,8 @@ class SileroVadModel::Impl { | @@ -180,7 +180,8 @@ class SileroVadModel::Impl { | ||
| 180 | GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | 180 | GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); |
| 181 | GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | 181 | GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); |
| 182 | 182 | ||
| 183 | - if (input_names_.size() == 4 && output_names_.size() == 3) { | 183 | + if ((input_names_.size() == 4 && output_names_.size() == 3) || |
| 184 | + IsExportedByK2Fsa()) { | ||
| 184 | is_v5_ = false; | 185 | is_v5_ = false; |
| 185 | } else if (input_names_.size() == 3 && output_names_.size() == 2) { | 186 | } else if (input_names_.size() == 3 && output_names_.size() == 2) { |
| 186 | is_v5_ = true; | 187 | is_v5_ = true; |
| @@ -248,7 +249,23 @@ class SileroVadModel::Impl { | @@ -248,7 +249,23 @@ class SileroVadModel::Impl { | ||
| 248 | } | 249 | } |
| 249 | } | 250 | } |
| 250 | 251 | ||
| 252 | + bool IsExportedByK2Fsa() const { | ||
| 253 | + if (input_names_.size() == 3 && input_names_[0] == "x" && | ||
| 254 | + input_names_[1] == "h" && input_names_[2] == "c" && | ||
| 255 | + output_names_.size() == 3 && output_names_[0] == "prob" && | ||
| 256 | + output_names_[1] == "new_h" && output_names_[2] == "new_c") { | ||
| 257 | + // this version is exported and maintained by us (k2-fsa) | ||
| 258 | + return true; | ||
| 259 | + } | ||
| 260 | + | ||
| 261 | + return false; | ||
| 262 | + } | ||
| 263 | + | ||
| 251 | void CheckV4() const { | 264 | void CheckV4() const { |
| 265 | + if (IsExportedByK2Fsa()) { | ||
| 266 | + return; | ||
| 267 | + } | ||
| 268 | + | ||
| 252 | if (input_names_.size() != 4) { | 269 | if (input_names_.size() != 4) { |
| 253 | SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", | 270 | SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", |
| 254 | static_cast<int32_t>(input_names_.size())); | 271 | static_cast<int32_t>(input_names_.size())); |
| @@ -393,9 +410,15 @@ class SileroVadModel::Impl { | @@ -393,9 +410,15 @@ class SileroVadModel::Impl { | ||
| 393 | Ort::Value sr = | 410 | Ort::Value sr = |
| 394 | Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); | 411 | Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); |
| 395 | 412 | ||
| 396 | - std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr), | ||
| 397 | - std::move(states_[0]), | ||
| 398 | - std::move(states_[1])}; | 413 | + std::vector<Ort::Value> inputs; |
| 414 | + inputs.reserve(input_names_.size()); | ||
| 415 | + | ||
| 416 | + inputs.push_back(std::move(x)); | ||
| 417 | + if (input_names_.size() == 4) { | ||
| 418 | + inputs.push_back(std::move(sr)); | ||
| 419 | + } | ||
| 420 | + inputs.push_back(std::move(states_[0])); | ||
| 421 | + inputs.push_back(std::move(states_[1])); | ||
| 399 | 422 | ||
| 400 | auto out = | 423 | auto out = |
| 401 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | 424 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), |
-
请 注册 或 登录 后发表评论