正在显示
7 个修改的文件
包含
571 行增加
和
2 行删除
| 1 | +name: export-silero-vad-to-rknn | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + workflow_dispatch: | ||
| 5 | + | ||
| 6 | +concurrency: | ||
| 7 | + group: export-silero-vad-to-rknn-${{ github.ref }} | ||
| 8 | + cancel-in-progress: true | ||
| 9 | + | ||
| 10 | +jobs: | ||
| 11 | + export-silero-vad-to-rknn: | ||
| 12 | + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
| 13 | + name: export silero-vad to rknn | ||
| 14 | + runs-on: ${{ matrix.os }} | ||
| 15 | + strategy: | ||
| 16 | + fail-fast: false | ||
| 17 | + matrix: | ||
| 18 | + os: [ubuntu-latest] | ||
| 19 | + python-version: ["3.10"] | ||
| 20 | + | ||
| 21 | + steps: | ||
| 22 | + - uses: actions/checkout@v4 | ||
| 23 | + | ||
| 24 | + - name: Setup Python ${{ matrix.python-version }} | ||
| 25 | + uses: actions/setup-python@v5 | ||
| 26 | + with: | ||
| 27 | + python-version: ${{ matrix.python-version }} | ||
| 28 | + | ||
| 29 | + - name: Install Python dependencies | ||
| 30 | + shell: bash | ||
| 31 | + run: | | ||
| 32 | + python3 -m pip install --upgrade \ | ||
| 33 | + pip \ | ||
| 34 | + "numpy<2" \ | ||
| 35 | + torch==2.0.0+cpu -f https://download.pytorch.org/whl/torch \ | ||
| 36 | + onnx \ | ||
| 37 | + onnxruntime==1.17.1 \ | ||
| 38 | + librosa \ | ||
| 39 | + soundfile \ | ||
| 40 | + onnxsim | ||
| 41 | + | ||
| 42 | + curl -SL -O https://huggingface.co/csukuangfj/rknn-toolkit2/resolve/main/rknn_toolkit2-2.1.0%2B708089d1-cp310-cp310-linux_x86_64.whl | ||
| 43 | + pip install ./*.whl "numpy<=1.26.4" | ||
| 44 | + | ||
| 45 | + - name: Run | ||
| 46 | + shell: bash | ||
| 47 | + run: | | ||
| 48 | + cd scripts/silero_vad/v4 | ||
| 49 | + curl -SL -O https://github.com/snakers4/silero-vad/raw/refs/tags/v4.0/files/silero_vad.jit | ||
| 50 | + ./export-onnx.py | ||
| 51 | + ./show.py | ||
| 52 | + | ||
| 53 | + ls -lh m.onnx | ||
| 54 | + | ||
| 55 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav | ||
| 56 | + ./test-onnx.py --model ./m.onnx --wav ./lei-jun-test.wav | ||
| 57 | + | ||
| 58 | + for platform in rk3588 rk3576 rk3568 rk3566 rk3562; do | ||
| 59 | + echo "Platform: $platform" | ||
| 60 | + ./export-rknn.py --in-model ./m.onnx --out-model silero-vad-v4-$platform.rknn --target-platform $platform | ||
| 61 | + ls -lh silero-vad-v4-$platform.rknn | ||
| 62 | + done | ||
| 63 | + | ||
| 64 | + - name: Collect files | ||
| 65 | + shell: bash | ||
| 66 | + run: | | ||
| 67 | + cd scripts/silero_vad/v4 | ||
| 68 | + ls -lh | ||
| 69 | + mv *.rknn ../../.. | ||
| 70 | + | ||
| 71 | + - name: Release | ||
| 72 | + uses: svenstaro/upload-release-action@v2 | ||
| 73 | + with: | ||
| 74 | + file_glob: true | ||
| 75 | + file: ./*.rknn | ||
| 76 | + overwrite: true | ||
| 77 | + repo_name: k2-fsa/sherpa-onnx | ||
| 78 | + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
| 79 | + tag: asr-models | ||
| 80 | + | ||
| 81 | + - name: Upload model to huggingface | ||
| 82 | + env: | ||
| 83 | + HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
| 84 | + uses: nick-fields/retry@v3 | ||
| 85 | + with: | ||
| 86 | + max_attempts: 20 | ||
| 87 | + timeout_seconds: 200 | ||
| 88 | + shell: bash | ||
| 89 | + command: | | ||
| 90 | + git config --global user.email "csukuangfj@gmail.com" | ||
| 91 | + git config --global user.name "Fangjun Kuang" | ||
| 92 | + | ||
| 93 | + rm -rf huggingface | ||
| 94 | + export GIT_LFS_SKIP_SMUDGE=1 | ||
| 95 | + | ||
| 96 | + git clone https://huggingface.co/csukuangfj/sherpa-onnx-rknn-models huggingface | ||
| 97 | + cd huggingface | ||
| 98 | + | ||
| 99 | + git fetch | ||
| 100 | + git pull | ||
| 101 | + git lfs track "*.rknn" | ||
| 102 | + git merge -m "merge remote" --ff origin main | ||
| 103 | + dst=vad | ||
| 104 | + mkdir -p $dst | ||
| 105 | + cp ../*.rknn $dst/ || true | ||
| 106 | + | ||
| 107 | + ls -lh $dst | ||
| 108 | + git add . | ||
| 109 | + git status | ||
| 110 | + git commit -m "update models" | ||
| 111 | + git status | ||
| 112 | + | ||
| 113 | + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-rknn-models main || true | ||
| 114 | + rm -rf huggingface |
| @@ -136,6 +136,7 @@ kokoro-multi-lang-v1_0 | @@ -136,6 +136,7 @@ kokoro-multi-lang-v1_0 | ||
| 136 | sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 | 136 | sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 |
| 137 | cmake-build-debug | 137 | cmake-build-debug |
| 138 | README-DEV.txt | 138 | README-DEV.txt |
| 139 | - | 139 | +*.rknn |
| 140 | +*.jit | ||
| 140 | ##clion | 141 | ##clion |
| 141 | -.idea | ||
| 142 | +.idea |
scripts/silero_vad/v4/README.md
0 → 100644
| 1 | +# Introduction | ||
| 2 | + | ||
| 3 | +This folder contains script for exporting | ||
| 4 | +[silero_vad v4](https://github.com/snakers4/silero-vad/tree/v4.0) | ||
| 5 | +to rknn. | ||
| 6 | + | ||
| 7 | +# Steps to run | ||
| 8 | + | ||
| 9 | +## 1. Download a jit model | ||
| 10 | +You can download it from <https://github.com/snakers4/silero-vad/blob/v4.0/files/silero_vad.jit> | ||
| 11 | + | ||
| 12 | +```bash | ||
| 13 | +wget https://github.com/snakers4/silero-vad/raw/refs/tags/v4.0/files/silero_vad.jit | ||
| 14 | +``` | ||
| 15 | + | ||
| 16 | +```bash | ||
| 17 | +ls -lh silero_vad.jit | ||
| 18 | +-rw-r--r-- 1 kuangfangjun root 1.4M Mar 30 11:04 silero_vad.jit | ||
| 19 | +``` | ||
| 20 | + | ||
| 21 | +## 2. Export it to onnx | ||
| 22 | +```bash | ||
| 23 | +./export-onnx.py | ||
| 24 | +``` | ||
| 25 | + | ||
| 26 | +It will generate a file `./m.onnx` | ||
| 27 | + | ||
| 28 | +```bash | ||
| 29 | + ls -lh m.onnx | ||
| 30 | +-rw-r--r-- 1 kuangfangjun root 627K Mar 30 11:13 m.onnx | ||
| 31 | +``` | ||
| 32 | + | ||
| 33 | +## 3. Test the onnx model | ||
| 34 | + | ||
| 35 | +```bash | ||
| 36 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav | ||
| 37 | +./test-onnx.py --model ./m.onnx --wav ./lei-jun-test.wav | ||
| 38 | +``` | ||
| 39 | + | ||
| 40 | +## 4. Convert the onnx model to RKNN format | ||
| 41 | + | ||
| 42 | +We assume you have installed rknn toolkit 2.1 | ||
| 43 | +```bash | ||
| 44 | +./export-rknn.py --in-model ./m.onnx --out-model m.rknn --target-platform rk3588 | ||
| 45 | +``` | ||
| 46 | + | ||
| 47 | +It will generate a file `./m.rknn` | ||
| 48 | + | ||
| 49 | +```bash | ||
| 50 | +ls -lh m.rknn | ||
| 51 | +-rw-r--r-- 1 kuangfangjun root 2.2M Mar 30 11:19 m.rknn | ||
| 52 | +``` |
scripts/silero_vad/v4/export-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import onnx | ||
| 5 | +import torch | ||
| 6 | +from onnxsim import simplify | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +@torch.no_grad() | ||
| 10 | +def main(): | ||
| 11 | + m = torch.jit.load("./silero_vad.jit") | ||
| 12 | + x = torch.rand((1, 512), dtype=torch.float32) | ||
| 13 | + h = torch.rand((2, 1, 64), dtype=torch.float32) | ||
| 14 | + c = torch.rand((2, 1, 64), dtype=torch.float32) | ||
| 15 | + torch.onnx.export( | ||
| 16 | + m._model, | ||
| 17 | + (x, h, c), | ||
| 18 | + "m.onnx", | ||
| 19 | + input_names=["x", "h", "c"], | ||
| 20 | + output_names=["prob", "next_h", "next_c"], | ||
| 21 | + ) | ||
| 22 | + | ||
| 23 | + print("simplifying ...") | ||
| 24 | + model = onnx.load("m.onnx") | ||
| 25 | + | ||
| 26 | + meta_data = { | ||
| 27 | + "model_type": "silero-vad-v4", | ||
| 28 | + "sample_rate": 16000, | ||
| 29 | + "version": 4, | ||
| 30 | + "h_shape": "2,1,64", | ||
| 31 | + "c_shape": "2,1,64", | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + while len(model.metadata_props): | ||
| 35 | + model.metadata_props.pop() | ||
| 36 | + | ||
| 37 | + for key, value in meta_data.items(): | ||
| 38 | + meta = model.metadata_props.add() | ||
| 39 | + meta.key = key | ||
| 40 | + meta.value = str(value) | ||
| 41 | + print("--------------------") | ||
| 42 | + print(model.metadata_props) | ||
| 43 | + | ||
| 44 | + model_simp, check = simplify(model) | ||
| 45 | + onnx.save(model_simp, "m.onnx") | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +if __name__ == "__main__": | ||
| 49 | + main() |
scripts/silero_vad/v4/export-rknn.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import argparse | ||
| 5 | +import logging | ||
| 6 | +from pathlib import Path | ||
| 7 | + | ||
| 8 | +from rknn.api import RKNN | ||
| 9 | + | ||
| 10 | +logging.basicConfig(level=logging.WARNING) | ||
| 11 | + | ||
| 12 | +g_platforms = [ | ||
| 13 | + # "rv1103", | ||
| 14 | + # "rv1103b", | ||
| 15 | + # "rv1106", | ||
| 16 | + # "rk2118", | ||
| 17 | + "rk3562", | ||
| 18 | + "rk3566", | ||
| 19 | + "rk3568", | ||
| 20 | + "rk3576", | ||
| 21 | + "rk3588", | ||
| 22 | +] | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def get_parser(): | ||
| 26 | + parser = argparse.ArgumentParser( | ||
| 27 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 28 | + ) | ||
| 29 | + | ||
| 30 | + parser.add_argument( | ||
| 31 | + "--target-platform", | ||
| 32 | + type=str, | ||
| 33 | + required=True, | ||
| 34 | + help=f"Supported values are: {','.join(g_platforms)}", | ||
| 35 | + ) | ||
| 36 | + | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--in-model", | ||
| 39 | + type=str, | ||
| 40 | + required=True, | ||
| 41 | + help="Path to the input onnx model", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument( | ||
| 45 | + "--out-model", | ||
| 46 | + type=str, | ||
| 47 | + required=True, | ||
| 48 | + help="Path to the output rknn model", | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + return parser | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +def get_meta_data(model: str): | ||
| 55 | + import onnxruntime | ||
| 56 | + | ||
| 57 | + session_opts = onnxruntime.SessionOptions() | ||
| 58 | + session_opts.inter_op_num_threads = 1 | ||
| 59 | + session_opts.intra_op_num_threads = 1 | ||
| 60 | + | ||
| 61 | + m = onnxruntime.InferenceSession( | ||
| 62 | + model, | ||
| 63 | + sess_options=session_opts, | ||
| 64 | + providers=["CPUExecutionProvider"], | ||
| 65 | + ) | ||
| 66 | + | ||
| 67 | + for i in m.get_inputs(): | ||
| 68 | + print(i) | ||
| 69 | + | ||
| 70 | + print("-----") | ||
| 71 | + | ||
| 72 | + for i in m.get_outputs(): | ||
| 73 | + print(i) | ||
| 74 | + print() | ||
| 75 | + | ||
| 76 | + meta = m.get_modelmeta().custom_metadata_map | ||
| 77 | + s = "" | ||
| 78 | + sep = "" | ||
| 79 | + for key, value in meta.items(): | ||
| 80 | + s = s + sep + f"{key}={value}" | ||
| 81 | + sep = ";" | ||
| 82 | + assert len(s) < 1024 | ||
| 83 | + | ||
| 84 | + return s | ||
| 85 | + | ||
| 86 | + | ||
| 87 | +def export_rknn(rknn, filename): | ||
| 88 | + ret = rknn.export_rknn(filename) | ||
| 89 | + if ret != 0: | ||
| 90 | + exit("Export rknn model to {filename} failed!") | ||
| 91 | + | ||
| 92 | + | ||
| 93 | +def init_model(filename: str, target_platform: str, custom_string=None): | ||
| 94 | + rknn = RKNN(verbose=False) | ||
| 95 | + | ||
| 96 | + rknn.config( | ||
| 97 | + optimization_level=0, | ||
| 98 | + target_platform=target_platform, | ||
| 99 | + custom_string=custom_string, | ||
| 100 | + ) | ||
| 101 | + if not Path(filename).is_file(): | ||
| 102 | + exit(f"{filename} does not exist") | ||
| 103 | + | ||
| 104 | + ret = rknn.load_onnx(model=filename) | ||
| 105 | + if ret != 0: | ||
| 106 | + exit(f"Load model {filename} failed!") | ||
| 107 | + | ||
| 108 | + ret = rknn.build(do_quantization=False) | ||
| 109 | + if ret != 0: | ||
| 110 | + exit("Build model {filename} failed!") | ||
| 111 | + | ||
| 112 | + return rknn | ||
| 113 | + | ||
| 114 | + | ||
| 115 | +class RKNNModel: | ||
| 116 | + def __init__( | ||
| 117 | + self, | ||
| 118 | + model: str, | ||
| 119 | + target_platform: str, | ||
| 120 | + ): | ||
| 121 | + meta = get_meta_data(model) | ||
| 122 | + print(meta) | ||
| 123 | + | ||
| 124 | + self.model = init_model( | ||
| 125 | + model, | ||
| 126 | + target_platform=target_platform, | ||
| 127 | + custom_string=meta, | ||
| 128 | + ) | ||
| 129 | + | ||
| 130 | + def export_rknn(self, model): | ||
| 131 | + export_rknn(self.model, model) | ||
| 132 | + | ||
| 133 | + def release(self): | ||
| 134 | + self.model.release() | ||
| 135 | + | ||
| 136 | + | ||
| 137 | +def main(): | ||
| 138 | + args = get_parser().parse_args() | ||
| 139 | + print(vars(args)) | ||
| 140 | + | ||
| 141 | + model = RKNNModel( | ||
| 142 | + model=args.in_model, | ||
| 143 | + target_platform=args.target_platform, | ||
| 144 | + ) | ||
| 145 | + | ||
| 146 | + model.export_rknn( | ||
| 147 | + model=args.out_model, | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + model.release() | ||
| 151 | + | ||
| 152 | + | ||
| 153 | +if __name__ == "__main__": | ||
| 154 | + main() |
scripts/silero_vad/v4/show.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | + | ||
| 4 | +import onnxruntime | ||
| 5 | +import onnx | ||
| 6 | + | ||
| 7 | +""" | ||
| 8 | +[key: "model_type" | ||
| 9 | +value: "silero-vad-v4" | ||
| 10 | +, key: "sample_rate" | ||
| 11 | +value: "16000" | ||
| 12 | +, key: "version" | ||
| 13 | +value: "4" | ||
| 14 | +, key: "h_shape" | ||
| 15 | +value: "2,1,64" | ||
| 16 | +, key: "c_shape" | ||
| 17 | +value: "2,1,64" | ||
| 18 | +] | ||
| 19 | +NodeArg(name='x', type='tensor(float)', shape=[1, 512]) | ||
| 20 | +NodeArg(name='h', type='tensor(float)', shape=[2, 1, 64]) | ||
| 21 | +NodeArg(name='c', type='tensor(float)', shape=[2, 1, 64]) | ||
| 22 | +----- | ||
| 23 | +NodeArg(name='prob', type='tensor(float)', shape=[1, 1]) | ||
| 24 | +NodeArg(name='next_h', type='tensor(float)', shape=[2, 1, 64]) | ||
| 25 | +NodeArg(name='next_c', type='tensor(float)', shape=[2, 1, 64]) | ||
| 26 | +""" | ||
| 27 | + | ||
| 28 | + | ||
| 29 | +def show(filename): | ||
| 30 | + model = onnx.load(filename) | ||
| 31 | + print(model.metadata_props) | ||
| 32 | + | ||
| 33 | + session_opts = onnxruntime.SessionOptions() | ||
| 34 | + session_opts.log_severity_level = 3 | ||
| 35 | + sess = onnxruntime.InferenceSession( | ||
| 36 | + filename, session_opts, providers=["CPUExecutionProvider"] | ||
| 37 | + ) | ||
| 38 | + for i in sess.get_inputs(): | ||
| 39 | + print(i) | ||
| 40 | + | ||
| 41 | + print("-----") | ||
| 42 | + | ||
| 43 | + for i in sess.get_outputs(): | ||
| 44 | + print(i) | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def main(): | ||
| 48 | + show("./m.onnx") | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +if __name__ == "__main__": | ||
| 52 | + main() |
scripts/silero_vad/v4/test-onnx.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) | ||
| 3 | +import onnxruntime as ort | ||
| 4 | +import argparse | ||
| 5 | +import soundfile as sf | ||
| 6 | +from typing import Tuple | ||
| 7 | +import numpy as np | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def get_args(): | ||
| 11 | + parser = argparse.ArgumentParser() | ||
| 12 | + parser.add_argument( | ||
| 13 | + "--model", | ||
| 14 | + type=str, | ||
| 15 | + required=True, | ||
| 16 | + help="Path to the onnx model", | ||
| 17 | + ) | ||
| 18 | + | ||
| 19 | + parser.add_argument( | ||
| 20 | + "--wav", | ||
| 21 | + type=str, | ||
| 22 | + required=True, | ||
| 23 | + help="Path to the input wav", | ||
| 24 | + ) | ||
| 25 | + return parser.parse_args() | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +class OnnxModel: | ||
| 29 | + def __init__( | ||
| 30 | + self, | ||
| 31 | + model: str, | ||
| 32 | + ): | ||
| 33 | + session_opts = ort.SessionOptions() | ||
| 34 | + session_opts.inter_op_num_threads = 1 | ||
| 35 | + session_opts.intra_op_num_threads = 1 | ||
| 36 | + self.model = ort.InferenceSession( | ||
| 37 | + model, | ||
| 38 | + sess_options=session_opts, | ||
| 39 | + providers=["CPUExecutionProvider"], | ||
| 40 | + ) | ||
| 41 | + | ||
| 42 | + def get_init_states(self): | ||
| 43 | + h = np.zeros((2, 1, 64), dtype=np.float32) | ||
| 44 | + c = np.zeros((2, 1, 64), dtype=np.float32) | ||
| 45 | + return h, c | ||
| 46 | + | ||
| 47 | + def __call__(self, x, h, c): | ||
| 48 | + """ | ||
| 49 | + Args: | ||
| 50 | + x: (1, 512) | ||
| 51 | + h: (2, 1, 64) | ||
| 52 | + c: (2, 1, 64) | ||
| 53 | + Returns: | ||
| 54 | + prob: (1, 1) | ||
| 55 | + next_h: (2, 1, 64) | ||
| 56 | + next_c: (2, 1, 64) | ||
| 57 | + """ | ||
| 58 | + x = x[None] | ||
| 59 | + out, next_h, next_c = self.model.run( | ||
| 60 | + [ | ||
| 61 | + self.model.get_outputs()[0].name, | ||
| 62 | + self.model.get_outputs()[1].name, | ||
| 63 | + self.model.get_outputs()[2].name, | ||
| 64 | + ], | ||
| 65 | + { | ||
| 66 | + self.model.get_inputs()[0].name: x, | ||
| 67 | + self.model.get_inputs()[1].name: h, | ||
| 68 | + self.model.get_inputs()[2].name: c, | ||
| 69 | + }, | ||
| 70 | + ) | ||
| 71 | + return out, next_h, next_c | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +def load_audio(filename: str) -> Tuple[np.ndarray, int]: | ||
| 75 | + data, sample_rate = sf.read( | ||
| 76 | + filename, | ||
| 77 | + always_2d=True, | ||
| 78 | + dtype="float32", | ||
| 79 | + ) | ||
| 80 | + data = data[:, 0] # use only the first channel | ||
| 81 | + samples = np.ascontiguousarray(data) | ||
| 82 | + return samples, sample_rate | ||
| 83 | + | ||
| 84 | + | ||
| 85 | +def main(): | ||
| 86 | + args = get_args() | ||
| 87 | + | ||
| 88 | + samples, sample_rate = load_audio(args.wav) | ||
| 89 | + if sample_rate != 16000: | ||
| 90 | + import librosa | ||
| 91 | + | ||
| 92 | + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) | ||
| 93 | + sample_rate = 16000 | ||
| 94 | + | ||
| 95 | + model = OnnxModel(args.model) | ||
| 96 | + probs = [] | ||
| 97 | + h, c = model.get_init_states() | ||
| 98 | + window_size = 512 | ||
| 99 | + num_windows = samples.shape[0] // window_size | ||
| 100 | + for i in range(num_windows): | ||
| 101 | + start = i * window_size | ||
| 102 | + end = start + window_size | ||
| 103 | + p, h, c = model(samples[start:end], h, c) | ||
| 104 | + probs.append(p[0].item()) | ||
| 105 | + | ||
| 106 | + threshold = 0.5 | ||
| 107 | + out = np.array(probs) > threshold | ||
| 108 | + out = out.tolist() | ||
| 109 | + min_speech_duration = 0.25 * sample_rate / window_size | ||
| 110 | + min_silence_duration = 0.25 * sample_rate / window_size | ||
| 111 | + | ||
| 112 | + result = [] | ||
| 113 | + last = -1 | ||
| 114 | + for k, f in enumerate(out): | ||
| 115 | + if f >= threshold: | ||
| 116 | + if last == -1: | ||
| 117 | + last = k | ||
| 118 | + elif last != -1: | ||
| 119 | + if k - last > min_speech_duration: | ||
| 120 | + result.append((last, k)) | ||
| 121 | + last = -1 | ||
| 122 | + | ||
| 123 | + if last != -1 and k - last > min_speech_duration: | ||
| 124 | + result.append((last, k)) | ||
| 125 | + | ||
| 126 | + if not result: | ||
| 127 | + print(f"Empty for {args.wav}") | ||
| 128 | + return | ||
| 129 | + | ||
| 130 | + print(result) | ||
| 131 | + | ||
| 132 | + final = [result[0]] | ||
| 133 | + for r in result[1:]: | ||
| 134 | + f = final[-1] | ||
| 135 | + if r[0] - f[1] < min_silence_duration: | ||
| 136 | + final[-1] = (f[0], r[1]) | ||
| 137 | + else: | ||
| 138 | + final.append(r) | ||
| 139 | + | ||
| 140 | + for f in final: | ||
| 141 | + start = f[0] * window_size / sample_rate | ||
| 142 | + end = f[1] * window_size / sample_rate | ||
| 143 | + print("{:.3f} -- {:.3f}".format(start, end)) | ||
| 144 | + | ||
| 145 | + | ||
| 146 | +if __name__ == "__main__": | ||
| 147 | + main() |
-
请 注册 或 登录 后发表评论