Fangjun Kuang
Committed by GitHub

Add C++ runtime for silero_vad with RKNN (#2078)

@@ -100,12 +100,11 @@ int32_t main() { @@ -100,12 +100,11 @@ int32_t main() {
100 100
101 while (!is_eof) { 101 while (!is_eof) {
102 if (i + window_size < wave->num_samples) { 102 if (i + window_size < wave->num_samples) {
103 - SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,  
104 - window_size);  
105 - }  
106 - else {  
107 - SherpaOnnxVoiceActivityDetectorFlush(vad);  
108 - is_eof = 1; 103 + SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
  104 + window_size);
  105 + } else {
  106 + SherpaOnnxVoiceActivityDetectorFlush(vad);
  107 + is_eof = 1;
109 } 108 }
110 while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) { 109 while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) {
111 const SherpaOnnxSpeechSegment *segment = 110 const SherpaOnnxSpeechSegment *segment =
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnxruntime
  5 +import onnx
  6 +
  7 +"""
  8 +[key: "model_type"
  9 +value: "gtcrn"
  10 +, key: "comment"
  11 +value: "gtcrn_simple"
  12 +, key: "version"
  13 +value: "1"
  14 +, key: "sample_rate"
  15 +value: "16000"
  16 +, key: "model_url"
  17 +value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx"
  18 +, key: "maintainer"
  19 +value: "k2-fsa"
  20 +, key: "comment2"
  21 +value: "Please see also https://github.com/Xiaobin-Rong/gtcrn"
  22 +, key: "conv_cache_shape"
  23 +value: "2,1,16,16,33"
  24 +, key: "tra_cache_shape"
  25 +value: "2,3,1,1,16"
  26 +, key: "inter_cache_shape"
  27 +value: "2,1,33,16"
  28 +, key: "n_fft"
  29 +value: "512"
  30 +, key: "hop_length"
  31 +value: "256"
  32 +, key: "window_length"
  33 +value: "512"
  34 +, key: "window_type"
  35 +value: "hann_sqrt"
  36 +]
  37 +"""
  38 +
  39 +"""
  40 +NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
  41 +NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
  42 +NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
  43 +NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
  44 +-----
  45 +NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
  46 +NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
  47 +NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
  48 +NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
  49 +"""
  50 +
  51 +
  52 +def show(filename):
  53 + model = onnx.load(filename)
  54 + print(model.metadata_props)
  55 +
  56 + session_opts = onnxruntime.SessionOptions()
  57 + session_opts.log_severity_level = 3
  58 + sess = onnxruntime.InferenceSession(
  59 + filename, session_opts, providers=["CPUExecutionProvider"]
  60 + )
  61 + for i in sess.get_inputs():
  62 + print(i)
  63 +
  64 + print("-----")
  65 +
  66 + for i in sess.get_outputs():
  67 + print(i)
  68 +
  69 +
  70 +def main():
  71 + show("./gtcrn_simple.onnx")
  72 +
  73 +
  74 +if __name__ == "__main__":
  75 + main()
@@ -5,15 +5,94 @@ import onnx @@ -5,15 +5,94 @@ import onnx
5 import torch 5 import torch
6 from onnxsim import simplify 6 from onnxsim import simplify
7 7
  8 +import torch
  9 +from torch import Tensor
  10 +
  11 +
  12 +def simple_pad(x: Tensor, pad: int) -> Tensor:
  13 + # _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
  14 + _0 = x[:, :, 1 : 1 + pad]
  15 +
  16 + left_pad = torch.flip(_0, [-1])
  17 + # _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)
  18 +
  19 + _1 = x[:, :, (-1 - pad) : -1]
  20 +
  21 + right_pad = torch.flip(_1, [-1])
  22 + _2 = torch.cat([left_pad, x, right_pad], 2)
  23 + return _2
  24 +
  25 +
  26 +class MyModule(torch.nn.Module):
  27 + def __init__(self, m):
  28 + super().__init__()
  29 + self.m = m
  30 +
  31 + def adaptive_normalization_forward(self, spect):
  32 + m = self.m._model.adaptive_normalization
  33 + _0 = simple_pad
  34 +
  35 + # Note(fangjun): rknn uses fp16 by default, whose max value is 65504
  36 + # so we need to re-write the computation for spect0
  37 + # spect0 = torch.log1p(torch.mul(spect, 1048576))
  38 + spect0 = torch.log1p(spect) + 13.86294
  39 +
  40 + _1 = torch.eq(len(spect0.shape), 2)
  41 + if _1:
  42 + _2 = torch.unsqueeze(spect0, 0)
  43 + spect1 = _2
  44 + else:
  45 + spect1 = spect0
  46 + mean = torch.mean(spect1, [1], True)
  47 + to_pad = m.to_pad
  48 + mean0 = _0(
  49 + mean,
  50 + to_pad,
  51 + )
  52 + filter_ = m.filter_
  53 + mean1 = torch.conv1d(mean0, filter_)
  54 + mean_mean = torch.mean(mean1, [-1], True)
  55 + spect2 = torch.add(spect1, torch.neg(mean_mean))
  56 + return spect2
  57 +
  58 + def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
  59 + m = self.m._model
  60 +
  61 + feature_extractor = m.feature_extractor
  62 + x0 = (feature_extractor).forward(
  63 + x,
  64 + )
  65 + norm = self.adaptive_normalization_forward(x0)
  66 + x1 = torch.cat([x0, norm], 1)
  67 + first_layer = m.first_layer
  68 + x2 = (first_layer).forward(
  69 + x1,
  70 + )
  71 + encoder = m.encoder
  72 + x3 = (encoder).forward(
  73 + x2,
  74 + )
  75 + decoder = m.decoder
  76 + x4, h0, c0, = (decoder).forward(
  77 + x3,
  78 + h,
  79 + c,
  80 + )
  81 + _0 = torch.mean(torch.squeeze(x4, 1), [1])
  82 + out = torch.unsqueeze(_0, 1)
  83 + return (out, h0, c0)
  84 +
8 85
9 @torch.no_grad() 86 @torch.no_grad()
10 def main(): 87 def main():
11 m = torch.jit.load("./silero_vad.jit") 88 m = torch.jit.load("./silero_vad.jit")
  89 + m = MyModule(m)
12 x = torch.rand((1, 512), dtype=torch.float32) 90 x = torch.rand((1, 512), dtype=torch.float32)
13 h = torch.rand((2, 1, 64), dtype=torch.float32) 91 h = torch.rand((2, 1, 64), dtype=torch.float32)
14 c = torch.rand((2, 1, 64), dtype=torch.float32) 92 c = torch.rand((2, 1, 64), dtype=torch.float32)
  93 + m = torch.jit.script(m)
15 torch.onnx.export( 94 torch.onnx.export(
16 - m._model, 95 + m,
17 (x, h, c), 96 (x, h, c),
18 "m.onnx", 97 "m.onnx",
19 input_names=["x", "h", "c"], 98 input_names=["x", "h", "c"],
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) 2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
3 3
4 import onnxruntime 4 import onnxruntime
5 import onnx 5 import onnx
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +# Please run this file on your rk3588 board
  5 +
  6 +try:
  7 + from rknnlite.api import RKNNLite
  8 +except:
  9 + print("Please run this file on your board (linux + aarch64 + npu)")
  10 + print("You need to install rknn_toolkit_lite2")
  11 + print(
  12 + " from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages"
  13 + )
  14 + print(
  15 + "https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl"
  16 + )
  17 + print("is known to work")
  18 + raise
  19 +
  20 +import time
  21 +from pathlib import Path
  22 +from typing import Tuple
  23 +
  24 +import numpy as np
  25 +import soundfile as sf
  26 +
  27 +
  28 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  29 + data, sample_rate = sf.read(
  30 + filename,
  31 + always_2d=True,
  32 + dtype="float32",
  33 + )
  34 + data = data[:, 0] # use only the first channel
  35 +
  36 + samples = np.ascontiguousarray(data)
  37 + return samples, sample_rate
  38 +
  39 +
  40 +def init_model(filename, target_platform="rk3588"):
  41 + if not Path(filename).is_file():
  42 + exit(f"{filename} does not exist")
  43 +
  44 + rknn_lite = RKNNLite(verbose=False)
  45 + ret = rknn_lite.load_rknn(path=filename)
  46 + if ret != 0:
  47 + exit(f"Load model {filename} failed!")
  48 +
  49 + ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
  50 + if ret != 0:
  51 + exit(f"Failed to init rknn runtime for {filename}")
  52 + return rknn_lite
  53 +
  54 +
  55 +class RKNNModel:
  56 + def __init__(self, model: str, target_platform="rk3588"):
  57 + self.model = init_model(model)
  58 +
  59 + def release(self):
  60 + self.model.release()
  61 +
  62 + def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray):
  63 + """
  64 + Args:
  65 + x: (1, 512), np.float32
  66 + h: (2, 1, 64), np.float32
  67 + c: (2, 1, 64), np.float32
  68 + Returns:
  69 + prob:
  70 + next_h:
  71 + next_c
  72 + """
  73 + out, next_h, next_c = self.model.inference(inputs=[x, h, c])
  74 + return out.item(), next_h, next_c
  75 +
  76 +
  77 +def main():
  78 + model = RKNNModel(model="./m.rknn")
  79 + for i in range(1):
  80 + test(model)
  81 +
  82 +
  83 +def test(model):
  84 + print("started")
  85 + start = time.time()
  86 + samples, sample_rate = load_audio("./lei-jun-test.wav")
  87 + assert sample_rate == 16000, sample_rate
  88 +
  89 + window_size = 512
  90 +
  91 + h = np.zeros((2, 1, 64), dtype=np.float32)
  92 + c = np.zeros((2, 1, 64), dtype=np.float32)
  93 +
  94 + threshold = 0.5
  95 + num_windows = samples.shape[0] // window_size
  96 + out = []
  97 + for i in range(num_windows):
  98 + print(i, num_windows)
  99 + this_samples = samples[i * window_size : (i + 1) * window_size]
  100 + prob, h, c = model(this_samples[None], h, c)
  101 + out.append(prob > threshold)
  102 +
  103 + min_speech_duration = 0.25 * sample_rate / window_size
  104 + min_silence_duration = 0.25 * sample_rate / window_size
  105 +
  106 + result = []
  107 + last = -1
  108 + for k, f in enumerate(out):
  109 + if f >= threshold:
  110 + if last == -1:
  111 + last = k
  112 + elif last != -1:
  113 + if k - last > min_speech_duration:
  114 + result.append((last, k))
  115 + last = -1
  116 +
  117 + if last != -1 and k - last > min_speech_duration:
  118 + result.append((last, k))
  119 +
  120 + if not result:
  121 + print("Empty for ./lei-jun-test.wav")
  122 + return
  123 +
  124 + print(result)
  125 +
  126 + final = [result[0]]
  127 + for r in result[1:]:
  128 + f = final[-1]
  129 + if r[0] - f[1] < min_silence_duration:
  130 + final[-1] = (f[0], r[1])
  131 + else:
  132 + final.append(r)
  133 +
  134 + for f in final:
  135 + start = f[0] * window_size / sample_rate
  136 + end = f[1] * window_size / sample_rate
  137 + print("{:.3f} -- {:.3f}".format(start, end))
  138 +
  139 +
  140 +if __name__ == "__main__":
  141 + main()
@@ -97,10 +97,13 @@ def main(): @@ -97,10 +97,13 @@ def main():
97 h, c = model.get_init_states() 97 h, c = model.get_init_states()
98 window_size = 512 98 window_size = 512
99 num_windows = samples.shape[0] // window_size 99 num_windows = samples.shape[0] // window_size
  100 +
100 for i in range(num_windows): 101 for i in range(num_windows):
101 start = i * window_size 102 start = i * window_size
102 end = start + window_size 103 end = start + window_size
  104 +
103 p, h, c = model(samples[start:end], h, c) 105 p, h, c = model(samples[start:end], h, c)
  106 +
104 probs.append(p[0].item()) 107 probs.append(p[0].item())
105 108
106 threshold = 0.5 109 threshold = 0.5
@@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) @@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
159 ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc 159 ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
160 ./rknn/online-zipformer-ctc-model-rknn.cc 160 ./rknn/online-zipformer-ctc-model-rknn.cc
161 ./rknn/online-zipformer-transducer-model-rknn.cc 161 ./rknn/online-zipformer-transducer-model-rknn.cc
  162 + ./rknn/silero-vad-model-rknn.cc
162 ./rknn/utils.cc 163 ./rknn/utils.cc
163 ) 164 )
164 165
@@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) @@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
468 microphone.cc 469 microphone.cc
469 ) 470 )
470 471
  472 +
471 add_executable(sherpa-onnx-microphone-offline 473 add_executable(sherpa-onnx-microphone-offline
472 sherpa-onnx-microphone-offline.cc 474 sherpa-onnx-microphone-offline.cc
473 microphone.cc 475 microphone.cc
@@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) @@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
498 ) 500 )
499 501
500 set(exes 502 set(exes
501 - sherpa-onnx-microphone  
502 sherpa-onnx-keyword-spotter-microphone 503 sherpa-onnx-keyword-spotter-microphone
  504 + sherpa-onnx-microphone
503 sherpa-onnx-microphone-offline 505 sherpa-onnx-microphone-offline
504 - sherpa-onnx-microphone-offline-speaker-identification  
505 sherpa-onnx-microphone-offline-audio-tagging 506 sherpa-onnx-microphone-offline-audio-tagging
  507 + sherpa-onnx-microphone-offline-speaker-identification
506 sherpa-onnx-vad-microphone 508 sherpa-onnx-vad-microphone
507 sherpa-onnx-vad-microphone-offline-asr 509 sherpa-onnx-vad-microphone-offline-asr
508 sherpa-onnx-vad-with-offline-asr 510 sherpa-onnx-vad-with-offline-asr
  1 +// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#if __OHOS__
  17 +#include "rawfile/raw_file_manager.h"
  18 +#endif
  19 +
  20 +#include "sherpa-onnx/csrc/file-utils.h"
  21 +#include "sherpa-onnx/csrc/macros.h"
  22 +#include "sherpa-onnx/csrc/rknn/macros.h"
  23 +#include "sherpa-onnx/csrc/rknn/utils.h"
  24 +#include "sherpa-onnx/csrc/text-utils.h"
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +class SileroVadModelRknn::Impl {
  29 + public:
  30 + ~Impl() {
  31 + auto ret = rknn_destroy(ctx_);
  32 + if (ret != RKNN_SUCC) {
  33 + SHERPA_ONNX_LOGE("Failed to destroy the context");
  34 + }
  35 + }
  36 +
  37 + explicit Impl(const VadModelConfig &config)
  38 + : config_(config), sample_rate_(config.sample_rate) {
  39 + auto buf = ReadFile(config.silero_vad.model);
  40 + Init(buf.data(), buf.size());
  41 +
  42 + if (sample_rate_ != 16000) {
  43 + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
  44 + config.sample_rate);
  45 + SHERPA_ONNX_EXIT(-1);
  46 + }
  47 +
  48 + min_silence_samples_ =
  49 + sample_rate_ * config_.silero_vad.min_silence_duration;
  50 +
  51 + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
  52 + }
  53 +
  54 + template <typename Manager>
  55 + Impl(Manager *mgr, const VadModelConfig &config)
  56 + : config_(config), sample_rate_(config.sample_rate) {
  57 + auto buf = ReadFile(mgr, config.silero_vad.model);
  58 + Init(buf.data(), buf.size());
  59 +
  60 + if (sample_rate_ != 16000) {
  61 + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
  62 + config.sample_rate);
  63 + exit(-1);
  64 + }
  65 +
  66 + min_silence_samples_ =
  67 + sample_rate_ * config_.silero_vad.min_silence_duration;
  68 +
  69 + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
  70 + }
  71 +
  72 + void Reset() {
  73 + for (auto &s : states_) {
  74 + std::fill(s.begin(), s.end(), 0);
  75 + }
  76 +
  77 + triggered_ = false;
  78 + current_sample_ = 0;
  79 + temp_start_ = 0;
  80 + temp_end_ = 0;
  81 + }
  82 +
  83 + bool IsSpeech(const float *samples, int32_t n) {
  84 + if (n != WindowSize()) {
  85 + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize());
  86 + SHERPA_ONNX_EXIT(-1);
  87 + }
  88 +
  89 + float prob = Run(samples, n);
  90 +
  91 + float threshold = config_.silero_vad.threshold;
  92 +
  93 + current_sample_ += config_.silero_vad.window_size;
  94 +
  95 + if (prob > threshold && temp_end_ != 0) {
  96 + temp_end_ = 0;
  97 + }
  98 +
  99 + if (prob > threshold && temp_start_ == 0) {
  100 + // start speaking, but we require that it must satisfy
  101 + // min_speech_duration
  102 + temp_start_ = current_sample_;
  103 + return false;
  104 + }
  105 +
  106 + if (prob > threshold && temp_start_ != 0 && !triggered_) {
  107 + if (current_sample_ - temp_start_ < min_speech_samples_) {
  108 + return false;
  109 + }
  110 +
  111 + triggered_ = true;
  112 +
  113 + return true;
  114 + }
  115 +
  116 + if ((prob < threshold) && !triggered_) {
  117 + // silence
  118 + temp_start_ = 0;
  119 + temp_end_ = 0;
  120 + return false;
  121 + }
  122 +
  123 + if ((prob > threshold - 0.15) && triggered_) {
  124 + // speaking
  125 + return true;
  126 + }
  127 +
  128 + if ((prob > threshold) && !triggered_) {
  129 + // start speaking
  130 + triggered_ = true;
  131 +
  132 + return true;
  133 + }
  134 +
  135 + if ((prob < threshold) && triggered_) {
  136 + // stop to speak
  137 + if (temp_end_ == 0) {
  138 + temp_end_ = current_sample_;
  139 + }
  140 +
  141 + if (current_sample_ - temp_end_ < min_silence_samples_) {
  142 + // continue speaking
  143 + return true;
  144 + }
  145 + // stopped speaking
  146 + temp_start_ = 0;
  147 + temp_end_ = 0;
  148 + triggered_ = false;
  149 + return false;
  150 + }
  151 +
  152 + return false;
  153 + }
  154 +
  155 + int32_t WindowShift() const { return config_.silero_vad.window_size; }
  156 +
  157 + int32_t WindowSize() const {
  158 + return config_.silero_vad.window_size + window_overlap_;
  159 + }
  160 +
  161 + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
  162 +
  163 + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
  164 +
  165 + void SetMinSilenceDuration(float s) {
  166 + min_silence_samples_ = sample_rate_ * s;
  167 + }
  168 +
  169 + void SetThreshold(float threshold) {
  170 + config_.silero_vad.threshold = threshold;
  171 + }
  172 +
  173 + private:
  174 + void Init(void *model_data, size_t model_data_length) {
  175 + auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
  176 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'",
  177 + config_.silero_vad.model.c_str());
  178 +
  179 + if (config_.debug) {
  180 + rknn_sdk_version v;
  181 + ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
  182 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
  183 +
  184 + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
  185 + v.drv_version);
  186 + }
  187 +
  188 + rknn_input_output_num io_num;
  189 + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
  190 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
  191 +
  192 + if (config_.debug) {
  193 + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
  194 + static_cast<int32_t>(io_num.n_input),
  195 + static_cast<int32_t>(io_num.n_output));
  196 + }
  197 +
  198 + input_attrs_.resize(io_num.n_input);
  199 + output_attrs_.resize(io_num.n_output);
  200 +
  201 + int32_t i = 0;
  202 + for (auto &attr : input_attrs_) {
  203 + memset(&attr, 0, sizeof(attr));
  204 + attr.index = i;
  205 + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
  206 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
  207 + i += 1;
  208 + }
  209 +
  210 + if (config_.debug) {
  211 + std::ostringstream os;
  212 + std::string sep;
  213 + for (auto &attr : input_attrs_) {
  214 + os << sep << ToString(attr);
  215 + sep = "\n";
  216 + }
  217 + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
  218 + os.str().c_str());
  219 + }
  220 +
  221 + i = 0;
  222 + for (auto &attr : output_attrs_) {
  223 + memset(&attr, 0, sizeof(attr));
  224 + attr.index = i;
  225 + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
  226 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
  227 + i += 1;
  228 + }
  229 +
  230 + if (config_.debug) {
  231 + std::ostringstream os;
  232 + std::string sep;
  233 + for (auto &attr : output_attrs_) {
  234 + os << sep << ToString(attr);
  235 + sep = "\n";
  236 + }
  237 + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
  238 + os.str().c_str());
  239 + }
  240 +
  241 + rknn_custom_string custom_string;
  242 + ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
  243 + sizeof(custom_string));
  244 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
  245 + if (config_.debug) {
  246 + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
  247 + }
  248 + auto meta = Parse(custom_string);
  249 +
  250 + if (config_.silero_vad.window_size != 512) {
  251 + SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d",
  252 + config_.silero_vad.window_size);
  253 + SHERPA_ONNX_EXIT(-1);
  254 + }
  255 +
  256 + if (config_.debug) {
  257 + for (const auto &p : meta) {
  258 + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
  259 + }
  260 + }
  261 +
  262 + if (meta.count("model_type") == 0) {
  263 + SHERPA_ONNX_LOGE("No model type found in '%s'",
  264 + config_.silero_vad.model.c_str());
  265 + SHERPA_ONNX_EXIT(-1);
  266 + }
  267 +
  268 + if (meta.at("model_type") != "silero-vad-v4") {
  269 + SHERPA_ONNX_LOGE("Expect model type silero-vad-v4 in '%s', given: '%s'",
  270 + config_.silero_vad.model.c_str(),
  271 + meta.at("model_type").c_str());
  272 + SHERPA_ONNX_EXIT(-1);
  273 + }
  274 +
  275 + if (meta.count("sample_rate") == 0) {
  276 + SHERPA_ONNX_LOGE("No sample_rate found in '%s'",
  277 + config_.silero_vad.model.c_str());
  278 + SHERPA_ONNX_EXIT(-1);
  279 + }
  280 +
  281 + if (meta.at("sample_rate") != "16000") {
  282 + SHERPA_ONNX_LOGE("Expect sample rate 16000 in '%s', given: '%s'",
  283 + config_.silero_vad.model.c_str(),
  284 + meta.at("sample_rate").c_str());
  285 + SHERPA_ONNX_EXIT(-1);
  286 + }
  287 +
  288 + if (meta.count("version") == 0) {
  289 + SHERPA_ONNX_LOGE("No version found in '%s'",
  290 + config_.silero_vad.model.c_str());
  291 + SHERPA_ONNX_EXIT(-1);
  292 + }
  293 +
  294 + if (meta.at("version") != "4") {
  295 + SHERPA_ONNX_LOGE("Expect version 4 in '%s', given: '%s'",
  296 + config_.silero_vad.model.c_str(),
  297 + meta.at("version").c_str());
  298 + SHERPA_ONNX_EXIT(-1);
  299 + }
  300 +
  301 + if (meta.count("h_shape") == 0) {
  302 + SHERPA_ONNX_LOGE("No h_shape found in '%s'",
  303 + config_.silero_vad.model.c_str());
  304 + SHERPA_ONNX_EXIT(-1);
  305 + }
  306 +
  307 + if (meta.count("c_shape") == 0) {
  308 + SHERPA_ONNX_LOGE("No c_shape found in '%s'",
  309 + config_.silero_vad.model.c_str());
  310 + SHERPA_ONNX_EXIT(-1);
  311 + }
  312 +
  313 + std::vector<int64_t> h_shape;
  314 + std::vector<int64_t> c_shape;
  315 +
  316 + SplitStringToIntegers(meta.at("h_shape"), ",", false, &h_shape);
  317 + SplitStringToIntegers(meta.at("c_shape"), ",", false, &c_shape);
  318 + if (h_shape.size() != 3 || c_shape.size() != 3) {
  319 + SHERPA_ONNX_LOGE("Incorrect shape for h (%d) or c (%d)",
  320 + static_cast<int32_t>(h_shape.size()),
  321 + static_cast<int32_t>(c_shape.size()));
  322 + SHERPA_ONNX_EXIT(-1);
  323 + }
  324 +
  325 + states_.resize(2);
  326 + states_[0].resize(h_shape[0] * h_shape[1] * h_shape[2]);
  327 + states_[1].resize(c_shape[0] * c_shape[1] * c_shape[2]);
  328 +
  329 + Reset();
  330 + }
  331 +
  332 + float Run(const float *samples, int32_t n) {
  333 + std::vector<rknn_input> inputs(input_attrs_.size());
  334 +
  335 + for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
  336 + auto &input = inputs[i];
  337 + auto &attr = input_attrs_[i];
  338 + input.index = attr.index;
  339 +
  340 + if (attr.type == RKNN_TENSOR_FLOAT16) {
  341 + input.type = RKNN_TENSOR_FLOAT32;
  342 + } else if (attr.type == RKNN_TENSOR_INT64) {
  343 + input.type = RKNN_TENSOR_INT64;
  344 + } else {
  345 + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
  346 + get_type_string(attr.type));
  347 + SHERPA_ONNX_EXIT(-1);
  348 + }
  349 +
  350 + input.fmt = attr.fmt;
  351 + if (i == 0) {
  352 + input.buf = reinterpret_cast<void *>(const_cast<float *>(samples));
  353 + input.size = n * sizeof(float);
  354 + } else {
  355 + input.buf = reinterpret_cast<void *>(states_[i - 1].data());
  356 + input.size = states_[i - 1].size() * sizeof(float);
  357 + }
  358 + }
  359 +
  360 + std::vector<float> out(output_attrs_[0].n_elems);
  361 +
  362 + auto &next_states = states_;
  363 +
  364 + std::vector<rknn_output> outputs(output_attrs_.size());
  365 +
  366 + for (int32_t i = 0; i < outputs.size(); ++i) {
  367 + auto &output = outputs[i];
  368 + auto &attr = output_attrs_[i];
  369 + output.index = attr.index;
  370 + output.is_prealloc = 1;
  371 +
  372 + if (attr.type == RKNN_TENSOR_FLOAT16) {
  373 + output.want_float = 1;
  374 + } else if (attr.type == RKNN_TENSOR_INT64) {
  375 + output.want_float = 0;
  376 + } else {
  377 + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
  378 + get_type_string(attr.type));
  379 + SHERPA_ONNX_EXIT(-1);
  380 + }
  381 +
  382 + if (i == 0) {
  383 + output.size = out.size() * sizeof(float);
  384 + output.buf = reinterpret_cast<void *>(out.data());
  385 + } else {
  386 + output.size = next_states[i - 1].size() * sizeof(float);
  387 + output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
  388 + }
  389 + }
  390 +
  391 + auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
  392 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
  393 +
  394 + ret = rknn_run(ctx_, nullptr);
  395 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
  396 +
  397 + ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
  398 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
  399 +
  400 + return out[0];
  401 + }
  402 +
  403 + private:
  404 + VadModelConfig config_;
  405 + rknn_context ctx_ = 0;
  406 +
  407 + std::vector<rknn_tensor_attr> input_attrs_;
  408 + std::vector<rknn_tensor_attr> output_attrs_;
  409 +
  410 + std::vector<std::vector<float>> states_;
  411 +
  412 + int64_t sample_rate_;
  413 + int32_t min_silence_samples_;
  414 + int32_t min_speech_samples_;
  415 +
  416 + bool triggered_ = false;
  417 + int32_t current_sample_ = 0;
  418 + int32_t temp_start_ = 0;
  419 + int32_t temp_end_ = 0;
  420 +
  421 + int32_t window_overlap_ = 0;
  422 +};
  423 +
  424 +SileroVadModelRknn::SileroVadModelRknn(const VadModelConfig &config)
  425 + : impl_(std::make_unique<Impl>(config)) {}
  426 +
  427 +template <typename Manager>
  428 +SileroVadModelRknn::SileroVadModelRknn(Manager *mgr,
  429 + const VadModelConfig &config)
  430 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  431 +
  432 +SileroVadModelRknn::~SileroVadModelRknn() = default;
  433 +
  434 +void SileroVadModelRknn::Reset() { return impl_->Reset(); }
  435 +
  436 +bool SileroVadModelRknn::IsSpeech(const float *samples, int32_t n) {
  437 + return impl_->IsSpeech(samples, n);
  438 +}
  439 +
  440 +int32_t SileroVadModelRknn::WindowSize() const { return impl_->WindowSize(); }
  441 +
  442 +int32_t SileroVadModelRknn::WindowShift() const { return impl_->WindowShift(); }
  443 +
  444 +int32_t SileroVadModelRknn::MinSilenceDurationSamples() const {
  445 + return impl_->MinSilenceDurationSamples();
  446 +}
  447 +
  448 +int32_t SileroVadModelRknn::MinSpeechDurationSamples() const {
  449 + return impl_->MinSpeechDurationSamples();
  450 +}
  451 +
  452 +void SileroVadModelRknn::SetMinSilenceDuration(float s) {
  453 + impl_->SetMinSilenceDuration(s);
  454 +}
  455 +
  456 +void SileroVadModelRknn::SetThreshold(float threshold) {
  457 + impl_->SetThreshold(threshold);
  458 +}
  459 +
  460 +#if __ANDROID_API__ >= 9
  461 +template SileroVadModelRknn::SileroVadModelRknn(AAssetManager *mgr,
  462 + const VadModelConfig &config);
  463 +#endif
  464 +
  465 +#if __OHOS__
  466 +template SileroVadModelRknn::SileroVadModelRknn(NativeResourceManager *mgr,
  467 + const VadModelConfig &config);
  468 +#endif
  469 +
  470 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
  5 +#define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
  6 +
  7 +#include "rknn_api.h" // NOLINT
  8 +#include "sherpa-onnx/csrc/online-model-config.h"
  9 +#include "sherpa-onnx/csrc/vad-model.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +class SileroVadModelRknn : public VadModel {
  14 + public:
  15 + explicit SileroVadModelRknn(const VadModelConfig &config);
  16 +
  17 + template <typename Manager>
  18 + SileroVadModelRknn(Manager *mgr, const VadModelConfig &config);
  19 +
  20 + ~SileroVadModelRknn() override;
  21 +
  22 + // reset the internal model states
  23 + void Reset() override;
  24 +
  25 + /**
  26 + * @param samples Pointer to a 1-d array containing audio samples.
  27 + * Each sample should be normalized to the range [-1, 1].
  28 + * @param n Number of samples.
  29 + *
  30 + * @return Return true if speech is detected. Return false otherwise.
  31 + */
  32 + bool IsSpeech(const float *samples, int32_t n) override;
  33 +
  34 + // For silero vad V4, it is WindowShift().
  35 + int32_t WindowSize() const override;
  36 +
  37 + // 512
  38 + int32_t WindowShift() const override;
  39 +
  40 + int32_t MinSilenceDurationSamples() const override;
  41 + int32_t MinSpeechDurationSamples() const override;
  42 +
  43 + void SetMinSilenceDuration(float s) override;
  44 + void SetThreshold(float threshold) override;
  45 +
  46 + private:
  47 + class Impl;
  48 + std::unique_ptr<Impl> impl_;
  49 +};
  50 +
  51 +} // namespace sherpa_onnx
  52 +
  53 +#endif // SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
@@ -129,15 +129,13 @@ as the device_name. @@ -129,15 +129,13 @@ as the device_name.
129 exit(-1); 129 exit(-1);
130 } 130 }
131 131
132 - int32_t chunk = 0.1 * alsa.GetActualSampleRate();  
133 -  
134 fprintf(stderr, "Started. Please speak\n"); 132 fprintf(stderr, "Started. Please speak\n");
135 133
136 int32_t window_size = vad_config.silero_vad.window_size; 134 int32_t window_size = vad_config.silero_vad.window_size;
137 int32_t index = 0; 135 int32_t index = 0;
138 136
139 while (!stop) { 137 while (!stop) {
140 - const std::vector<float> &samples = alsa.Read(chunk); 138 + const std::vector<float> &samples = alsa.Read(window_size);
141 vad->AcceptWaveform(samples.data(), samples.size()); 139 vad->AcceptWaveform(samples.data(), samples.size());
142 140
143 while (!vad->Empty()) { 141 while (!vad->Empty()) {
@@ -7,6 +7,9 @@ @@ -7,6 +7,9 @@
7 #include <sstream> 7 #include <sstream>
8 #include <string> 8 #include <string>
9 9
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/text-utils.h"
  12 +
10 namespace sherpa_onnx { 13 namespace sherpa_onnx {
11 14
12 void VadModelConfig::Register(ParseOptions *po) { 15 void VadModelConfig::Register(ParseOptions *po) {
@@ -26,7 +29,27 @@ void VadModelConfig::Register(ParseOptions *po) { @@ -26,7 +29,27 @@ void VadModelConfig::Register(ParseOptions *po) {
26 "true to display debug information when loading vad models"); 29 "true to display debug information when loading vad models");
27 } 30 }
28 31
29 -bool VadModelConfig::Validate() const { return silero_vad.Validate(); } 32 +bool VadModelConfig::Validate() const {
  33 + if (provider != "rknn") {
  34 + if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".rknn")) {
  35 + SHERPA_ONNX_LOGE(
  36 + "--provider is %s, which is not rknn, but you pass an rknn model "
  37 + "'%s'",
  38 + provider.c_str(), silero_vad.model.c_str());
  39 + return false;
  40 + }
  41 + }
  42 +
  43 + if (provider == "rknn") {
  44 + if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".onnx")) {
  45 + SHERPA_ONNX_LOGE("--provider is rknn, but you pass an onnx model '%s'",
  46 + silero_vad.model.c_str());
  47 + return false;
  48 + }
  49 + }
  50 +
  51 + return silero_vad.Validate();
  52 +}
30 53
31 std::string VadModelConfig::ToString() const { 54 std::string VadModelConfig::ToString() const {
32 std::ostringstream os; 55 std::ostringstream os;
@@ -13,19 +13,27 @@ @@ -13,19 +13,27 @@
13 #include "rawfile/raw_file_manager.h" 13 #include "rawfile/raw_file_manager.h"
14 #endif 14 #endif
15 15
  16 +#if SHERPA_ONNX_ENABLE_RKNN
  17 +#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
  18 +#endif
  19 +
16 #include "sherpa-onnx/csrc/silero-vad-model.h" 20 #include "sherpa-onnx/csrc/silero-vad-model.h"
17 21
18 namespace sherpa_onnx { 22 namespace sherpa_onnx {
19 23
20 std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) { 24 std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
21 - // TODO(fangjun): Support other VAD models. 25 + if (config.provider == "rknn") {
  26 + return std::make_unique<SileroVadModelRknn>(config);
  27 + }
22 return std::make_unique<SileroVadModel>(config); 28 return std::make_unique<SileroVadModel>(config);
23 } 29 }
24 30
25 template <typename Manager> 31 template <typename Manager>
26 std::unique_ptr<VadModel> VadModel::Create(Manager *mgr, 32 std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
27 const VadModelConfig &config) { 33 const VadModelConfig &config) {
28 - // TODO(fangjun): Support other VAD models. 34 + if (config.provider == "rknn") {
  35 + return std::make_unique<SileroVadModelRknn>(mgr, config);
  36 + }
29 return std::make_unique<SileroVadModel>(mgr, config); 37 return std::make_unique<SileroVadModel>(mgr, config);
30 } 38 }
31 39