Fangjun Kuang
Committed by GitHub

Add VAD + Non-streaming ASR Python example. (#332)

@@ -7,3 +7,6 @@ @@ -7,3 +7,6 @@
7 - [vad-remove-non-speech-segments.py](./vad-remove-non-speech-segments.py) It uses 7 - [vad-remove-non-speech-segments.py](./vad-remove-non-speech-segments.py) It uses
8 [silero-vad](https://github.com/snakers4/silero-vad) to remove non-speech 8 [silero-vad](https://github.com/snakers4/silero-vad) to remove non-speech
9 segments and concatenate all speech segments into a single one. 9 segments and concatenate all speech segments into a single one.
  10 +- [vad-with-non-streaming-asr.py](./vad-with-non-streaming-asr.py) It shows
  11 + how to use VAD with a non-streaming ASR model for speech recognition from
  12 + a microphone
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +This file demonstrates how to use sherpa-onnx Python APIs
  7 +with VAD and non-streaming ASR models for speech recognition
  8 +from a microphone.
  9 +
  10 +Note that you need a non-streaming model for this script.
  11 +
  12 +(1) For paraformer
  13 +
  14 + ./python-api-examples/vad-with-non-streaming-asr.py \
  15 + --silero-vad-model=/path/to/silero_vad.onnx \
  16 + --tokens=/path/to/tokens.txt \
  17 + --paraformer=/path/to/paraformer.onnx \
  18 + --num-threads=2 \
  19 + --decoding-method=greedy_search \
  20 + --debug=false \
  21 + --sample-rate=16000 \
  22 + --feature-dim=80
  23 +
  24 +(2) For transducer models from icefall
  25 +
  26 + ./python-api-examples/vad-with-non-streaming-asr.py \
  27 + --silero-vad-model=/path/to/silero_vad.onnx \
  28 + --tokens=/path/to/tokens.txt \
  29 + --encoder=/path/to/encoder.onnx \
  30 + --decoder=/path/to/decoder.onnx \
  31 + --joiner=/path/to/joiner.onnx \
  32 + --num-threads=2 \
  33 + --decoding-method=greedy_search \
  34 + --debug=false \
  35 + --sample-rate=16000 \
  36 + --feature-dim=80
  37 +
  38 +(3) For Whisper models
  39 +
  40 +./python-api-examples/vad-with-non-streaming-asr.py \
  41 + --silero-vad-model=/path/to/silero_vad.onnx \
  42 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  43 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  44 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  45 + --whisper-task=transcribe \
  46 + --num-threads=2
  47 +
  48 +Please refer to
  49 +https://k2-fsa.github.io/sherpa/onnx/index.html
  50 +to install sherpa-onnx and to download non-streaming pre-trained models
  51 +used in this file.
  52 +
  53 +Please visit
  54 +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
  55 +to download silero_vad.onnx
  56 +
  57 +For instance,
  58 +
  59 +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
  60 +"""
  61 +import argparse
  62 +import sys
  63 +from pathlib import Path
  64 +
  65 +import numpy as np
  66 +
  67 +try:
  68 + import sounddevice as sd
  69 +except ImportError:
  70 + print("Please install sounddevice first. You can use")
  71 + print()
  72 + print(" pip install sounddevice")
  73 + print()
  74 + print("to install it")
  75 + sys.exit(-1)
  76 +
  77 +import sherpa_onnx
  78 +
  79 +
  80 +def get_args():
  81 + parser = argparse.ArgumentParser(
  82 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  83 + )
  84 +
  85 + parser.add_argument(
  86 + "--silero-vad-model",
  87 + type=str,
  88 + required=True,
  89 + help="Path to silero_vad.onnx",
  90 + )
  91 +
  92 + parser.add_argument(
  93 + "--tokens",
  94 + type=str,
  95 + help="Path to tokens.txt",
  96 + )
  97 +
  98 + parser.add_argument(
  99 + "--encoder",
  100 + default="",
  101 + type=str,
  102 + help="Path to the transducer encoder model",
  103 + )
  104 +
  105 + parser.add_argument(
  106 + "--decoder",
  107 + default="",
  108 + type=str,
  109 + help="Path to the transducer decoder model",
  110 + )
  111 +
  112 + parser.add_argument(
  113 + "--joiner",
  114 + default="",
  115 + type=str,
  116 + help="Path to the transducer joiner model",
  117 + )
  118 +
  119 + parser.add_argument(
  120 + "--paraformer",
  121 + default="",
  122 + type=str,
  123 + help="Path to the model.onnx from Paraformer",
  124 + )
  125 +
  126 + parser.add_argument(
  127 + "--num-threads",
  128 + type=int,
  129 + default=1,
  130 + help="Number of threads for neural network computation",
  131 + )
  132 +
  133 + parser.add_argument(
  134 + "--whisper-encoder",
  135 + default="",
  136 + type=str,
  137 + help="Path to whisper encoder model",
  138 + )
  139 +
  140 + parser.add_argument(
  141 + "--whisper-decoder",
  142 + default="",
  143 + type=str,
  144 + help="Path to whisper decoder model",
  145 + )
  146 +
  147 + parser.add_argument(
  148 + "--whisper-language",
  149 + default="",
  150 + type=str,
  151 + help="""It specifies the spoken language in the input file.
  152 + Example values: en, fr, de, zh, jp.
  153 + Available languages for multilingual models can be found at
  154 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  155 + If not specified, we infer the language from the input audio file.
  156 + """,
  157 + )
  158 +
  159 + parser.add_argument(
  160 + "--whisper-task",
  161 + default="transcribe",
  162 + choices=["transcribe", "translate"],
  163 + type=str,
  164 + help="""For multilingual models, if you specify translate, the output
  165 + will be in English.
  166 + """,
  167 + )
  168 +
  169 + parser.add_argument(
  170 + "--decoding-method",
  171 + type=str,
  172 + default="greedy_search",
  173 + help="""Valid values are greedy_search and modified_beam_search.
  174 + modified_beam_search is valid only for transducer models.
  175 + """,
  176 + )
  177 + parser.add_argument(
  178 + "--debug",
  179 + type=bool,
  180 + default=False,
  181 + help="True to show debug messages when loading modes.",
  182 + )
  183 +
  184 + parser.add_argument(
  185 + "--sample-rate",
  186 + type=int,
  187 + default=16000,
  188 + help="""Sample rate of the feature extractor. Must match the one
  189 + expected by the model.""",
  190 + )
  191 +
  192 + parser.add_argument(
  193 + "--feature-dim",
  194 + type=int,
  195 + default=80,
  196 + help="Feature dimension. Must match the one expected by the model",
  197 + )
  198 +
  199 + return parser.parse_args()
  200 +
  201 +
  202 +def assert_file_exists(filename: str):
  203 + assert Path(filename).is_file(), (
  204 + f"{filename} does not exist!\n"
  205 + "Please refer to "
  206 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  207 + )
  208 +
  209 +
  210 +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  211 + if args.encoder:
  212 + assert len(args.paraformer) == 0, args.paraformer
  213 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  214 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  215 +
  216 + assert_file_exists(args.encoder)
  217 + assert_file_exists(args.decoder)
  218 + assert_file_exists(args.joiner)
  219 +
  220 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  221 + encoder=args.encoder,
  222 + decoder=args.decoder,
  223 + joiner=args.joiner,
  224 + tokens=args.tokens,
  225 + num_threads=args.num_threads,
  226 + sample_rate=args.sample_rate,
  227 + feature_dim=args.feature_dim,
  228 + decoding_method=args.decoding_method,
  229 + debug=args.debug,
  230 + )
  231 + elif args.paraformer:
  232 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  233 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  234 +
  235 + assert_file_exists(args.paraformer)
  236 +
  237 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  238 + paraformer=args.paraformer,
  239 + tokens=args.tokens,
  240 + num_threads=args.num_threads,
  241 + sample_rate=args.sample_rate,
  242 + feature_dim=args.feature_dim,
  243 + decoding_method=args.decoding_method,
  244 + debug=args.debug,
  245 + )
  246 + elif args.whisper_encoder:
  247 + assert_file_exists(args.whisper_encoder)
  248 + assert_file_exists(args.whisper_decoder)
  249 +
  250 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  251 + encoder=args.whisper_encoder,
  252 + decoder=args.whisper_decoder,
  253 + tokens=args.tokens,
  254 + num_threads=args.num_threads,
  255 + decoding_method=args.decoding_method,
  256 + debug=args.debug,
  257 + language=args.whisper_language,
  258 + task=args.whisper_task,
  259 + )
  260 + else:
  261 + raise ValueError("Please specify at least one model")
  262 +
  263 + return recognizer
  264 +
  265 +
  266 +def main():
  267 + devices = sd.query_devices()
  268 + if len(devices) == 0:
  269 + print("No microphone devices found")
  270 + sys.exit(0)
  271 +
  272 + print(devices)
  273 +
  274 + # If you want to select a different input device, please use
  275 + # sd.default.device[0] = xxx
  276 + # where xxx is the device number
  277 +
  278 + default_input_device_idx = sd.default.device[0]
  279 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  280 +
  281 + args = get_args()
  282 + assert_file_exists(args.tokens)
  283 + assert_file_exists(args.silero_vad_model)
  284 +
  285 + assert args.num_threads > 0, args.num_threads
  286 +
  287 + assert (
  288 + args.sample_rate == 16000
  289 + ), f"Only sample rate 16000 is supported.Given: {args.sample_rate}"
  290 +
  291 + print("Creating recognizer. Please wait...")
  292 + recognizer = create_recognizer(args)
  293 +
  294 + config = sherpa_onnx.VadModelConfig()
  295 + config.silero_vad.model = args.silero_vad_model
  296 + config.silero_vad.min_silence_duration = 0.25
  297 + config.sample_rate = args.sample_rate
  298 +
  299 + window_size = config.silero_vad.window_size
  300 +
  301 + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100)
  302 +
  303 + samples_per_read = int(0.1 * args.sample_rate) # 0.1 second = 100 ms
  304 +
  305 + print("Started! Please speak")
  306 +
  307 + buffer = []
  308 + texts = []
  309 + with sd.InputStream(channels=1, dtype="float32", samplerate=args.sample_rate) as s:
  310 + while True:
  311 + samples, _ = s.read(samples_per_read) # a blocking read
  312 + samples = samples.reshape(-1)
  313 +
  314 + buffer = np.concatenate([buffer, samples])
  315 + while len(buffer) > window_size:
  316 + vad.accept_waveform(buffer[:window_size])
  317 + buffer = buffer[window_size:]
  318 +
  319 + while not vad.empty():
  320 + stream = recognizer.create_stream()
  321 + stream.accept_waveform(args.sample_rate, vad.front.samples)
  322 +
  323 + vad.pop()
  324 + recognizer.decode_stream(stream)
  325 +
  326 + text = stream.result.text.strip().lower()
  327 + if len(text):
  328 + idx = len(texts)
  329 + texts.append(text)
  330 + print(f"{idx}: {text}")
  331 +
  332 +
  333 +if __name__ == "__main__":
  334 + try:
  335 + main()
  336 + except KeyboardInterrupt:
  337 + print("\nCaught Ctrl + C. Exiting")