正在显示
6 个修改的文件
包含
203 行增加
和
50 行删除
| @@ -8,7 +8,7 @@ project(sherpa-onnx) | @@ -8,7 +8,7 @@ project(sherpa-onnx) | ||
| 8 | # ./nodejs-addon-examples | 8 | # ./nodejs-addon-examples |
| 9 | # ./dart-api-examples/ | 9 | # ./dart-api-examples/ |
| 10 | # ./sherpa-onnx/flutter/CHANGELOG.md | 10 | # ./sherpa-onnx/flutter/CHANGELOG.md |
| 11 | -set(SHERPA_ONNX_VERSION "1.10.5") | 11 | +set(SHERPA_ONNX_VERSION "1.10.6") |
| 12 | 12 | ||
| 13 | # Disable warning about | 13 | # Disable warning about |
| 14 | # | 14 | # |
| @@ -61,25 +61,11 @@ class SileroVadModel::Impl { | @@ -61,25 +61,11 @@ class SileroVadModel::Impl { | ||
| 61 | #endif | 61 | #endif |
| 62 | 62 | ||
| 63 | void Reset() { | 63 | void Reset() { |
| 64 | - // 2 - number of LSTM layer | ||
| 65 | - // 1 - batch size | ||
| 66 | - // 64 - hidden dim | ||
| 67 | - std::array<int64_t, 3> shape{2, 1, 64}; | ||
| 68 | - | ||
| 69 | - Ort::Value h = | ||
| 70 | - Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size()); | ||
| 71 | - | ||
| 72 | - Ort::Value c = | ||
| 73 | - Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size()); | ||
| 74 | - | ||
| 75 | - Fill<float>(&h, 0); | ||
| 76 | - Fill<float>(&c, 0); | ||
| 77 | - | ||
| 78 | - states_.clear(); | ||
| 79 | - | ||
| 80 | - states_.reserve(2); | ||
| 81 | - states_.push_back(std::move(h)); | ||
| 82 | - states_.push_back(std::move(c)); | 64 | + if (is_v5_) { |
| 65 | + ResetV5(); | ||
| 66 | + } else { | ||
| 67 | + ResetV4(); | ||
| 68 | + } | ||
| 83 | 69 | ||
| 84 | triggered_ = false; | 70 | triggered_ = false; |
| 85 | current_sample_ = 0; | 71 | current_sample_ = 0; |
| @@ -94,31 +80,7 @@ class SileroVadModel::Impl { | @@ -94,31 +80,7 @@ class SileroVadModel::Impl { | ||
| 94 | exit(-1); | 80 | exit(-1); |
| 95 | } | 81 | } |
| 96 | 82 | ||
| 97 | - auto memory_info = | ||
| 98 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 99 | - | ||
| 100 | - std::array<int64_t, 2> x_shape = {1, n}; | ||
| 101 | - | ||
| 102 | - Ort::Value x = | ||
| 103 | - Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n, | ||
| 104 | - x_shape.data(), x_shape.size()); | ||
| 105 | - | ||
| 106 | - int64_t sr_shape = 1; | ||
| 107 | - Ort::Value sr = | ||
| 108 | - Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); | ||
| 109 | - | ||
| 110 | - std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr), | ||
| 111 | - std::move(states_[0]), | ||
| 112 | - std::move(states_[1])}; | ||
| 113 | - | ||
| 114 | - auto out = | ||
| 115 | - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 116 | - output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 117 | - | ||
| 118 | - states_[0] = std::move(out[1]); | ||
| 119 | - states_[1] = std::move(out[2]); | ||
| 120 | - | ||
| 121 | - float prob = out[0].GetTensorData<float>()[0]; | 83 | + float prob = Run(samples, n); |
| 122 | 84 | ||
| 123 | float threshold = config_.silero_vad.threshold; | 85 | float threshold = config_.silero_vad.threshold; |
| 124 | 86 | ||
| @@ -186,6 +148,8 @@ class SileroVadModel::Impl { | @@ -186,6 +148,8 @@ class SileroVadModel::Impl { | ||
| 186 | 148 | ||
| 187 | int32_t WindowSize() const { return config_.silero_vad.window_size; } | 149 | int32_t WindowSize() const { return config_.silero_vad.window_size; } |
| 188 | 150 | ||
| 151 | + int32_t WindowShift() const { return WindowSize() - window_shift_; } | ||
| 152 | + | ||
| 189 | int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } | 153 | int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } |
| 190 | 154 | ||
| 191 | int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } | 155 | int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } |
| @@ -205,12 +169,76 @@ class SileroVadModel::Impl { | @@ -205,12 +169,76 @@ class SileroVadModel::Impl { | ||
| 205 | 169 | ||
| 206 | GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | 170 | GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); |
| 207 | GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | 171 | GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); |
| 172 | + | ||
| 173 | + if (input_names_.size() == 4 && output_names_.size() == 3) { | ||
| 174 | + is_v5_ = false; | ||
| 175 | + } else if (input_names_.size() == 3 && output_names_.size() == 2) { | ||
| 176 | + is_v5_ = true; | ||
| 177 | + | ||
| 178 | + // 64 for 16kHz | ||
| 179 | + // 32 for 8kHz | ||
| 180 | + window_shift_ = 64; | ||
| 181 | + | ||
| 182 | + if (WindowSize() != 512) { | ||
| 183 | + SHERPA_ONNX_LOGE( | ||
| 184 | + "For silero_vad v5, we require window_size to be 512 for 16kHz"); | ||
| 185 | + exit(-1); | ||
| 186 | + } | ||
| 187 | + } else { | ||
| 188 | + SHERPA_ONNX_LOGE("Unsupported silero vad model"); | ||
| 189 | + exit(-1); | ||
| 190 | + } | ||
| 191 | + | ||
| 208 | Check(); | 192 | Check(); |
| 209 | 193 | ||
| 210 | Reset(); | 194 | Reset(); |
| 211 | } | 195 | } |
| 212 | 196 | ||
| 213 | - void Check() { | 197 | + void ResetV5() { |
| 198 | + // 2 - number of LSTM layer | ||
| 199 | + // 1 - batch size | ||
| 200 | + // 128 - hidden dim | ||
| 201 | + std::array<int64_t, 3> shape{2, 1, 128}; | ||
| 202 | + | ||
| 203 | + Ort::Value s = | ||
| 204 | + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size()); | ||
| 205 | + | ||
| 206 | + Fill<float>(&s, 0); | ||
| 207 | + states_.clear(); | ||
| 208 | + states_.push_back(std::move(s)); | ||
| 209 | + } | ||
| 210 | + | ||
| 211 | + void ResetV4() { | ||
| 212 | + // 2 - number of LSTM layer | ||
| 213 | + // 1 - batch size | ||
| 214 | + // 64 - hidden dim | ||
| 215 | + std::array<int64_t, 3> shape{2, 1, 64}; | ||
| 216 | + | ||
| 217 | + Ort::Value h = | ||
| 218 | + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size()); | ||
| 219 | + | ||
| 220 | + Ort::Value c = | ||
| 221 | + Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size()); | ||
| 222 | + | ||
| 223 | + Fill<float>(&h, 0); | ||
| 224 | + Fill<float>(&c, 0); | ||
| 225 | + | ||
| 226 | + states_.clear(); | ||
| 227 | + | ||
| 228 | + states_.reserve(2); | ||
| 229 | + states_.push_back(std::move(h)); | ||
| 230 | + states_.push_back(std::move(c)); | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + void Check() const { | ||
| 234 | + if (is_v5_) { | ||
| 235 | + CheckV5(); | ||
| 236 | + } else { | ||
| 237 | + CheckV4(); | ||
| 238 | + } | ||
| 239 | + } | ||
| 240 | + | ||
| 241 | + void CheckV4() const { | ||
| 214 | if (input_names_.size() != 4) { | 242 | if (input_names_.size() != 4) { |
| 215 | SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", | 243 | SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", |
| 216 | static_cast<int32_t>(input_names_.size())); | 244 | static_cast<int32_t>(input_names_.size())); |
| @@ -262,6 +290,114 @@ class SileroVadModel::Impl { | @@ -262,6 +290,114 @@ class SileroVadModel::Impl { | ||
| 262 | } | 290 | } |
| 263 | } | 291 | } |
| 264 | 292 | ||
| 293 | + void CheckV5() const { | ||
| 294 | + if (input_names_.size() != 3) { | ||
| 295 | + SHERPA_ONNX_LOGE("Expect 3 inputs. Given: %d", | ||
| 296 | + static_cast<int32_t>(input_names_.size())); | ||
| 297 | + exit(-1); | ||
| 298 | + } | ||
| 299 | + | ||
| 300 | + if (input_names_[0] != "input") { | ||
| 301 | + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input", | ||
| 302 | + input_names_[0].c_str()); | ||
| 303 | + exit(-1); | ||
| 304 | + } | ||
| 305 | + | ||
| 306 | + if (input_names_[1] != "state") { | ||
| 307 | + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: state", | ||
| 308 | + input_names_[1].c_str()); | ||
| 309 | + exit(-1); | ||
| 310 | + } | ||
| 311 | + | ||
| 312 | + if (input_names_[2] != "sr") { | ||
| 313 | + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: sr", input_names_[2].c_str()); | ||
| 314 | + exit(-1); | ||
| 315 | + } | ||
| 316 | + | ||
| 317 | + // Now for outputs | ||
| 318 | + if (output_names_.size() != 2) { | ||
| 319 | + SHERPA_ONNX_LOGE("Expect 2 outputs. Given: %d", | ||
| 320 | + static_cast<int32_t>(output_names_.size())); | ||
| 321 | + exit(-1); | ||
| 322 | + } | ||
| 323 | + | ||
| 324 | + if (output_names_[0] != "output") { | ||
| 325 | + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output", | ||
| 326 | + output_names_[0].c_str()); | ||
| 327 | + exit(-1); | ||
| 328 | + } | ||
| 329 | + | ||
| 330 | + if (output_names_[1] != "stateN") { | ||
| 331 | + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: stateN", | ||
| 332 | + output_names_[1].c_str()); | ||
| 333 | + exit(-1); | ||
| 334 | + } | ||
| 335 | + } | ||
| 336 | + | ||
| 337 | + float Run(const float *samples, int32_t n) { | ||
| 338 | + if (is_v5_) { | ||
| 339 | + return RunV5(samples, n); | ||
| 340 | + } else { | ||
| 341 | + return RunV4(samples, n); | ||
| 342 | + } | ||
| 343 | + } | ||
| 344 | + | ||
| 345 | + float RunV5(const float *samples, int32_t n) { | ||
| 346 | + auto memory_info = | ||
| 347 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 348 | + | ||
| 349 | + std::array<int64_t, 2> x_shape = {1, n}; | ||
| 350 | + | ||
| 351 | + Ort::Value x = | ||
| 352 | + Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n, | ||
| 353 | + x_shape.data(), x_shape.size()); | ||
| 354 | + | ||
| 355 | + int64_t sr_shape = 1; | ||
| 356 | + Ort::Value sr = | ||
| 357 | + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); | ||
| 358 | + | ||
| 359 | + std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states_[0]), | ||
| 360 | + std::move(sr)}; | ||
| 361 | + | ||
| 362 | + auto out = | ||
| 363 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 364 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 365 | + | ||
| 366 | + states_[0] = std::move(out[1]); | ||
| 367 | + | ||
| 368 | + float prob = out[0].GetTensorData<float>()[0]; | ||
| 369 | + return prob; | ||
| 370 | + } | ||
| 371 | + | ||
| 372 | + float RunV4(const float *samples, int32_t n) { | ||
| 373 | + auto memory_info = | ||
| 374 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 375 | + | ||
| 376 | + std::array<int64_t, 2> x_shape = {1, n}; | ||
| 377 | + | ||
| 378 | + Ort::Value x = | ||
| 379 | + Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n, | ||
| 380 | + x_shape.data(), x_shape.size()); | ||
| 381 | + | ||
| 382 | + int64_t sr_shape = 1; | ||
| 383 | + Ort::Value sr = | ||
| 384 | + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); | ||
| 385 | + | ||
| 386 | + std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr), | ||
| 387 | + std::move(states_[0]), | ||
| 388 | + std::move(states_[1])}; | ||
| 389 | + | ||
| 390 | + auto out = | ||
| 391 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 392 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 393 | + | ||
| 394 | + states_[0] = std::move(out[1]); | ||
| 395 | + states_[1] = std::move(out[2]); | ||
| 396 | + | ||
| 397 | + float prob = out[0].GetTensorData<float>()[0]; | ||
| 398 | + return prob; | ||
| 399 | + } | ||
| 400 | + | ||
| 265 | private: | 401 | private: |
| 266 | VadModelConfig config_; | 402 | VadModelConfig config_; |
| 267 | 403 | ||
| @@ -286,6 +422,10 @@ class SileroVadModel::Impl { | @@ -286,6 +422,10 @@ class SileroVadModel::Impl { | ||
| 286 | int32_t current_sample_ = 0; | 422 | int32_t current_sample_ = 0; |
| 287 | int32_t temp_start_ = 0; | 423 | int32_t temp_start_ = 0; |
| 288 | int32_t temp_end_ = 0; | 424 | int32_t temp_end_ = 0; |
| 425 | + | ||
| 426 | + int32_t window_shift_ = 0; | ||
| 427 | + | ||
| 428 | + bool is_v5_ = false; | ||
| 289 | }; | 429 | }; |
| 290 | 430 | ||
| 291 | SileroVadModel::SileroVadModel(const VadModelConfig &config) | 431 | SileroVadModel::SileroVadModel(const VadModelConfig &config) |
| @@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) { | @@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) { | ||
| 306 | 446 | ||
| 307 | int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); } | 447 | int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); } |
| 308 | 448 | ||
| 449 | +int32_t SileroVadModel::WindowShift() const { return impl_->WindowShift(); } | ||
| 450 | + | ||
| 309 | int32_t SileroVadModel::MinSilenceDurationSamples() const { | 451 | int32_t SileroVadModel::MinSilenceDurationSamples() const { |
| 310 | return impl_->MinSilenceDurationSamples(); | 452 | return impl_->MinSilenceDurationSamples(); |
| 311 | } | 453 | } |
| @@ -39,6 +39,11 @@ class SileroVadModel : public VadModel { | @@ -39,6 +39,11 @@ class SileroVadModel : public VadModel { | ||
| 39 | 39 | ||
| 40 | int32_t WindowSize() const override; | 40 | int32_t WindowSize() const override; |
| 41 | 41 | ||
| 42 | + // For silero vad V4, it is WindowSize(). | ||
| 43 | + // For silero vad V5, it is WindowSize()-64 for 16kHz and | ||
| 44 | + // WindowSize()-32 for 8kHz | ||
| 45 | + int32_t WindowShift() const override; | ||
| 46 | + | ||
| 42 | int32_t MinSilenceDurationSamples() const override; | 47 | int32_t MinSilenceDurationSamples() const override; |
| 43 | int32_t MinSpeechDurationSamples() const override; | 48 | int32_t MinSpeechDurationSamples() const override; |
| 44 | 49 |
| @@ -40,6 +40,8 @@ class VadModel { | @@ -40,6 +40,8 @@ class VadModel { | ||
| 40 | 40 | ||
| 41 | virtual int32_t WindowSize() const = 0; | 41 | virtual int32_t WindowSize() const = 0; |
| 42 | 42 | ||
| 43 | + virtual int32_t WindowShift() const = 0; | ||
| 44 | + | ||
| 43 | virtual int32_t MinSilenceDurationSamples() const = 0; | 45 | virtual int32_t MinSilenceDurationSamples() const = 0; |
| 44 | virtual int32_t MinSpeechDurationSamples() const = 0; | 46 | virtual int32_t MinSpeechDurationSamples() const = 0; |
| 45 | virtual void SetMinSilenceDuration(float s) = 0; | 47 | virtual void SetMinSilenceDuration(float s) = 0; |
| @@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl { | @@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl { | ||
| 38 | } | 38 | } |
| 39 | 39 | ||
| 40 | int32_t window_size = model_->WindowSize(); | 40 | int32_t window_size = model_->WindowSize(); |
| 41 | + int32_t window_shift = model_->WindowShift(); | ||
| 41 | 42 | ||
| 42 | // note n is usually window_size and there is no need to use | 43 | // note n is usually window_size and there is no need to use |
| 43 | // an extra buffer here | 44 | // an extra buffer here |
| 44 | last_.insert(last_.end(), samples, samples + n); | 45 | last_.insert(last_.end(), samples, samples + n); |
| 45 | - int32_t k = static_cast<int32_t>(last_.size()) / window_size; | 46 | + |
| 47 | + // Note: For v4, window_shift == window_size | ||
| 48 | + int32_t k = | ||
| 49 | + (static_cast<int32_t>(last_.size()) - window_size) / window_shift + 1; | ||
| 46 | const float *p = last_.data(); | 50 | const float *p = last_.data(); |
| 47 | bool is_speech = false; | 51 | bool is_speech = false; |
| 48 | 52 | ||
| 49 | - for (int32_t i = 0; i != k; ++i, p += window_size) { | ||
| 50 | - buffer_.Push(p, window_size); | 53 | + for (int32_t i = 0; i != k; ++i, p += window_shift) { |
| 54 | + buffer_.Push(p, window_shift); | ||
| 51 | // NOTE(fangjun): Please don't use a very large n. | 55 | // NOTE(fangjun): Please don't use a very large n. |
| 52 | bool this_window_is_speech = model_->IsSpeech(p, window_size); | 56 | bool this_window_is_speech = model_->IsSpeech(p, window_size); |
| 53 | is_speech = is_speech || this_window_is_speech; | 57 | is_speech = is_speech || this_window_is_speech; |
-
请 注册 或 登录 后发表评论