Fangjun Kuang
Committed by GitHub

Add C++ runtime for silero_vad with RKNN (#2078)

... ... @@ -102,8 +102,7 @@ int32_t main() {
if (i + window_size < wave->num_samples) {
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
window_size);
}
else {
} else {
SherpaOnnxVoiceActivityDetectorFlush(vad);
is_eof = 1;
}
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import onnxruntime
import onnx
"""
[key: "model_type"
value: "gtcrn"
, key: "comment"
value: "gtcrn_simple"
, key: "version"
value: "1"
, key: "sample_rate"
value: "16000"
, key: "model_url"
value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx"
, key: "maintainer"
value: "k2-fsa"
, key: "comment2"
value: "Please see also https://github.com/Xiaobin-Rong/gtcrn"
, key: "conv_cache_shape"
value: "2,1,16,16,33"
, key: "tra_cache_shape"
value: "2,3,1,1,16"
, key: "inter_cache_shape"
value: "2,1,33,16"
, key: "n_fft"
value: "512"
, key: "hop_length"
value: "256"
, key: "window_length"
value: "512"
, key: "window_type"
value: "hann_sqrt"
]
"""
"""
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
-----
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
"""
def show(filename):
model = onnx.load(filename)
print(model.metadata_props)
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
sess = onnxruntime.InferenceSession(
filename, session_opts, providers=["CPUExecutionProvider"]
)
for i in sess.get_inputs():
print(i)
print("-----")
for i in sess.get_outputs():
print(i)
def main():
show("./gtcrn_simple.onnx")
if __name__ == "__main__":
main()
... ...
... ... @@ -5,15 +5,94 @@ import onnx
import torch
from onnxsim import simplify
import torch
from torch import Tensor
def simple_pad(x: Tensor, pad: int) -> Tensor:
# _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
_0 = x[:, :, 1 : 1 + pad]
left_pad = torch.flip(_0, [-1])
# _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)
_1 = x[:, :, (-1 - pad) : -1]
right_pad = torch.flip(_1, [-1])
_2 = torch.cat([left_pad, x, right_pad], 2)
return _2
class MyModule(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def adaptive_normalization_forward(self, spect):
m = self.m._model.adaptive_normalization
_0 = simple_pad
# Note(fangjun): rknn uses fp16 by default, whose max value is 65504
# so we need to re-write the computation for spect0
# spect0 = torch.log1p(torch.mul(spect, 1048576))
spect0 = torch.log1p(spect) + 13.86294
_1 = torch.eq(len(spect0.shape), 2)
if _1:
_2 = torch.unsqueeze(spect0, 0)
spect1 = _2
else:
spect1 = spect0
mean = torch.mean(spect1, [1], True)
to_pad = m.to_pad
mean0 = _0(
mean,
to_pad,
)
filter_ = m.filter_
mean1 = torch.conv1d(mean0, filter_)
mean_mean = torch.mean(mean1, [-1], True)
spect2 = torch.add(spect1, torch.neg(mean_mean))
return spect2
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
m = self.m._model
feature_extractor = m.feature_extractor
x0 = (feature_extractor).forward(
x,
)
norm = self.adaptive_normalization_forward(x0)
x1 = torch.cat([x0, norm], 1)
first_layer = m.first_layer
x2 = (first_layer).forward(
x1,
)
encoder = m.encoder
x3 = (encoder).forward(
x2,
)
decoder = m.decoder
x4, h0, c0, = (decoder).forward(
x3,
h,
c,
)
_0 = torch.mean(torch.squeeze(x4, 1), [1])
out = torch.unsqueeze(_0, 1)
return (out, h0, c0)
@torch.no_grad()
def main():
m = torch.jit.load("./silero_vad.jit")
m = MyModule(m)
x = torch.rand((1, 512), dtype=torch.float32)
h = torch.rand((2, 1, 64), dtype=torch.float32)
c = torch.rand((2, 1, 64), dtype=torch.float32)
m = torch.jit.script(m)
torch.onnx.export(
m._model,
m,
(x, h, c),
"m.onnx",
input_names=["x", "h", "c"],
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import onnxruntime
import onnx
... ...
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
# Please run this file on your rk3588 board
try:
from rknnlite.api import RKNNLite
except:
print("Please run this file on your board (linux + aarch64 + npu)")
print("You need to install rknn_toolkit_lite2")
print(
" from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages"
)
print(
"https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl"
)
print("is known to work")
raise
import time
from pathlib import Path
from typing import Tuple
import numpy as np
import soundfile as sf
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def init_model(filename, target_platform="rk3588"):
if not Path(filename).is_file():
exit(f"{filename} does not exist")
rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(path=filename)
if ret != 0:
exit(f"Load model {filename} failed!")
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
exit(f"Failed to init rknn runtime for {filename}")
return rknn_lite
class RKNNModel:
def __init__(self, model: str, target_platform="rk3588"):
self.model = init_model(model)
def release(self):
self.model.release()
def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray):
"""
Args:
x: (1, 512), np.float32
h: (2, 1, 64), np.float32
c: (2, 1, 64), np.float32
Returns:
prob:
next_h:
next_c
"""
out, next_h, next_c = self.model.inference(inputs=[x, h, c])
return out.item(), next_h, next_c
def main():
model = RKNNModel(model="./m.rknn")
for i in range(1):
test(model)
def test(model):
print("started")
start = time.time()
samples, sample_rate = load_audio("./lei-jun-test.wav")
assert sample_rate == 16000, sample_rate
window_size = 512
h = np.zeros((2, 1, 64), dtype=np.float32)
c = np.zeros((2, 1, 64), dtype=np.float32)
threshold = 0.5
num_windows = samples.shape[0] // window_size
out = []
for i in range(num_windows):
print(i, num_windows)
this_samples = samples[i * window_size : (i + 1) * window_size]
prob, h, c = model(this_samples[None], h, c)
out.append(prob > threshold)
min_speech_duration = 0.25 * sample_rate / window_size
min_silence_duration = 0.25 * sample_rate / window_size
result = []
last = -1
for k, f in enumerate(out):
if f >= threshold:
if last == -1:
last = k
elif last != -1:
if k - last > min_speech_duration:
result.append((last, k))
last = -1
if last != -1 and k - last > min_speech_duration:
result.append((last, k))
if not result:
print("Empty for ./lei-jun-test.wav")
return
print(result)
final = [result[0]]
for r in result[1:]:
f = final[-1]
if r[0] - f[1] < min_silence_duration:
final[-1] = (f[0], r[1])
else:
final.append(r)
for f in final:
start = f[0] * window_size / sample_rate
end = f[1] * window_size / sample_rate
print("{:.3f} -- {:.3f}".format(start, end))
if __name__ == "__main__":
main()
... ...
... ... @@ -97,10 +97,13 @@ def main():
h, c = model.get_init_states()
window_size = 512
num_windows = samples.shape[0] // window_size
for i in range(num_windows):
start = i * window_size
end = start + window_size
p, h, c = model(samples[start:end], h, c)
probs.append(p[0].item())
threshold = 0.5
... ...
... ... @@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
./rknn/online-zipformer-ctc-model-rknn.cc
./rknn/online-zipformer-transducer-model-rknn.cc
./rknn/silero-vad-model-rknn.cc
./rknn/utils.cc
)
... ... @@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
microphone.cc
)
add_executable(sherpa-onnx-microphone-offline
sherpa-onnx-microphone-offline.cc
microphone.cc
... ... @@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
)
set(exes
sherpa-onnx-microphone
sherpa-onnx-keyword-spotter-microphone
sherpa-onnx-microphone
sherpa-onnx-microphone-offline
sherpa-onnx-microphone-offline-speaker-identification
sherpa-onnx-microphone-offline-audio-tagging
sherpa-onnx-microphone-offline-speaker-identification
sherpa-onnx-vad-microphone
sherpa-onnx-vad-microphone-offline-asr
sherpa-onnx-vad-with-offline-asr
... ...
// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class SileroVadModelRknn::Impl {
public:
~Impl() {
auto ret = rknn_destroy(ctx_);
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE("Failed to destroy the context");
}
}
explicit Impl(const VadModelConfig &config)
: config_(config), sample_rate_(config.sample_rate) {
auto buf = ReadFile(config.silero_vad.model);
Init(buf.data(), buf.size());
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
SHERPA_ONNX_EXIT(-1);
}
min_silence_samples_ =
sample_rate_ * config_.silero_vad.min_silence_duration;
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
}
template <typename Manager>
Impl(Manager *mgr, const VadModelConfig &config)
: config_(config), sample_rate_(config.sample_rate) {
auto buf = ReadFile(mgr, config.silero_vad.model);
Init(buf.data(), buf.size());
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
exit(-1);
}
min_silence_samples_ =
sample_rate_ * config_.silero_vad.min_silence_duration;
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
}
void Reset() {
for (auto &s : states_) {
std::fill(s.begin(), s.end(), 0);
}
triggered_ = false;
current_sample_ = 0;
temp_start_ = 0;
temp_end_ = 0;
}
bool IsSpeech(const float *samples, int32_t n) {
if (n != WindowSize()) {
SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize());
SHERPA_ONNX_EXIT(-1);
}
float prob = Run(samples, n);
float threshold = config_.silero_vad.threshold;
current_sample_ += config_.silero_vad.window_size;
if (prob > threshold && temp_end_ != 0) {
temp_end_ = 0;
}
if (prob > threshold && temp_start_ == 0) {
// start speaking, but we require that it must satisfy
// min_speech_duration
temp_start_ = current_sample_;
return false;
}
if (prob > threshold && temp_start_ != 0 && !triggered_) {
if (current_sample_ - temp_start_ < min_speech_samples_) {
return false;
}
triggered_ = true;
return true;
}
if ((prob < threshold) && !triggered_) {
// silence
temp_start_ = 0;
temp_end_ = 0;
return false;
}
if ((prob > threshold - 0.15) && triggered_) {
// speaking
return true;
}
if ((prob > threshold) && !triggered_) {
// start speaking
triggered_ = true;
return true;
}
if ((prob < threshold) && triggered_) {
// stop to speak
if (temp_end_ == 0) {
temp_end_ = current_sample_;
}
if (current_sample_ - temp_end_ < min_silence_samples_) {
// continue speaking
return true;
}
// stopped speaking
temp_start_ = 0;
temp_end_ = 0;
triggered_ = false;
return false;
}
return false;
}
int32_t WindowShift() const { return config_.silero_vad.window_size; }
int32_t WindowSize() const {
return config_.silero_vad.window_size + window_overlap_;
}
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
void SetMinSilenceDuration(float s) {
min_silence_samples_ = sample_rate_ * s;
}
void SetThreshold(float threshold) {
config_.silero_vad.threshold = threshold;
}
private:
void Init(void *model_data, size_t model_data_length) {
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'",
config_.silero_vad.model.c_str());
if (config_.debug) {
rknn_sdk_version v;
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
v.drv_version);
}
rknn_input_output_num io_num;
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
static_cast<int32_t>(io_num.n_input),
static_cast<int32_t>(io_num.n_output));
}
input_attrs_.resize(io_num.n_input);
output_attrs_.resize(io_num.n_output);
int32_t i = 0;
for (auto &attr : input_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : input_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
os.str().c_str());
}
i = 0;
for (auto &attr : output_attrs_) {
memset(&attr, 0, sizeof(attr));
attr.index = i;
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
i += 1;
}
if (config_.debug) {
std::ostringstream os;
std::string sep;
for (auto &attr : output_attrs_) {
os << sep << ToString(attr);
sep = "\n";
}
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
os.str().c_str());
}
rknn_custom_string custom_string;
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
sizeof(custom_string));
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
if (config_.debug) {
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
}
auto meta = Parse(custom_string);
if (config_.silero_vad.window_size != 512) {
SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d",
config_.silero_vad.window_size);
SHERPA_ONNX_EXIT(-1);
}
if (config_.debug) {
for (const auto &p : meta) {
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
}
}
if (meta.count("model_type") == 0) {
SHERPA_ONNX_LOGE("No model type found in '%s'",
config_.silero_vad.model.c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.at("model_type") != "silero-vad-v4") {
SHERPA_ONNX_LOGE("Expect model type silero-vad-v4 in '%s', given: '%s'",
config_.silero_vad.model.c_str(),
meta.at("model_type").c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.count("sample_rate") == 0) {
SHERPA_ONNX_LOGE("No sample_rate found in '%s'",
config_.silero_vad.model.c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.at("sample_rate") != "16000") {
SHERPA_ONNX_LOGE("Expect sample rate 16000 in '%s', given: '%s'",
config_.silero_vad.model.c_str(),
meta.at("sample_rate").c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.count("version") == 0) {
SHERPA_ONNX_LOGE("No version found in '%s'",
config_.silero_vad.model.c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.at("version") != "4") {
SHERPA_ONNX_LOGE("Expect version 4 in '%s', given: '%s'",
config_.silero_vad.model.c_str(),
meta.at("version").c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.count("h_shape") == 0) {
SHERPA_ONNX_LOGE("No h_shape found in '%s'",
config_.silero_vad.model.c_str());
SHERPA_ONNX_EXIT(-1);
}
if (meta.count("c_shape") == 0) {
SHERPA_ONNX_LOGE("No c_shape found in '%s'",
config_.silero_vad.model.c_str());
SHERPA_ONNX_EXIT(-1);
}
std::vector<int64_t> h_shape;
std::vector<int64_t> c_shape;
SplitStringToIntegers(meta.at("h_shape"), ",", false, &h_shape);
SplitStringToIntegers(meta.at("c_shape"), ",", false, &c_shape);
if (h_shape.size() != 3 || c_shape.size() != 3) {
SHERPA_ONNX_LOGE("Incorrect shape for h (%d) or c (%d)",
static_cast<int32_t>(h_shape.size()),
static_cast<int32_t>(c_shape.size()));
SHERPA_ONNX_EXIT(-1);
}
states_.resize(2);
states_[0].resize(h_shape[0] * h_shape[1] * h_shape[2]);
states_[1].resize(c_shape[0] * c_shape[1] * c_shape[2]);
Reset();
}
float Run(const float *samples, int32_t n) {
std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
auto &input = inputs[i];
auto &attr = input_attrs_[i];
input.index = attr.index;
if (attr.type == RKNN_TENSOR_FLOAT16) {
input.type = RKNN_TENSOR_FLOAT32;
} else if (attr.type == RKNN_TENSOR_INT64) {
input.type = RKNN_TENSOR_INT64;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
input.fmt = attr.fmt;
if (i == 0) {
input.buf = reinterpret_cast<void *>(const_cast<float *>(samples));
input.size = n * sizeof(float);
} else {
input.buf = reinterpret_cast<void *>(states_[i - 1].data());
input.size = states_[i - 1].size() * sizeof(float);
}
}
std::vector<float> out(output_attrs_[0].n_elems);
auto &next_states = states_;
std::vector<rknn_output> outputs(output_attrs_.size());
for (int32_t i = 0; i < outputs.size(); ++i) {
auto &output = outputs[i];
auto &attr = output_attrs_[i];
output.index = attr.index;
output.is_prealloc = 1;
if (attr.type == RKNN_TENSOR_FLOAT16) {
output.want_float = 1;
} else if (attr.type == RKNN_TENSOR_INT64) {
output.want_float = 0;
} else {
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
get_type_string(attr.type));
SHERPA_ONNX_EXIT(-1);
}
if (i == 0) {
output.size = out.size() * sizeof(float);
output.buf = reinterpret_cast<void *>(out.data());
} else {
output.size = next_states[i - 1].size() * sizeof(float);
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
}
}
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
return out[0];
}
private:
VadModelConfig config_;
rknn_context ctx_ = 0;
std::vector<rknn_tensor_attr> input_attrs_;
std::vector<rknn_tensor_attr> output_attrs_;
std::vector<std::vector<float>> states_;
int64_t sample_rate_;
int32_t min_silence_samples_;
int32_t min_speech_samples_;
bool triggered_ = false;
int32_t current_sample_ = 0;
int32_t temp_start_ = 0;
int32_t temp_end_ = 0;
int32_t window_overlap_ = 0;
};
SileroVadModelRknn::SileroVadModelRknn(const VadModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
SileroVadModelRknn::SileroVadModelRknn(Manager *mgr,
const VadModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
SileroVadModelRknn::~SileroVadModelRknn() = default;
void SileroVadModelRknn::Reset() { return impl_->Reset(); }
bool SileroVadModelRknn::IsSpeech(const float *samples, int32_t n) {
return impl_->IsSpeech(samples, n);
}
int32_t SileroVadModelRknn::WindowSize() const { return impl_->WindowSize(); }
int32_t SileroVadModelRknn::WindowShift() const { return impl_->WindowShift(); }
int32_t SileroVadModelRknn::MinSilenceDurationSamples() const {
return impl_->MinSilenceDurationSamples();
}
int32_t SileroVadModelRknn::MinSpeechDurationSamples() const {
return impl_->MinSpeechDurationSamples();
}
void SileroVadModelRknn::SetMinSilenceDuration(float s) {
impl_->SetMinSilenceDuration(s);
}
void SileroVadModelRknn::SetThreshold(float threshold) {
impl_->SetThreshold(threshold);
}
#if __ANDROID_API__ >= 9
template SileroVadModelRknn::SileroVadModelRknn(AAssetManager *mgr,
const VadModelConfig &config);
#endif
#if __OHOS__
template SileroVadModelRknn::SileroVadModelRknn(NativeResourceManager *mgr,
const VadModelConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
#include "rknn_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/vad-model.h"
namespace sherpa_onnx {
class SileroVadModelRknn : public VadModel {
public:
explicit SileroVadModelRknn(const VadModelConfig &config);
template <typename Manager>
SileroVadModelRknn(Manager *mgr, const VadModelConfig &config);
~SileroVadModelRknn() override;
// reset the internal model states
void Reset() override;
/**
* @param samples Pointer to a 1-d array containing audio samples.
* Each sample should be normalized to the range [-1, 1].
* @param n Number of samples.
*
* @return Return true if speech is detected. Return false otherwise.
*/
bool IsSpeech(const float *samples, int32_t n) override;
// For silero vad V4, it is WindowShift().
int32_t WindowSize() const override;
// 512
int32_t WindowShift() const override;
int32_t MinSilenceDurationSamples() const override;
int32_t MinSpeechDurationSamples() const override;
void SetMinSilenceDuration(float s) override;
void SetThreshold(float threshold) override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
... ...
... ... @@ -129,15 +129,13 @@ as the device_name.
exit(-1);
}
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
fprintf(stderr, "Started. Please speak\n");
int32_t window_size = vad_config.silero_vad.window_size;
int32_t index = 0;
while (!stop) {
const std::vector<float> &samples = alsa.Read(chunk);
const std::vector<float> &samples = alsa.Read(window_size);
vad->AcceptWaveform(samples.data(), samples.size());
while (!vad->Empty()) {
... ...
... ... @@ -7,6 +7,9 @@
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void VadModelConfig::Register(ParseOptions *po) {
... ... @@ -26,7 +29,27 @@ void VadModelConfig::Register(ParseOptions *po) {
"true to display debug information when loading vad models");
}
bool VadModelConfig::Validate() const { return silero_vad.Validate(); }
bool VadModelConfig::Validate() const {
if (provider != "rknn") {
if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".rknn")) {
SHERPA_ONNX_LOGE(
"--provider is %s, which is not rknn, but you pass an rknn model "
"'%s'",
provider.c_str(), silero_vad.model.c_str());
return false;
}
}
if (provider == "rknn") {
if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".onnx")) {
SHERPA_ONNX_LOGE("--provider is rknn, but you pass an onnx model '%s'",
silero_vad.model.c_str());
return false;
}
}
return silero_vad.Validate();
}
std::string VadModelConfig::ToString() const {
std::ostringstream os;
... ...
... ... @@ -13,19 +13,27 @@
#include "rawfile/raw_file_manager.h"
#endif
#if SHERPA_ONNX_ENABLE_RKNN
#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
#endif
#include "sherpa-onnx/csrc/silero-vad-model.h"
namespace sherpa_onnx {
std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
// TODO(fangjun): Support other VAD models.
if (config.provider == "rknn") {
return std::make_unique<SileroVadModelRknn>(config);
}
return std::make_unique<SileroVadModel>(config);
}
template <typename Manager>
std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
const VadModelConfig &config) {
// TODO(fangjun): Support other VAD models.
if (config.provider == "rknn") {
return std::make_unique<SileroVadModelRknn>(mgr, config);
}
return std::make_unique<SileroVadModel>(mgr, config);
}
... ...