Ming-Hsuan-Tu
Committed by GitHub

Expose JNI to compute probability of chunk in VAD (#2433)

@@ -69,6 +69,14 @@ class SileroVadModel::Impl { @@ -69,6 +69,14 @@ class SileroVadModel::Impl {
69 min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; 69 min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
70 } 70 }
71 71
  72 + float Run(const float *samples, int32_t n) {
  73 + if (is_v5_) {
  74 + return RunV5(samples, n);
  75 + } else {
  76 + return RunV4(samples, n);
  77 + }
  78 + }
  79 +
72 void Reset() { 80 void Reset() {
73 if (is_v5_) { 81 if (is_v5_) {
74 ResetV5(); 82 ResetV5();
@@ -361,14 +369,6 @@ class SileroVadModel::Impl { @@ -361,14 +369,6 @@ class SileroVadModel::Impl {
361 } 369 }
362 } 370 }
363 371
364 - float Run(const float *samples, int32_t n) {  
365 - if (is_v5_) {  
366 - return RunV5(samples, n);  
367 - } else {  
368 - return RunV4(samples, n);  
369 - }  
370 - }  
371 -  
372 float RunV5(const float *samples, int32_t n) { 372 float RunV5(const float *samples, int32_t n) {
373 auto memory_info = 373 auto memory_info =
374 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 374 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -496,6 +496,10 @@ void SileroVadModel::SetThreshold(float threshold) { @@ -496,6 +496,10 @@ void SileroVadModel::SetThreshold(float threshold) {
496 impl_->SetThreshold(threshold); 496 impl_->SetThreshold(threshold);
497 } 497 }
498 498
  499 +float SileroVadModel::Compute(const float *samples, int32_t n) {
  500 + return impl_->Run(samples, n);
  501 +}
  502 +
499 #if __ANDROID_API__ >= 9 503 #if __ANDROID_API__ >= 9
500 template SileroVadModel::SileroVadModel(AAssetManager *mgr, 504 template SileroVadModel::SileroVadModel(AAssetManager *mgr,
501 const VadModelConfig &config); 505 const VadModelConfig &config);
@@ -31,6 +31,8 @@ class SileroVadModel : public VadModel { @@ -31,6 +31,8 @@ class SileroVadModel : public VadModel {
31 */ 31 */
32 bool IsSpeech(const float *samples, int32_t n) override; 32 bool IsSpeech(const float *samples, int32_t n) override;
33 33
  34 + float Compute(const float *samples, int32_t n) override;
  35 +
34 // For silero vad V4, it is WindowShift(). 36 // For silero vad V4, it is WindowShift().
35 // For silero vad V5, it is WindowShift()+64 for 16kHz and 37 // For silero vad V5, it is WindowShift()+64 for 16kHz and
36 // WindowShift()+32 for 8kHz 38 // WindowShift()+32 for 8kHz
@@ -56,6 +56,38 @@ class TenVadModel::Impl { @@ -56,6 +56,38 @@ class TenVadModel::Impl {
56 Init(buf.data(), buf.size()); 56 Init(buf.data(), buf.size());
57 } 57 }
58 58
  59 + float Run(const float *samples, int32_t n) {
  60 + ComputeFeatures(samples, n);
  61 +
  62 + auto memory_info =
  63 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  64 +
  65 + std::array<int64_t, 3> x_shape = {1, 3, 41};
  66 +
  67 + Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),
  68 + last_features_.size(),
  69 + x_shape.data(), x_shape.size());
  70 +
  71 + std::vector<Ort::Value> inputs;
  72 + inputs.reserve(input_names_.size());
  73 +
  74 + inputs.push_back(std::move(x));
  75 + for (auto &s : states_) {
  76 + inputs.push_back(std::move(s));
  77 + }
  78 +
  79 + auto out =
  80 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  81 + output_names_ptr_.data(), output_names_ptr_.size());
  82 +
  83 + for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {
  84 + states_[i - 1] = std::move(out[i]);
  85 + }
  86 +
  87 + float prob = out[0].GetTensorData<float>()[0];
  88 +
  89 + return prob;
  90 + }
59 void Reset() { 91 void Reset() {
60 triggered_ = false; 92 triggered_ = false;
61 current_sample_ = 0; 93 current_sample_ = 0;
@@ -363,39 +395,6 @@ class TenVadModel::Impl { @@ -363,39 +395,6 @@ class TenVadModel::Impl {
363 last_features_.begin() + 2 * features_.size()); 395 last_features_.begin() + 2 * features_.size());
364 } 396 }
365 397
366 - float Run(const float *samples, int32_t n) {  
367 - ComputeFeatures(samples, n);  
368 -  
369 - auto memory_info =  
370 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
371 -  
372 - std::array<int64_t, 3> x_shape = {1, 3, 41};  
373 -  
374 - Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),  
375 - last_features_.size(),  
376 - x_shape.data(), x_shape.size());  
377 -  
378 - std::vector<Ort::Value> inputs;  
379 - inputs.reserve(input_names_.size());  
380 -  
381 - inputs.push_back(std::move(x));  
382 - for (auto &s : states_) {  
383 - inputs.push_back(std::move(s));  
384 - }  
385 -  
386 - auto out =  
387 - sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),  
388 - output_names_ptr_.data(), output_names_ptr_.size());  
389 -  
390 - for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {  
391 - states_[i - 1] = std::move(out[i]);  
392 - }  
393 -  
394 - float prob = out[0].GetTensorData<float>()[0];  
395 -  
396 - return prob;  
397 - }  
398 -  
399 private: 398 private:
400 VadModelConfig config_; 399 VadModelConfig config_;
401 knf::Rfft rfft_; 400 knf::Rfft rfft_;
@@ -469,6 +468,10 @@ void TenVadModel::SetThreshold(float threshold) { @@ -469,6 +468,10 @@ void TenVadModel::SetThreshold(float threshold) {
469 impl_->SetThreshold(threshold); 468 impl_->SetThreshold(threshold);
470 } 469 }
471 470
  471 +float TenVadModel::Compute(const float *samples, int32_t n) {
  472 + return impl_->Run(samples, n);
  473 +}
  474 +
472 #if __ANDROID_API__ >= 9 475 #if __ANDROID_API__ >= 9
473 template TenVadModel::TenVadModel(AAssetManager *mgr, 476 template TenVadModel::TenVadModel(AAssetManager *mgr,
474 const VadModelConfig &config); 477 const VadModelConfig &config);
@@ -31,6 +31,8 @@ class TenVadModel : public VadModel { @@ -31,6 +31,8 @@ class TenVadModel : public VadModel {
31 */ 31 */
32 bool IsSpeech(const float *samples, int32_t n) override; 32 bool IsSpeech(const float *samples, int32_t n) override;
33 33
  34 + float Compute(const float *samples, int32_t n) override;
  35 +
34 // 256 or 160 36 // 256 or 160
35 int32_t WindowSize() const override; 37 int32_t WindowSize() const override;
36 38
@@ -32,6 +32,8 @@ class VadModel { @@ -32,6 +32,8 @@ class VadModel {
32 */ 32 */
33 virtual bool IsSpeech(const float *samples, int32_t n) = 0; 33 virtual bool IsSpeech(const float *samples, int32_t n) = 0;
34 34
  35 + virtual float Compute(const float *samples, int32_t n) = 0;
  36 +
35 virtual int32_t WindowSize() const = 0; 37 virtual int32_t WindowSize() const = 0;
36 38
37 virtual int32_t WindowShift() const = 0; 39 virtual int32_t WindowShift() const = 0;
@@ -41,6 +41,10 @@ class VoiceActivityDetector::Impl { @@ -41,6 +41,10 @@ class VoiceActivityDetector::Impl {
41 Init(); 41 Init();
42 } 42 }
43 43
  44 + float Compute(const float *samples, int32_t n) {
  45 + return model_->Compute(samples, n);
  46 + }
  47 +
44 void AcceptWaveform(const float *samples, int32_t n) { 48 void AcceptWaveform(const float *samples, int32_t n) {
45 if (buffer_.Size() > max_utterance_length_) { 49 if (buffer_.Size() > max_utterance_length_) {
46 model_->SetMinSilenceDuration(new_min_silence_duration_s_); 50 model_->SetMinSilenceDuration(new_min_silence_duration_s_);
@@ -256,6 +260,10 @@ const VadModelConfig &VoiceActivityDetector::GetConfig() const { @@ -256,6 +260,10 @@ const VadModelConfig &VoiceActivityDetector::GetConfig() const {
256 return impl_->GetConfig(); 260 return impl_->GetConfig();
257 } 261 }
258 262
  263 +float VoiceActivityDetector::Compute(const float *samples, int32_t n) {
  264 + return impl_->Compute(samples, n);
  265 +}
  266 +
259 #if __ANDROID_API__ >= 9 267 #if __ANDROID_API__ >= 9
260 template VoiceActivityDetector::VoiceActivityDetector( 268 template VoiceActivityDetector::VoiceActivityDetector(
261 AAssetManager *mgr, const VadModelConfig &config, 269 AAssetManager *mgr, const VadModelConfig &config,
@@ -28,6 +28,8 @@ class VoiceActivityDetector { @@ -28,6 +28,8 @@ class VoiceActivityDetector {
28 ~VoiceActivityDetector(); 28 ~VoiceActivityDetector();
29 29
30 void AcceptWaveform(const float *samples, int32_t n); 30 void AcceptWaveform(const float *samples, int32_t n);
  31 + float Compute(const float *samples, int32_t n);
  32 +
31 bool Empty() const; 33 bool Empty() const;
32 void Pop(); 34 void Pop();
33 void Clear(); 35 void Clear();
@@ -27,6 +27,10 @@ public class Vad { @@ -27,6 +27,10 @@ public class Vad {
27 acceptWaveform(this.ptr, samples); 27 acceptWaveform(this.ptr, samples);
28 } 28 }
29 29
  30 + public float compute(float[] samples) {
  31 + return compute(this.ptr, samples);
  32 + }
  33 +
30 public boolean empty() { 34 public boolean empty() {
31 return empty(this.ptr); 35 return empty(this.ptr);
32 } 36 }
@@ -65,6 +69,8 @@ public class Vad { @@ -65,6 +69,8 @@ public class Vad {
65 69
66 private native void acceptWaveform(long ptr, float[] samples); 70 private native void acceptWaveform(long ptr, float[] samples);
67 71
  72 + private native float compute(long ptr, float[] samples);
  73 +
68 private native boolean empty(long ptr); 74 private native boolean empty(long ptr);
69 75
70 private native void pop(long ptr); 76 private native void pop(long ptr);
@@ -227,3 +227,26 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_flush(JNIEnv * /*env*/, @@ -227,3 +227,26 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_flush(JNIEnv * /*env*/,
227 auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr); 227 auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
228 model->Flush(); 228 model->Flush();
229 } 229 }
  230 +
  231 +SHERPA_ONNX_EXTERN_C
  232 +JNIEXPORT jfloat JNICALL Java_com_k2fsa_sherpa_onnx_Vad_compute(
  233 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
  234 + return SafeJNI(
  235 + env, "Vad_compute",
  236 + [&]() -> jfloat {
  237 + if (!ValidatePointer(env, ptr, "Vad_compute",
  238 + "VoiceActivityDetector pointer is null.")) {
  239 + return -1.0f;
  240 + }
  241 + auto vad = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  242 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  243 + jsize n = env->GetArrayLength(samples);
  244 +
  245 + float score = vad->Compute(p, n);
  246 +
  247 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  248 +
  249 + return static_cast<jfloat>(score);
  250 + },
  251 + -1.0f);
  252 +}
@@ -55,6 +55,9 @@ class Vad( @@ -55,6 +55,9 @@ class Vad(
55 55
56 fun release() = finalize() 56 fun release() = finalize()
57 57
  58 + fun compute(samples: FloatArray): Float = compute(ptr, samples)
  59 +
  60 +
58 fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) 61 fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
59 62
60 fun empty(): Boolean = empty(ptr) 63 fun empty(): Boolean = empty(ptr)
@@ -85,6 +88,8 @@ class Vad( @@ -85,6 +88,8 @@ class Vad(
85 ): Long 88 ): Long
86 89
87 private external fun acceptWaveform(ptr: Long, samples: FloatArray) 90 private external fun acceptWaveform(ptr: Long, samples: FloatArray)
  91 + private external fun compute(ptr: Long, samples: FloatArray): Float
  92 +
88 private external fun empty(ptr: Long): Boolean 93 private external fun empty(ptr: Long): Boolean
89 private external fun pop(ptr: Long) 94 private external fun pop(ptr: Long)
90 private external fun clear(ptr: Long) 95 private external fun clear(ptr: Long)