Fangjun Kuang
Committed by GitHub

Export silero_vad v4 to RKNN (#2067)

  1 +name: export-silero-vad-to-rknn
  2 +
  3 +on:
  4 + workflow_dispatch:
  5 +
  6 +concurrency:
  7 + group: export-silero-vad-to-rknn-${{ github.ref }}
  8 + cancel-in-progress: true
  9 +
  10 +jobs:
  11 + export-silero-vad-to-rknn:
  12 + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
  13 + name: export silero-vad to rknn
  14 + runs-on: ${{ matrix.os }}
  15 + strategy:
  16 + fail-fast: false
  17 + matrix:
  18 + os: [ubuntu-latest]
  19 + python-version: ["3.10"]
  20 +
  21 + steps:
  22 + - uses: actions/checkout@v4
  23 +
  24 + - name: Setup Python ${{ matrix.python-version }}
  25 + uses: actions/setup-python@v5
  26 + with:
  27 + python-version: ${{ matrix.python-version }}
  28 +
  29 + - name: Install Python dependencies
  30 + shell: bash
  31 + run: |
  32 + python3 -m pip install --upgrade \
  33 + pip \
  34 + "numpy<2" \
  35 + torch==2.0.0+cpu -f https://download.pytorch.org/whl/torch \
  36 + onnx \
  37 + onnxruntime==1.17.1 \
  38 + librosa \
  39 + soundfile \
  40 + onnxsim
  41 +
  42 + curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.1.0%2B708089d1-cp310-cp310-linux_x86_64.whl
  43 + pip install ./*.whl "numpy<=1.26.4"
  44 +
  45 + - name: Run
  46 + shell: bash
  47 + run: |
  48 + cd scripts/silero_vad/v4
  49 + curl -SL -O https://github.com/snakers4/silero-vad/raw/refs/tags/v4.0/files/silero_vad.jit
  50 + ./export-onnx.py
  51 + ./show.py
  52 +
  53 + ls -lh m.onnx
  54 +
  55 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  56 + ./test-onnx.py --model ./m.onnx --wav ./lei-jun-test.wav
  57 +
  58 + for platform in rk3588 rk3576 rk3568 rk3566 rk3562; do
  59 + echo "Platform: $platform"
  60 + ./export-rknn.py --in-model ./m.onnx --out-model silero-vad-v4-$platform.rknn --target-platform $platform
  61 + ls -lh silero-vad-v4-$platform.rknn
  62 + done
  63 +
  64 + - name: Collect files
  65 + shell: bash
  66 + run: |
  67 + cd scripts/silero_vad/v4
  68 + ls -lh
  69 + mv *.rknn ../../..
  70 +
  71 + - name: Release
  72 + uses: svenstaro/upload-release-action@v2
  73 + with:
  74 + file_glob: true
  75 + file: ./*.rknn
  76 + overwrite: true
  77 + repo_name: k2-fsa/sherpa-onnx
  78 + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
  79 + tag: asr-models
  80 +
  81 + - name: Upload model to huggingface
  82 + env:
  83 + HF_TOKEN: ${{ secrets.HF_TOKEN }}
  84 + uses: nick-fields/retry@v3
  85 + with:
  86 + max_attempts: 20
  87 + timeout_seconds: 200
  88 + shell: bash
  89 + command: |
  90 + git config --global user.email "csukuangfj@gmail.com"
  91 + git config --global user.name "Fangjun Kuang"
  92 +
  93 + rm -rf huggingface
  94 + export GIT_LFS_SKIP_SMUDGE=1
  95 +
  96 + git clone https://huggingface.co/csukuangfj/sherpa-onnx-rknn-models huggingface
  97 + cd huggingface
  98 +
  99 + git fetch
  100 + git pull
  101 + git lfs track "*.rknn"
  102 + git merge -m "merge remote" --ff origin main
  103 + dst=vad
  104 + mkdir -p $dst
  105 + cp ../*.rknn $dst/ || true
  106 +
  107 + ls -lh $dst
  108 + git add .
  109 + git status
  110 + git commit -m "update models"
  111 + git status
  112 +
  113 + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-rknn-models main || true
  114 + rm -rf huggingface
@@ -136,6 +136,7 @@ kokoro-multi-lang-v1_0 @@ -136,6 +136,7 @@ kokoro-multi-lang-v1_0
136 sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 136 sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16
137 cmake-build-debug 137 cmake-build-debug
138 README-DEV.txt 138 README-DEV.txt
139 - 139 +*.rknn
  140 +*.jit
140 ##clion 141 ##clion
141 -.idea  
  142 +.idea
  1 +# Introduction
  2 +
  3 +This folder contains script for exporting
  4 +[silero_vad v4](https://github.com/snakers4/silero-vad/tree/v4.0)
  5 +to rknn.
  6 +
  7 +# Steps to run
  8 +
  9 +## 1. Download a jit model
  10 +You can download it from <https://github.com/snakers4/silero-vad/blob/v4.0/files/silero_vad.jit>
  11 +
  12 +```bash
  13 +wget https://github.com/snakers4/silero-vad/raw/refs/tags/v4.0/files/silero_vad.jit
  14 +```
  15 +
  16 +```bash
  17 +ls -lh silero_vad.jit
  18 +-rw-r--r-- 1 kuangfangjun root 1.4M Mar 30 11:04 silero_vad.jit
  19 +```
  20 +
  21 +## 2. Export it to onnx
  22 +```bash
  23 +./export-onnx.py
  24 +```
  25 +
  26 +It will generate a file `./m.onnx`
  27 +
  28 +```bash
  29 + ls -lh m.onnx
  30 +-rw-r--r-- 1 kuangfangjun root 627K Mar 30 11:13 m.onnx
  31 +```
  32 +
  33 +## 3. Test the onnx model
  34 +
  35 +```bash
  36 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav
  37 +./test-onnx.py --model ./m.onnx --wav ./lei-jun-test.wav
  38 +```
  39 +
  40 +## 4. Convert the onnx model to RKNN format
  41 +
  42 +We assume you have installed rknn toolkit 2.1
  43 +```bash
  44 +./export-rknn.py --in-model ./m.onnx --out-model m.rknn --target-platform rk3588
  45 +```
  46 +
  47 +It will generate a file `./m.rknn`
  48 +
  49 +```bash
  50 +ls -lh m.rknn
  51 +-rw-r--r-- 1 kuangfangjun root 2.2M Mar 30 11:19 m.rknn
  52 +```
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnx
  5 +import torch
  6 +from onnxsim import simplify
  7 +
  8 +
  9 +@torch.no_grad()
  10 +def main():
  11 + m = torch.jit.load("./silero_vad.jit")
  12 + x = torch.rand((1, 512), dtype=torch.float32)
  13 + h = torch.rand((2, 1, 64), dtype=torch.float32)
  14 + c = torch.rand((2, 1, 64), dtype=torch.float32)
  15 + torch.onnx.export(
  16 + m._model,
  17 + (x, h, c),
  18 + "m.onnx",
  19 + input_names=["x", "h", "c"],
  20 + output_names=["prob", "next_h", "next_c"],
  21 + )
  22 +
  23 + print("simplifying ...")
  24 + model = onnx.load("m.onnx")
  25 +
  26 + meta_data = {
  27 + "model_type": "silero-vad-v4",
  28 + "sample_rate": 16000,
  29 + "version": 4,
  30 + "h_shape": "2,1,64",
  31 + "c_shape": "2,1,64",
  32 + }
  33 +
  34 + while len(model.metadata_props):
  35 + model.metadata_props.pop()
  36 +
  37 + for key, value in meta_data.items():
  38 + meta = model.metadata_props.add()
  39 + meta.key = key
  40 + meta.value = str(value)
  41 + print("--------------------")
  42 + print(model.metadata_props)
  43 +
  44 + model_simp, check = simplify(model)
  45 + onnx.save(model_simp, "m.onnx")
  46 +
  47 +
  48 +if __name__ == "__main__":
  49 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
  3 +
  4 +import argparse
  5 +import logging
  6 +from pathlib import Path
  7 +
  8 +from rknn.api import RKNN
  9 +
  10 +logging.basicConfig(level=logging.WARNING)
  11 +
  12 +g_platforms = [
  13 + # "rv1103",
  14 + # "rv1103b",
  15 + # "rv1106",
  16 + # "rk2118",
  17 + "rk3562",
  18 + "rk3566",
  19 + "rk3568",
  20 + "rk3576",
  21 + "rk3588",
  22 +]
  23 +
  24 +
  25 +def get_parser():
  26 + parser = argparse.ArgumentParser(
  27 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  28 + )
  29 +
  30 + parser.add_argument(
  31 + "--target-platform",
  32 + type=str,
  33 + required=True,
  34 + help=f"Supported values are: {','.join(g_platforms)}",
  35 + )
  36 +
  37 + parser.add_argument(
  38 + "--in-model",
  39 + type=str,
  40 + required=True,
  41 + help="Path to the input onnx model",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--out-model",
  46 + type=str,
  47 + required=True,
  48 + help="Path to the output rknn model",
  49 + )
  50 +
  51 + return parser
  52 +
  53 +
  54 +def get_meta_data(model: str):
  55 + import onnxruntime
  56 +
  57 + session_opts = onnxruntime.SessionOptions()
  58 + session_opts.inter_op_num_threads = 1
  59 + session_opts.intra_op_num_threads = 1
  60 +
  61 + m = onnxruntime.InferenceSession(
  62 + model,
  63 + sess_options=session_opts,
  64 + providers=["CPUExecutionProvider"],
  65 + )
  66 +
  67 + for i in m.get_inputs():
  68 + print(i)
  69 +
  70 + print("-----")
  71 +
  72 + for i in m.get_outputs():
  73 + print(i)
  74 + print()
  75 +
  76 + meta = m.get_modelmeta().custom_metadata_map
  77 + s = ""
  78 + sep = ""
  79 + for key, value in meta.items():
  80 + s = s + sep + f"{key}={value}"
  81 + sep = ";"
  82 + assert len(s) < 1024
  83 +
  84 + return s
  85 +
  86 +
  87 +def export_rknn(rknn, filename):
  88 + ret = rknn.export_rknn(filename)
  89 + if ret != 0:
  90 + exit("Export rknn model to {filename} failed!")
  91 +
  92 +
  93 +def init_model(filename: str, target_platform: str, custom_string=None):
  94 + rknn = RKNN(verbose=False)
  95 +
  96 + rknn.config(
  97 + optimization_level=0,
  98 + target_platform=target_platform,
  99 + custom_string=custom_string,
  100 + )
  101 + if not Path(filename).is_file():
  102 + exit(f"{filename} does not exist")
  103 +
  104 + ret = rknn.load_onnx(model=filename)
  105 + if ret != 0:
  106 + exit(f"Load model {filename} failed!")
  107 +
  108 + ret = rknn.build(do_quantization=False)
  109 + if ret != 0:
  110 + exit("Build model {filename} failed!")
  111 +
  112 + return rknn
  113 +
  114 +
  115 +class RKNNModel:
  116 + def __init__(
  117 + self,
  118 + model: str,
  119 + target_platform: str,
  120 + ):
  121 + meta = get_meta_data(model)
  122 + print(meta)
  123 +
  124 + self.model = init_model(
  125 + model,
  126 + target_platform=target_platform,
  127 + custom_string=meta,
  128 + )
  129 +
  130 + def export_rknn(self, model):
  131 + export_rknn(self.model, model)
  132 +
  133 + def release(self):
  134 + self.model.release()
  135 +
  136 +
  137 +def main():
  138 + args = get_parser().parse_args()
  139 + print(vars(args))
  140 +
  141 + model = RKNNModel(
  142 + model=args.in_model,
  143 + target_platform=args.target_platform,
  144 + )
  145 +
  146 + model.export_rknn(
  147 + model=args.out_model,
  148 + )
  149 +
  150 + model.release()
  151 +
  152 +
  153 +if __name__ == "__main__":
  154 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +
  4 +import onnxruntime
  5 +import onnx
  6 +
  7 +"""
  8 +[key: "model_type"
  9 +value: "silero-vad-v4"
  10 +, key: "sample_rate"
  11 +value: "16000"
  12 +, key: "version"
  13 +value: "4"
  14 +, key: "h_shape"
  15 +value: "2,1,64"
  16 +, key: "c_shape"
  17 +value: "2,1,64"
  18 +]
  19 +NodeArg(name='x', type='tensor(float)', shape=[1, 512])
  20 +NodeArg(name='h', type='tensor(float)', shape=[2, 1, 64])
  21 +NodeArg(name='c', type='tensor(float)', shape=[2, 1, 64])
  22 +-----
  23 +NodeArg(name='prob', type='tensor(float)', shape=[1, 1])
  24 +NodeArg(name='next_h', type='tensor(float)', shape=[2, 1, 64])
  25 +NodeArg(name='next_c', type='tensor(float)', shape=[2, 1, 64])
  26 +"""
  27 +
  28 +
  29 +def show(filename):
  30 + model = onnx.load(filename)
  31 + print(model.metadata_props)
  32 +
  33 + session_opts = onnxruntime.SessionOptions()
  34 + session_opts.log_severity_level = 3
  35 + sess = onnxruntime.InferenceSession(
  36 + filename, session_opts, providers=["CPUExecutionProvider"]
  37 + )
  38 + for i in sess.get_inputs():
  39 + print(i)
  40 +
  41 + print("-----")
  42 +
  43 + for i in sess.get_outputs():
  44 + print(i)
  45 +
  46 +
  47 +def main():
  48 + show("./m.onnx")
  49 +
  50 +
  51 +if __name__ == "__main__":
  52 + main()
  1 +#!/usr/bin/env python3
  2 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
  3 +import onnxruntime as ort
  4 +import argparse
  5 +import soundfile as sf
  6 +from typing import Tuple
  7 +import numpy as np
  8 +
  9 +
  10 +def get_args():
  11 + parser = argparse.ArgumentParser()
  12 + parser.add_argument(
  13 + "--model",
  14 + type=str,
  15 + required=True,
  16 + help="Path to the onnx model",
  17 + )
  18 +
  19 + parser.add_argument(
  20 + "--wav",
  21 + type=str,
  22 + required=True,
  23 + help="Path to the input wav",
  24 + )
  25 + return parser.parse_args()
  26 +
  27 +
  28 +class OnnxModel:
  29 + def __init__(
  30 + self,
  31 + model: str,
  32 + ):
  33 + session_opts = ort.SessionOptions()
  34 + session_opts.inter_op_num_threads = 1
  35 + session_opts.intra_op_num_threads = 1
  36 + self.model = ort.InferenceSession(
  37 + model,
  38 + sess_options=session_opts,
  39 + providers=["CPUExecutionProvider"],
  40 + )
  41 +
  42 + def get_init_states(self):
  43 + h = np.zeros((2, 1, 64), dtype=np.float32)
  44 + c = np.zeros((2, 1, 64), dtype=np.float32)
  45 + return h, c
  46 +
  47 + def __call__(self, x, h, c):
  48 + """
  49 + Args:
  50 + x: (1, 512)
  51 + h: (2, 1, 64)
  52 + c: (2, 1, 64)
  53 + Returns:
  54 + prob: (1, 1)
  55 + next_h: (2, 1, 64)
  56 + next_c: (2, 1, 64)
  57 + """
  58 + x = x[None]
  59 + out, next_h, next_c = self.model.run(
  60 + [
  61 + self.model.get_outputs()[0].name,
  62 + self.model.get_outputs()[1].name,
  63 + self.model.get_outputs()[2].name,
  64 + ],
  65 + {
  66 + self.model.get_inputs()[0].name: x,
  67 + self.model.get_inputs()[1].name: h,
  68 + self.model.get_inputs()[2].name: c,
  69 + },
  70 + )
  71 + return out, next_h, next_c
  72 +
  73 +
  74 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  75 + data, sample_rate = sf.read(
  76 + filename,
  77 + always_2d=True,
  78 + dtype="float32",
  79 + )
  80 + data = data[:, 0] # use only the first channel
  81 + samples = np.ascontiguousarray(data)
  82 + return samples, sample_rate
  83 +
  84 +
  85 +def main():
  86 + args = get_args()
  87 +
  88 + samples, sample_rate = load_audio(args.wav)
  89 + if sample_rate != 16000:
  90 + import librosa
  91 +
  92 + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
  93 + sample_rate = 16000
  94 +
  95 + model = OnnxModel(args.model)
  96 + probs = []
  97 + h, c = model.get_init_states()
  98 + window_size = 512
  99 + num_windows = samples.shape[0] // window_size
  100 + for i in range(num_windows):
  101 + start = i * window_size
  102 + end = start + window_size
  103 + p, h, c = model(samples[start:end], h, c)
  104 + probs.append(p[0].item())
  105 +
  106 + threshold = 0.5
  107 + out = np.array(probs) > threshold
  108 + out = out.tolist()
  109 + min_speech_duration = 0.25 * sample_rate / window_size
  110 + min_silence_duration = 0.25 * sample_rate / window_size
  111 +
  112 + result = []
  113 + last = -1
  114 + for k, f in enumerate(out):
  115 + if f >= threshold:
  116 + if last == -1:
  117 + last = k
  118 + elif last != -1:
  119 + if k - last > min_speech_duration:
  120 + result.append((last, k))
  121 + last = -1
  122 +
  123 + if last != -1 and k - last > min_speech_duration:
  124 + result.append((last, k))
  125 +
  126 + if not result:
  127 + print(f"Empty for {args.wav}")
  128 + return
  129 +
  130 + print(result)
  131 +
  132 + final = [result[0]]
  133 + for r in result[1:]:
  134 + f = final[-1]
  135 + if r[0] - f[1] < min_silence_duration:
  136 + final[-1] = (f[0], r[1])
  137 + else:
  138 + final.append(r)
  139 +
  140 + for f in final:
  141 + start = f[0] * window_size / sample_rate
  142 + end = f[1] * window_size / sample_rate
  143 + print("{:.3f} -- {:.3f}".format(start, end))
  144 +
  145 +
  146 +if __name__ == "__main__":
  147 + main()