Committed by
GitHub
Expose JNI to compute probability of chunk in VAD (#2433)
正在显示
10 个修改的文件
包含
98 行增加
和
41 行删除
| @@ -69,6 +69,14 @@ class SileroVadModel::Impl { | @@ -69,6 +69,14 @@ class SileroVadModel::Impl { | ||
| 69 | min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; | 69 | min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; |
| 70 | } | 70 | } |
| 71 | 71 | ||
| 72 | + float Run(const float *samples, int32_t n) { | ||
| 73 | + if (is_v5_) { | ||
| 74 | + return RunV5(samples, n); | ||
| 75 | + } else { | ||
| 76 | + return RunV4(samples, n); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + | ||
| 72 | void Reset() { | 80 | void Reset() { |
| 73 | if (is_v5_) { | 81 | if (is_v5_) { |
| 74 | ResetV5(); | 82 | ResetV5(); |
| @@ -361,14 +369,6 @@ class SileroVadModel::Impl { | @@ -361,14 +369,6 @@ class SileroVadModel::Impl { | ||
| 361 | } | 369 | } |
| 362 | } | 370 | } |
| 363 | 371 | ||
| 364 | - float Run(const float *samples, int32_t n) { | ||
| 365 | - if (is_v5_) { | ||
| 366 | - return RunV5(samples, n); | ||
| 367 | - } else { | ||
| 368 | - return RunV4(samples, n); | ||
| 369 | - } | ||
| 370 | - } | ||
| 371 | - | ||
| 372 | float RunV5(const float *samples, int32_t n) { | 372 | float RunV5(const float *samples, int32_t n) { |
| 373 | auto memory_info = | 373 | auto memory_info = |
| 374 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 374 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| @@ -496,6 +496,10 @@ void SileroVadModel::SetThreshold(float threshold) { | @@ -496,6 +496,10 @@ void SileroVadModel::SetThreshold(float threshold) { | ||
| 496 | impl_->SetThreshold(threshold); | 496 | impl_->SetThreshold(threshold); |
| 497 | } | 497 | } |
| 498 | 498 | ||
| 499 | +float SileroVadModel::Compute(const float *samples, int32_t n) { | ||
| 500 | + return impl_->Run(samples, n); | ||
| 501 | +} | ||
| 502 | + | ||
| 499 | #if __ANDROID_API__ >= 9 | 503 | #if __ANDROID_API__ >= 9 |
| 500 | template SileroVadModel::SileroVadModel(AAssetManager *mgr, | 504 | template SileroVadModel::SileroVadModel(AAssetManager *mgr, |
| 501 | const VadModelConfig &config); | 505 | const VadModelConfig &config); |
| @@ -31,6 +31,8 @@ class SileroVadModel : public VadModel { | @@ -31,6 +31,8 @@ class SileroVadModel : public VadModel { | ||
| 31 | */ | 31 | */ |
| 32 | bool IsSpeech(const float *samples, int32_t n) override; | 32 | bool IsSpeech(const float *samples, int32_t n) override; |
| 33 | 33 | ||
| 34 | + float Compute(const float *samples, int32_t n) override; | ||
| 35 | + | ||
| 34 | // For silero vad V4, it is WindowShift(). | 36 | // For silero vad V4, it is WindowShift(). |
| 35 | // For silero vad V5, it is WindowShift()+64 for 16kHz and | 37 | // For silero vad V5, it is WindowShift()+64 for 16kHz and |
| 36 | // WindowShift()+32 for 8kHz | 38 | // WindowShift()+32 for 8kHz |
| @@ -56,6 +56,38 @@ class TenVadModel::Impl { | @@ -56,6 +56,38 @@ class TenVadModel::Impl { | ||
| 56 | Init(buf.data(), buf.size()); | 56 | Init(buf.data(), buf.size()); |
| 57 | } | 57 | } |
| 58 | 58 | ||
| 59 | + float Run(const float *samples, int32_t n) { | ||
| 60 | + ComputeFeatures(samples, n); | ||
| 61 | + | ||
| 62 | + auto memory_info = | ||
| 63 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 64 | + | ||
| 65 | + std::array<int64_t, 3> x_shape = {1, 3, 41}; | ||
| 66 | + | ||
| 67 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(), | ||
| 68 | + last_features_.size(), | ||
| 69 | + x_shape.data(), x_shape.size()); | ||
| 70 | + | ||
| 71 | + std::vector<Ort::Value> inputs; | ||
| 72 | + inputs.reserve(input_names_.size()); | ||
| 73 | + | ||
| 74 | + inputs.push_back(std::move(x)); | ||
| 75 | + for (auto &s : states_) { | ||
| 76 | + inputs.push_back(std::move(s)); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + auto out = | ||
| 80 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 81 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 82 | + | ||
| 83 | + for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) { | ||
| 84 | + states_[i - 1] = std::move(out[i]); | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + float prob = out[0].GetTensorData<float>()[0]; | ||
| 88 | + | ||
| 89 | + return prob; | ||
| 90 | + } | ||
| 59 | void Reset() { | 91 | void Reset() { |
| 60 | triggered_ = false; | 92 | triggered_ = false; |
| 61 | current_sample_ = 0; | 93 | current_sample_ = 0; |
| @@ -363,39 +395,6 @@ class TenVadModel::Impl { | @@ -363,39 +395,6 @@ class TenVadModel::Impl { | ||
| 363 | last_features_.begin() + 2 * features_.size()); | 395 | last_features_.begin() + 2 * features_.size()); |
| 364 | } | 396 | } |
| 365 | 397 | ||
| 366 | - float Run(const float *samples, int32_t n) { | ||
| 367 | - ComputeFeatures(samples, n); | ||
| 368 | - | ||
| 369 | - auto memory_info = | ||
| 370 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 371 | - | ||
| 372 | - std::array<int64_t, 3> x_shape = {1, 3, 41}; | ||
| 373 | - | ||
| 374 | - Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(), | ||
| 375 | - last_features_.size(), | ||
| 376 | - x_shape.data(), x_shape.size()); | ||
| 377 | - | ||
| 378 | - std::vector<Ort::Value> inputs; | ||
| 379 | - inputs.reserve(input_names_.size()); | ||
| 380 | - | ||
| 381 | - inputs.push_back(std::move(x)); | ||
| 382 | - for (auto &s : states_) { | ||
| 383 | - inputs.push_back(std::move(s)); | ||
| 384 | - } | ||
| 385 | - | ||
| 386 | - auto out = | ||
| 387 | - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 388 | - output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 389 | - | ||
| 390 | - for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) { | ||
| 391 | - states_[i - 1] = std::move(out[i]); | ||
| 392 | - } | ||
| 393 | - | ||
| 394 | - float prob = out[0].GetTensorData<float>()[0]; | ||
| 395 | - | ||
| 396 | - return prob; | ||
| 397 | - } | ||
| 398 | - | ||
| 399 | private: | 398 | private: |
| 400 | VadModelConfig config_; | 399 | VadModelConfig config_; |
| 401 | knf::Rfft rfft_; | 400 | knf::Rfft rfft_; |
| @@ -469,6 +468,10 @@ void TenVadModel::SetThreshold(float threshold) { | @@ -469,6 +468,10 @@ void TenVadModel::SetThreshold(float threshold) { | ||
| 469 | impl_->SetThreshold(threshold); | 468 | impl_->SetThreshold(threshold); |
| 470 | } | 469 | } |
| 471 | 470 | ||
| 471 | +float TenVadModel::Compute(const float *samples, int32_t n) { | ||
| 472 | + return impl_->Run(samples, n); | ||
| 473 | +} | ||
| 474 | + | ||
| 472 | #if __ANDROID_API__ >= 9 | 475 | #if __ANDROID_API__ >= 9 |
| 473 | template TenVadModel::TenVadModel(AAssetManager *mgr, | 476 | template TenVadModel::TenVadModel(AAssetManager *mgr, |
| 474 | const VadModelConfig &config); | 477 | const VadModelConfig &config); |
| @@ -31,6 +31,8 @@ class TenVadModel : public VadModel { | @@ -31,6 +31,8 @@ class TenVadModel : public VadModel { | ||
| 31 | */ | 31 | */ |
| 32 | bool IsSpeech(const float *samples, int32_t n) override; | 32 | bool IsSpeech(const float *samples, int32_t n) override; |
| 33 | 33 | ||
| 34 | + float Compute(const float *samples, int32_t n) override; | ||
| 35 | + | ||
| 34 | // 256 or 160 | 36 | // 256 or 160 |
| 35 | int32_t WindowSize() const override; | 37 | int32_t WindowSize() const override; |
| 36 | 38 |
| @@ -32,6 +32,8 @@ class VadModel { | @@ -32,6 +32,8 @@ class VadModel { | ||
| 32 | */ | 32 | */ |
| 33 | virtual bool IsSpeech(const float *samples, int32_t n) = 0; | 33 | virtual bool IsSpeech(const float *samples, int32_t n) = 0; |
| 34 | 34 | ||
| 35 | + virtual float Compute(const float *samples, int32_t n) = 0; | ||
| 36 | + | ||
| 35 | virtual int32_t WindowSize() const = 0; | 37 | virtual int32_t WindowSize() const = 0; |
| 36 | 38 | ||
| 37 | virtual int32_t WindowShift() const = 0; | 39 | virtual int32_t WindowShift() const = 0; |
| @@ -41,6 +41,10 @@ class VoiceActivityDetector::Impl { | @@ -41,6 +41,10 @@ class VoiceActivityDetector::Impl { | ||
| 41 | Init(); | 41 | Init(); |
| 42 | } | 42 | } |
| 43 | 43 | ||
| 44 | + float Compute(const float *samples, int32_t n) { | ||
| 45 | + return model_->Compute(samples, n); | ||
| 46 | + } | ||
| 47 | + | ||
| 44 | void AcceptWaveform(const float *samples, int32_t n) { | 48 | void AcceptWaveform(const float *samples, int32_t n) { |
| 45 | if (buffer_.Size() > max_utterance_length_) { | 49 | if (buffer_.Size() > max_utterance_length_) { |
| 46 | model_->SetMinSilenceDuration(new_min_silence_duration_s_); | 50 | model_->SetMinSilenceDuration(new_min_silence_duration_s_); |
| @@ -256,6 +260,10 @@ const VadModelConfig &VoiceActivityDetector::GetConfig() const { | @@ -256,6 +260,10 @@ const VadModelConfig &VoiceActivityDetector::GetConfig() const { | ||
| 256 | return impl_->GetConfig(); | 260 | return impl_->GetConfig(); |
| 257 | } | 261 | } |
| 258 | 262 | ||
| 263 | +float VoiceActivityDetector::Compute(const float *samples, int32_t n) { | ||
| 264 | + return impl_->Compute(samples, n); | ||
| 265 | +} | ||
| 266 | + | ||
| 259 | #if __ANDROID_API__ >= 9 | 267 | #if __ANDROID_API__ >= 9 |
| 260 | template VoiceActivityDetector::VoiceActivityDetector( | 268 | template VoiceActivityDetector::VoiceActivityDetector( |
| 261 | AAssetManager *mgr, const VadModelConfig &config, | 269 | AAssetManager *mgr, const VadModelConfig &config, |
| @@ -28,6 +28,8 @@ class VoiceActivityDetector { | @@ -28,6 +28,8 @@ class VoiceActivityDetector { | ||
| 28 | ~VoiceActivityDetector(); | 28 | ~VoiceActivityDetector(); |
| 29 | 29 | ||
| 30 | void AcceptWaveform(const float *samples, int32_t n); | 30 | void AcceptWaveform(const float *samples, int32_t n); |
| 31 | + float Compute(const float *samples, int32_t n); | ||
| 32 | + | ||
| 31 | bool Empty() const; | 33 | bool Empty() const; |
| 32 | void Pop(); | 34 | void Pop(); |
| 33 | void Clear(); | 35 | void Clear(); |
| @@ -27,6 +27,10 @@ public class Vad { | @@ -27,6 +27,10 @@ public class Vad { | ||
| 27 | acceptWaveform(this.ptr, samples); | 27 | acceptWaveform(this.ptr, samples); |
| 28 | } | 28 | } |
| 29 | 29 | ||
| 30 | + public float compute(float[] samples) { | ||
| 31 | + return compute(this.ptr, samples); | ||
| 32 | + } | ||
| 33 | + | ||
| 30 | public boolean empty() { | 34 | public boolean empty() { |
| 31 | return empty(this.ptr); | 35 | return empty(this.ptr); |
| 32 | } | 36 | } |
| @@ -65,6 +69,8 @@ public class Vad { | @@ -65,6 +69,8 @@ public class Vad { | ||
| 65 | 69 | ||
| 66 | private native void acceptWaveform(long ptr, float[] samples); | 70 | private native void acceptWaveform(long ptr, float[] samples); |
| 67 | 71 | ||
| 72 | + private native float compute(long ptr, float[] samples); | ||
| 73 | + | ||
| 68 | private native boolean empty(long ptr); | 74 | private native boolean empty(long ptr); |
| 69 | 75 | ||
| 70 | private native void pop(long ptr); | 76 | private native void pop(long ptr); |
| @@ -227,3 +227,26 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_flush(JNIEnv * /*env*/, | @@ -227,3 +227,26 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_flush(JNIEnv * /*env*/, | ||
| 227 | auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr); | 227 | auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr); |
| 228 | model->Flush(); | 228 | model->Flush(); |
| 229 | } | 229 | } |
| 230 | + | ||
| 231 | +SHERPA_ONNX_EXTERN_C | ||
| 232 | +JNIEXPORT jfloat JNICALL Java_com_k2fsa_sherpa_onnx_Vad_compute( | ||
| 233 | + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { | ||
| 234 | + return SafeJNI( | ||
| 235 | + env, "Vad_compute", | ||
| 236 | + [&]() -> jfloat { | ||
| 237 | + if (!ValidatePointer(env, ptr, "Vad_compute", | ||
| 238 | + "VoiceActivityDetector pointer is null.")) { | ||
| 239 | + return -1.0f; | ||
| 240 | + } | ||
| 241 | + auto vad = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr); | ||
| 242 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 243 | + jsize n = env->GetArrayLength(samples); | ||
| 244 | + | ||
| 245 | + float score = vad->Compute(p, n); | ||
| 246 | + | ||
| 247 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 248 | + | ||
| 249 | + return static_cast<jfloat>(score); | ||
| 250 | + }, | ||
| 251 | + -1.0f); | ||
| 252 | +} |
| @@ -55,6 +55,9 @@ class Vad( | @@ -55,6 +55,9 @@ class Vad( | ||
| 55 | 55 | ||
| 56 | fun release() = finalize() | 56 | fun release() = finalize() |
| 57 | 57 | ||
| 58 | + fun compute(samples: FloatArray): Float = compute(ptr, samples) | ||
| 59 | + | ||
| 60 | + | ||
| 58 | fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) | 61 | fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) |
| 59 | 62 | ||
| 60 | fun empty(): Boolean = empty(ptr) | 63 | fun empty(): Boolean = empty(ptr) |
| @@ -85,6 +88,8 @@ class Vad( | @@ -85,6 +88,8 @@ class Vad( | ||
| 85 | ): Long | 88 | ): Long |
| 86 | 89 | ||
| 87 | private external fun acceptWaveform(ptr: Long, samples: FloatArray) | 90 | private external fun acceptWaveform(ptr: Long, samples: FloatArray) |
| 91 | + private external fun compute(ptr: Long, samples: FloatArray): Float | ||
| 92 | + | ||
| 88 | private external fun empty(ptr: Long): Boolean | 93 | private external fun empty(ptr: Long): Boolean |
| 89 | private external fun pop(ptr: Long) | 94 | private external fun pop(ptr: Long) |
| 90 | private external fun clear(ptr: Long) | 95 | private external fun clear(ptr: Long) |
-
请 注册 或 登录 后发表评论