Fangjun Kuang
Committed by GitHub

add a two-pass python example (#303)

  1 +#!/usr/bin/env python3
  2 +
  3 +# Two-pass real-time speech recognition from a microphone with sherpa-onnx
  4 +# Python API.
  5 +#
  6 +# The first pass uses a streaming model, which has two purposes:
  7 +#
  8 +# (1) Display a temporary result to users
  9 +#
  10 +# (2) Endpointing
  11 +#
  12 +# The second pass uses a non-streaming model. It has a higher recognition
  13 +# accuracy than the first pass model and its result is used as the final result.
  14 +#
  15 +# Please refer to
  16 +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  17 +# to download pre-trained models
  18 +
  19 +"""
  20 +Usage examples:
  21 +
  22 +(1) Chinese: Streaming zipformer (1st pass) + Non-streaming paraformer (2nd pass)
  23 +
  24 +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \
  25 + --first-encoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/encoder-epoch-99-avg-1.onnx \
  26 + --first-decoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/decoder-epoch-99-avg-1.onnx \
  27 + --first-joiner ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/joiner-epoch-99-avg-1.onnx \
  28 + --first-tokens ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/tokens.txt \
  29 + \
  30 + --second-paraformer ./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx \
  31 + --second-tokens ./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt
  32 +
  33 +(2) English: Streaming zipformer (1st pass) + Non-streaming whisper (2nd pass)
  34 +
  35 +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \
  36 + --first-encoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.onnx \
  37 + --first-decoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/decoder-epoch-99-avg-1.onnx \
  38 + --first-joiner ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/joiner-epoch-99-avg-1.onnx \
  39 + --first-tokens ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt \
  40 + \
  41 + --second-whisper-encoder ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx \
  42 + --second-whisper-decoder ./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx \
  43 + --second-tokens ./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
  44 +"""
  45 +
  46 +import argparse
  47 +import sys
  48 +from pathlib import Path
  49 +from typing import List
  50 +
  51 +import numpy as np
  52 +
  53 +try:
  54 + import sounddevice as sd
  55 +except ImportError:
  56 + print("Please install sounddevice first. You can use")
  57 + print()
  58 + print(" pip install sounddevice")
  59 + print()
  60 + print("to install it")
  61 + sys.exit(-1)
  62 +
  63 +import sherpa_onnx
  64 +
  65 +
  66 +def assert_file_exists(filename: str, message: str):
  67 + if not filename:
  68 + raise ValueError(f"Please specify {message}")
  69 +
  70 + if not Path(filename).is_file():
  71 + raise ValueError(f"{message} {filename} does not exist")
  72 +
  73 +
  74 +def add_first_pass_streaming_model_args(parser: argparse.ArgumentParser):
  75 + parser.add_argument(
  76 + "--first-tokens",
  77 + type=str,
  78 + required=True,
  79 + help="Path to tokens.txt for the first pass",
  80 + )
  81 +
  82 + parser.add_argument(
  83 + "--first-encoder",
  84 + type=str,
  85 + required=True,
  86 + help="Path to the encoder model for the first pass",
  87 + )
  88 +
  89 + parser.add_argument(
  90 + "--first-decoder",
  91 + type=str,
  92 + required=True,
  93 + help="Path to the decoder model for the first pass",
  94 + )
  95 +
  96 + parser.add_argument(
  97 + "--first-joiner",
  98 + type=str,
  99 + help="Path to the joiner model for the first pass",
  100 + )
  101 +
  102 + parser.add_argument(
  103 + "--first-decoding-method",
  104 + type=str,
  105 + default="greedy_search",
  106 + help="""Decoding method for the first pass. Valid values are
  107 + greedy_search and modified_beam_search""",
  108 + )
  109 +
  110 + parser.add_argument(
  111 + "--first-max-active-paths",
  112 + type=int,
  113 + default=4,
  114 + help="""Used only when --first-decoding-method is modified_beam_search.
  115 + It specifies number of active paths to keep during decoding.
  116 + """,
  117 + )
  118 +
  119 +
  120 +def add_second_pass_transducer_model_args(parser: argparse.ArgumentParser):
  121 + parser.add_argument(
  122 + "--second-encoder",
  123 + default="",
  124 + type=str,
  125 + help="Path to the transducer encoder model for the second pass",
  126 + )
  127 +
  128 + parser.add_argument(
  129 + "--second-decoder",
  130 + default="",
  131 + type=str,
  132 + help="Path to the transducer decoder model for the second pass",
  133 + )
  134 +
  135 + parser.add_argument(
  136 + "--second-joiner",
  137 + default="",
  138 + type=str,
  139 + help="Path to the transducer joiner model for the second pass",
  140 + )
  141 +
  142 +
  143 +def add_second_pass_paraformer_model_args(parser: argparse.ArgumentParser):
  144 + parser.add_argument(
  145 + "--second-paraformer",
  146 + default="",
  147 + type=str,
  148 + help="Path to the model.onnx for Paraformer for the second pass",
  149 + )
  150 +
  151 +
  152 +def add_second_pass_nemo_ctc_model_args(parser: argparse.ArgumentParser):
  153 + parser.add_argument(
  154 + "--second-nemo-ctc",
  155 + default="",
  156 + type=str,
  157 + help="Path to the model.onnx for NeMo CTC for the second pass",
  158 + )
  159 +
  160 +
  161 +def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser):
  162 + parser.add_argument(
  163 + "--second-whisper-encoder",
  164 + default="",
  165 + type=str,
  166 + help="Path to whisper encoder model for the second pass",
  167 + )
  168 +
  169 + parser.add_argument(
  170 + "--second-whisper-decoder",
  171 + default="",
  172 + type=str,
  173 + help="Path to whisper decoder model for the second pass",
  174 + )
  175 +
  176 + parser.add_argument(
  177 + "--second-whisper-language",
  178 + default="",
  179 + type=str,
  180 + help="""It specifies the spoken language in the input audio file.
  181 + Example values: en, fr, de, zh, jp.
  182 + Available languages for multilingual models can be found at
  183 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  184 + If not specified, we infer the language from the input audio file.
  185 + """,
  186 + )
  187 +
  188 + parser.add_argument(
  189 + "--second-whisper-task",
  190 + default="transcribe",
  191 + choices=["transcribe", "translate"],
  192 + type=str,
  193 + help="""For multilingual models, if you specify translate, the output
  194 + will be in English.
  195 + """,
  196 + )
  197 +
  198 +
  199 +def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser):
  200 + add_second_pass_transducer_model_args(parser)
  201 + add_second_pass_nemo_ctc_model_args(parser)
  202 + add_second_pass_paraformer_model_args(parser)
  203 + add_second_pass_whisper_model_args(parser)
  204 +
  205 + parser.add_argument(
  206 + "--second-tokens",
  207 + type=str,
  208 + help="Path to tokens.txt for the second pass",
  209 + )
  210 +
  211 +
  212 +def get_args():
  213 + parser = argparse.ArgumentParser(
  214 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  215 + )
  216 +
  217 + parser.add_argument(
  218 + "--provider",
  219 + type=str,
  220 + default="cpu",
  221 + help="Valid values: cpu, cuda, coreml",
  222 + )
  223 +
  224 + add_first_pass_streaming_model_args(parser)
  225 + add_second_pass_non_streaming_model_args(parser)
  226 +
  227 + return parser.parse_args()
  228 +
  229 +
  230 +def check_first_pass_args(args):
  231 + assert_file_exists(args.first_tokens, "--first-tokens")
  232 + assert_file_exists(args.first_encoder, "--first-encoder")
  233 + assert_file_exists(args.first_decoder, "--first-decoder")
  234 + assert_file_exists(args.first_joiner, "--first-joiner")
  235 +
  236 +
  237 +def check_second_pass_args(args):
  238 + assert_file_exists(args.second_tokens, "--second-tokens")
  239 +
  240 + if args.second_encoder:
  241 + assert_file_exists(args.second_encoder, "--second-encoder")
  242 + assert_file_exists(args.second_decoder, "--second-decoder")
  243 + assert_file_exists(args.second_joiner, "--second-joiner")
  244 + elif args.second_paraformer:
  245 + assert_file_exists(args.second_paraformer, "--second-paraformer")
  246 + elif args.second_nemo_ctc:
  247 + assert_file_exists(args.second_nemo_ctc, "--second-nemo-ctc")
  248 + elif args.second_whisper_encoder:
  249 + assert_file_exists(args.second_whisper_encoder, "--second-whisper-encoder")
  250 + assert_file_exists(args.second_whisper_decoder, "--second-whisper-decoder")
  251 + else:
  252 + raise ValueError("Please specify the model for the second pass")
  253 +
  254 +
  255 +def create_first_pass_recognizer(args):
  256 + # Please replace the model files if needed.
  257 + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  258 + # for download links.
  259 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
  260 + tokens=args.first_tokens,
  261 + encoder=args.first_encoder,
  262 + decoder=args.first_decoder,
  263 + joiner=args.first_joiner,
  264 + num_threads=1,
  265 + sample_rate=16000,
  266 + feature_dim=80,
  267 + decoding_method=args.first_decoding_method,
  268 + max_active_paths=args.first_max_active_paths,
  269 + provider=args.provider,
  270 + enable_endpoint_detection=True,
  271 + rule1_min_trailing_silence=2.4,
  272 + rule2_min_trailing_silence=1.2,
  273 + rule3_min_utterance_length=20,
  274 + )
  275 + return recognizer
  276 +
  277 +
  278 +def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  279 + if args.second_encoder:
  280 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  281 + encoder=args.second_encoder,
  282 + decoder=args.second_decoder,
  283 + joiner=args.second_joiner,
  284 + tokens=args.second_tokens,
  285 + sample_rate=16000,
  286 + feature_dim=80,
  287 + decoding_method="greedy_search",
  288 + max_active_paths=4,
  289 + )
  290 + elif args.second_paraformer:
  291 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  292 + paraformer=args.second_paraformer,
  293 + tokens=args.second_tokens,
  294 + num_threads=1,
  295 + sample_rate=16000,
  296 + feature_dim=80,
  297 + decoding_method="greedy_search",
  298 + )
  299 + elif args.second_nemo_ctc:
  300 + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
  301 + model=args.second_nemo_ctc,
  302 + tokens=args.second_tokens,
  303 + num_threads=1,
  304 + sample_rate=16000,
  305 + feature_dim=80,
  306 + decoding_method="greedy_search",
  307 + )
  308 + elif args.second_whisper_encoder:
  309 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  310 + encoder=args.second_whisper_encoder,
  311 + decoder=args.second_whisper_decoder,
  312 + tokens=args.second_tokens,
  313 + num_threads=1,
  314 + decoding_method="greedy_search",
  315 + language=args.second_whisper_language,
  316 + task=args.second_whisper_task,
  317 + )
  318 + else:
  319 + raise ValueError("Please specify at least one model for the second pass")
  320 +
  321 + return recognizer
  322 +
  323 +
  324 +def run_second_pass(
  325 + recognizer: sherpa_onnx.OfflineRecognizer,
  326 + sample_buffers: List[np.ndarray],
  327 + sample_rate: int,
  328 +):
  329 + stream = recognizer.create_stream()
  330 + samples = np.concatenate(sample_buffers)
  331 + stream.accept_waveform(sample_rate, samples)
  332 +
  333 + recognizer.decode_stream(stream)
  334 +
  335 + return stream.result.text
  336 +
  337 +
  338 +def main():
  339 + args = get_args()
  340 + check_first_pass_args(args)
  341 + check_second_pass_args(args)
  342 +
  343 + devices = sd.query_devices()
  344 + if len(devices) == 0:
  345 + print("No microphone devices found")
  346 + sys.exit(0)
  347 +
  348 + print(devices)
  349 +
  350 + # If you want to select a different input device, please use
  351 + # sd.default.device[0] = xxx
  352 + # where xxx is the device number
  353 +
  354 + default_input_device_idx = sd.default.device[0]
  355 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  356 +
  357 + print("Creating recognizers. Please wait...")
  358 + first_recognizer = create_first_pass_recognizer(args)
  359 + second_recognizer = create_second_pass_recognizer(args)
  360 +
  361 + print("Started! Please speak")
  362 +
  363 + sample_rate = 16000
  364 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
  365 + stream = first_recognizer.create_stream()
  366 +
  367 + last_result = ""
  368 + segment_id = 0
  369 +
  370 + sample_buffers = []
  371 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
  372 + while True:
  373 + samples, _ = s.read(samples_per_read) # a blocking read
  374 + samples = samples.reshape(-1)
  375 + stream.accept_waveform(sample_rate, samples)
  376 +
  377 + sample_buffers.append(samples)
  378 +
  379 + while first_recognizer.is_ready(stream):
  380 + first_recognizer.decode_stream(stream)
  381 +
  382 + is_endpoint = first_recognizer.is_endpoint(stream)
  383 +
  384 + result = first_recognizer.get_result(stream)
  385 + result = result.lower().strip()
  386 +
  387 + if last_result != result:
  388 + print(
  389 + "\r{}:{}".format(segment_id, " " * len(last_result)),
  390 + end="",
  391 + flush=True,
  392 + )
  393 + last_result = result
  394 + print("\r{}:{}".format(segment_id, result), end="", flush=True)
  395 +
  396 + if is_endpoint:
  397 + if result:
  398 + result = run_second_pass(
  399 + recognizer=second_recognizer,
  400 + sample_buffers=sample_buffers,
  401 + sample_rate=sample_rate,
  402 + )
  403 + result = result.lower().strip()
  404 +
  405 + sample_buffers = []
  406 + print(
  407 + "\r{}:{}".format(segment_id, " " * len(last_result)),
  408 + end="",
  409 + flush=True,
  410 + )
  411 + print("\r{}:{}".format(segment_id, result), flush=True)
  412 + segment_id += 1
  413 + else:
  414 + sample_buffers = []
  415 +
  416 + first_recognizer.reset(stream)
  417 +
  418 +
  419 +if __name__ == "__main__":
  420 + try:
  421 + main()
  422 + except KeyboardInterrupt:
  423 + print("\nCaught Ctrl + C. Exiting")