Ming-Hsuan-Tu
Committed by GitHub

Expose JNI to compute probability of chunk in VAD (#2433)

... ... @@ -69,6 +69,14 @@ class SileroVadModel::Impl {
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
}
float Run(const float *samples, int32_t n) {
if (is_v5_) {
return RunV5(samples, n);
} else {
return RunV4(samples, n);
}
}
void Reset() {
if (is_v5_) {
ResetV5();
... ... @@ -361,14 +369,6 @@ class SileroVadModel::Impl {
}
}
float Run(const float *samples, int32_t n) {
if (is_v5_) {
return RunV5(samples, n);
} else {
return RunV4(samples, n);
}
}
float RunV5(const float *samples, int32_t n) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
... ... @@ -496,6 +496,10 @@ void SileroVadModel::SetThreshold(float threshold) {
impl_->SetThreshold(threshold);
}
float SileroVadModel::Compute(const float *samples, int32_t n) {
return impl_->Run(samples, n);
}
#if __ANDROID_API__ >= 9
template SileroVadModel::SileroVadModel(AAssetManager *mgr,
const VadModelConfig &config);
... ...
... ... @@ -31,6 +31,8 @@ class SileroVadModel : public VadModel {
*/
bool IsSpeech(const float *samples, int32_t n) override;
float Compute(const float *samples, int32_t n) override;
// For silero vad V4, it is WindowShift().
// For silero vad V5, it is WindowShift()+64 for 16kHz and
// WindowShift()+32 for 8kHz
... ...
... ... @@ -56,6 +56,38 @@ class TenVadModel::Impl {
Init(buf.data(), buf.size());
}
float Run(const float *samples, int32_t n) {
ComputeFeatures(samples, n);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape = {1, 3, 41};
Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),
last_features_.size(),
x_shape.data(), x_shape.size());
std::vector<Ort::Value> inputs;
inputs.reserve(input_names_.size());
inputs.push_back(std::move(x));
for (auto &s : states_) {
inputs.push_back(std::move(s));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {
states_[i - 1] = std::move(out[i]);
}
float prob = out[0].GetTensorData<float>()[0];
return prob;
}
void Reset() {
triggered_ = false;
current_sample_ = 0;
... ... @@ -363,39 +395,6 @@ class TenVadModel::Impl {
last_features_.begin() + 2 * features_.size());
}
float Run(const float *samples, int32_t n) {
ComputeFeatures(samples, n);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape = {1, 3, 41};
Ort::Value x = Ort::Value::CreateTensor(memory_info, last_features_.data(),
last_features_.size(),
x_shape.data(), x_shape.size());
std::vector<Ort::Value> inputs;
inputs.reserve(input_names_.size());
inputs.push_back(std::move(x));
for (auto &s : states_) {
inputs.push_back(std::move(s));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
for (int32_t i = 1; i != static_cast<int32_t>(output_names_.size()); ++i) {
states_[i - 1] = std::move(out[i]);
}
float prob = out[0].GetTensorData<float>()[0];
return prob;
}
private:
VadModelConfig config_;
knf::Rfft rfft_;
... ... @@ -469,6 +468,10 @@ void TenVadModel::SetThreshold(float threshold) {
impl_->SetThreshold(threshold);
}
float TenVadModel::Compute(const float *samples, int32_t n) {
return impl_->Run(samples, n);
}
#if __ANDROID_API__ >= 9
template TenVadModel::TenVadModel(AAssetManager *mgr,
const VadModelConfig &config);
... ...
... ... @@ -31,6 +31,8 @@ class TenVadModel : public VadModel {
*/
bool IsSpeech(const float *samples, int32_t n) override;
float Compute(const float *samples, int32_t n) override;
// 256 or 160
int32_t WindowSize() const override;
... ...
... ... @@ -32,6 +32,8 @@ class VadModel {
*/
virtual bool IsSpeech(const float *samples, int32_t n) = 0;
virtual float Compute(const float *samples, int32_t n) = 0;
virtual int32_t WindowSize() const = 0;
virtual int32_t WindowShift() const = 0;
... ...
... ... @@ -41,6 +41,10 @@ class VoiceActivityDetector::Impl {
Init();
}
float Compute(const float *samples, int32_t n) {
return model_->Compute(samples, n);
}
void AcceptWaveform(const float *samples, int32_t n) {
if (buffer_.Size() > max_utterance_length_) {
model_->SetMinSilenceDuration(new_min_silence_duration_s_);
... ... @@ -256,6 +260,10 @@ const VadModelConfig &VoiceActivityDetector::GetConfig() const {
return impl_->GetConfig();
}
float VoiceActivityDetector::Compute(const float *samples, int32_t n) {
return impl_->Compute(samples, n);
}
#if __ANDROID_API__ >= 9
template VoiceActivityDetector::VoiceActivityDetector(
AAssetManager *mgr, const VadModelConfig &config,
... ...
... ... @@ -28,6 +28,8 @@ class VoiceActivityDetector {
~VoiceActivityDetector();
void AcceptWaveform(const float *samples, int32_t n);
float Compute(const float *samples, int32_t n);
bool Empty() const;
void Pop();
void Clear();
... ...
... ... @@ -27,6 +27,10 @@ public class Vad {
acceptWaveform(this.ptr, samples);
}
public float compute(float[] samples) {
return compute(this.ptr, samples);
}
public boolean empty() {
return empty(this.ptr);
}
... ... @@ -65,6 +69,8 @@ public class Vad {
private native void acceptWaveform(long ptr, float[] samples);
private native float compute(long ptr, float[] samples);
private native boolean empty(long ptr);
private native void pop(long ptr);
... ...
... ... @@ -227,3 +227,26 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_flush(JNIEnv * /*env*/,
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
model->Flush();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloat JNICALL Java_com_k2fsa_sherpa_onnx_Vad_compute(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
return SafeJNI(
env, "Vad_compute",
[&]() -> jfloat {
if (!ValidatePointer(env, ptr, "Vad_compute",
"VoiceActivityDetector pointer is null.")) {
return -1.0f;
}
auto vad = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
float score = vad->Compute(p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return static_cast<jfloat>(score);
},
-1.0f);
}
... ...
... ... @@ -55,6 +55,9 @@ class Vad(
fun release() = finalize()
fun compute(samples: FloatArray): Float = compute(ptr, samples)
fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
fun empty(): Boolean = empty(ptr)
... ... @@ -85,6 +88,8 @@ class Vad(
): Long
private external fun acceptWaveform(ptr: Long, samples: FloatArray)
private external fun compute(ptr: Long, samples: FloatArray): Float
private external fun empty(ptr: Long): Boolean
private external fun pop(ptr: Long)
private external fun clear(ptr: Long)
... ...