Fangjun Kuang
Committed by GitHub

Support silero-vad v4 exported by k2-fsa (#2372)

... ... @@ -180,7 +180,8 @@ class SileroVadModel::Impl {
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
if (input_names_.size() == 4 && output_names_.size() == 3) {
if ((input_names_.size() == 4 && output_names_.size() == 3) ||
IsExportedByK2Fsa()) {
is_v5_ = false;
} else if (input_names_.size() == 3 && output_names_.size() == 2) {
is_v5_ = true;
... ... @@ -248,7 +249,23 @@ class SileroVadModel::Impl {
}
}
bool IsExportedByK2Fsa() const {
if (input_names_.size() == 3 && input_names_[0] == "x" &&
input_names_[1] == "h" && input_names_[2] == "c" &&
output_names_.size() == 3 && output_names_[0] == "prob" &&
output_names_[1] == "new_h" && output_names_[2] == "new_c") {
// this version is exported and maintained by us (k2-fsa)
return true;
}
return false;
}
void CheckV4() const {
if (IsExportedByK2Fsa()) {
return;
}
if (input_names_.size() != 4) {
SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
static_cast<int32_t>(input_names_.size()));
... ... @@ -393,9 +410,15 @@ class SileroVadModel::Impl {
Ort::Value sr =
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
std::move(states_[0]),
std::move(states_[1])};
std::vector<Ort::Value> inputs;
inputs.reserve(input_names_.size());
inputs.push_back(std::move(x));
if (input_names_.size() == 4) {
inputs.push_back(std::move(sr));
}
inputs.push_back(std::move(states_[0]));
inputs.push_back(std::move(states_[1]));
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
... ...