AlexWhisper
Committed by GitHub

Alex/feat add python example (#2490)

Co-authored-by: wangmh <minghu.wang@shenyiai.net>
@@ -39,6 +39,19 @@ from pathlib import Path @@ -39,6 +39,19 @@ from pathlib import Path
39 39
40 import sherpa_onnx 40 import sherpa_onnx
41 import soundfile as sf 41 import soundfile as sf
  42 +import librosa
  43 +
  44 +
  45 +def resample_audio(audio, sample_rate, target_sample_rate):
  46 + """
  47 + Resample audio to target sample rate using librosa
  48 + """
  49 + if sample_rate != target_sample_rate:
  50 + print(f"Resampling audio from {sample_rate}Hz to {target_sample_rate}Hz...")
  51 + audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sample_rate)
  52 + print(f"Resampling completed. New audio shape: {audio.shape}")
  53 + return audio, target_sample_rate
  54 + return audio, sample_rate
42 55
43 56
44 def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5): 57 def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
@@ -97,6 +110,11 @@ def main(): @@ -97,6 +110,11 @@ def main():
97 # Since we know there are 4 speakers in the above test wave file, we use 110 # Since we know there are 4 speakers in the above test wave file, we use
98 # num_speakers 4 here 111 # num_speakers 4 here
99 sd = init_speaker_diarization(num_speakers=4) 112 sd = init_speaker_diarization(num_speakers=4)
  113 +
  114 + # Resample audio to match the expected sample rate
  115 + target_sample_rate = sd.sample_rate
  116 + audio, sample_rate = resample_audio(audio, sample_rate, target_sample_rate)
  117 +
100 if sample_rate != sd.sample_rate: 118 if sample_rate != sd.sample_rate:
101 raise RuntimeError( 119 raise RuntimeError(
102 f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}" 120 f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
  1 +#!/usr/bin/env python3
  2 +# Copyright (c) 2025 Minghu Wang
  3 +"""
  4 +
  5 +A two-pass streaming ASR server with WebSocket support. This server implements
  6 +a two-pass recognition strategy where the first pass uses a fast streaming model
  7 +for real-time recognition, and the second pass uses a more accurate offline model
  8 +to refine the results.
  9 +
  10 +The first pass provides immediate feedback to users, while the second pass
  11 +improves accuracy by re-processing the complete utterance with a more powerful
  12 +model.
  13 +
  14 +It supports multiple clients sending audio simultaneously and provides
  15 +real-time transcription results.
  16 +
  17 +Usage:
  18 + ./two-pass-wss.py --help
  19 +
  20 +Example:
  21 +
  22 +(1) Without a certificate
  23 +
  24 +python3 ./python-api-examples/two-pass-wss.py \
  25 + --paraformer-encoder ./sherpa-onnx-paraformer-zh-2023-09-18/encoder.onnx \
  26 + --paraformer-decoder ./sherpa-onnx-paraformer-zh-2023-09-18/decoder.onnx \
  27 + --tokens ./sherpa-onnx-paraformer-zh-2023-09-18/tokens.txt \
  28 + --second-sense-voice ./sherpa-onnx-sense-voice-zh-2023-09-18/model.onnx \
  29 + --second-tokens ./sherpa-onnx-sense-voice-zh-2023-09-18/tokens.txt
  30 +
  31 +(2) With a certificate
  32 +
  33 +(a) Generate a certificate first:
  34 +
  35 + cd python-api-examples/web
  36 + ./generate-certificate.py
  37 + cd ../..
  38 +
  39 +(b) Start the server
  40 +
  41 +python3 ./python-api-examples/two-pass-wss.py \
  42 + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \
  43 + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \
  44 + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \
  45 + --second-sense-voice ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx \
  46 + --second-tokens ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \
  47 + --certificate ./python-api-examples/web/cert.pem
  48 +
  49 +Please refer to
  50 +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
  51 +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  52 +to download pre-trained models.
  53 +"""
  54 +
  55 +import argparse
  56 +import asyncio
  57 +import http
  58 +import json
  59 +import logging
  60 +import socket
  61 +import ssl
  62 +from concurrent.futures import ThreadPoolExecutor
  63 +from datetime import datetime
  64 +from pathlib import Path
  65 +from typing import List, Optional, Tuple
  66 +
  67 +import numpy as np
  68 +import sherpa_onnx
  69 +import websockets
  70 +
  71 +def setup_logger(
  72 + log_filename: str,
  73 + log_level: str = "info",
  74 + use_console: bool = True,
  75 +) -> None:
  76 + """Setup log level.
  77 +
  78 + Args:
  79 + log_filename:
  80 + The filename to save the log.
  81 + log_level:
  82 + The log level to use, e.g., "debug", "info", "warning", "error",
  83 + "critical"
  84 + use_console:
  85 + True to also print logs to console.
  86 + """
  87 + now = datetime.now()
  88 + date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
  89 + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
  90 + log_filename = f"{log_filename}-{date_time}.txt"
  91 +
  92 + Path(log_filename).parent.mkdir(parents=True, exist_ok=True)
  93 +
  94 + level = logging.ERROR
  95 + if log_level == "debug":
  96 + level = logging.DEBUG
  97 + elif log_level == "info":
  98 + level = logging.INFO
  99 + elif log_level == "warning":
  100 + level = logging.WARNING
  101 + elif log_level == "critical":
  102 + level = logging.CRITICAL
  103 +
  104 + logging.basicConfig(
  105 + filename=log_filename,
  106 + format=formatter,
  107 + level=level,
  108 + filemode="w",
  109 + )
  110 + if use_console:
  111 + console = logging.StreamHandler()
  112 + console.setLevel(level)
  113 + console.setFormatter(logging.Formatter(formatter))
  114 + logging.getLogger("").addHandler(console)
  115 +
  116 +
  117 +def add_model_args(parser: argparse.ArgumentParser):
  118 + parser.add_argument(
  119 + "--encoder",
  120 + type=str,
  121 + default="",
  122 + help="Path to the transducer encoder model",
  123 + )
  124 +
  125 + parser.add_argument(
  126 + "--decoder",
  127 + type=str,
  128 + default="",
  129 + help="Path to the transducer decoder model.",
  130 + )
  131 +
  132 +
  133 + parser.add_argument(
  134 + "--second-tokens",
  135 + type=str,
  136 + default="",
  137 + help="Path to the second pass tokens.txt",
  138 + )
  139 +
  140 + parser.add_argument(
  141 + "--second-sense-voice",
  142 + type=str,
  143 + default="",
  144 + help="Path to the second pass sense voice model.",
  145 + )
  146 +
  147 + parser.add_argument(
  148 + "--paraformer-encoder",
  149 + type=str,
  150 + default="",
  151 + help="Path to the paraformer encoder model",
  152 + )
  153 +
  154 + parser.add_argument(
  155 + "--paraformer-decoder",
  156 + type=str,
  157 + default="",
  158 + help="Path to the paraformer decoder model.",
  159 + )
  160 +
  161 + parser.add_argument(
  162 + "--tokens",
  163 + type=str,
  164 + default="",
  165 + help="Path to tokens.txt",
  166 + )
  167 +
  168 + parser.add_argument(
  169 + "--sample-rate",
  170 + type=int,
  171 + default=16000,
  172 + help="Sample rate of the data used to train the model. "
  173 + "Caution: If your input sound files have a different sampling rate, "
  174 + "we will do resampling inside",
  175 + )
  176 +
  177 + parser.add_argument(
  178 + "--feat-dim",
  179 + type=int,
  180 + default=80,
  181 + help="Feature dimension of the model",
  182 + )
  183 +
  184 + parser.add_argument(
  185 + "--provider",
  186 + type=str,
  187 + default="cpu",
  188 + help="Valid values: cpu, cuda, coreml",
  189 + )
  190 +
  191 +
  192 +def add_decoding_args(parser: argparse.ArgumentParser):
  193 + parser.add_argument(
  194 + "--decoding-method",
  195 + type=str,
  196 + default="greedy_search",
  197 + help="""Decoding method to use. Current supported methods are:
  198 + - greedy_search
  199 + - modified_beam_search
  200 + """,
  201 + )
  202 +
  203 + add_modified_beam_search_args(parser)
  204 +
  205 +
  206 +def add_hotwords_args(parser: argparse.ArgumentParser):
  207 + parser.add_argument(
  208 + "--hotwords-file",
  209 + type=str,
  210 + default="",
  211 + help="""
  212 + The file containing hotwords, one words/phrases per line, and for each
  213 + phrase the bpe/cjkchar are separated by a space. For example:
  214 +
  215 + ▁HE LL O ▁WORLD
  216 + 你 好 世 界
  217 + """,
  218 + )
  219 +
  220 + parser.add_argument(
  221 + "--hotwords-score",
  222 + type=float,
  223 + default=1.5,
  224 + help="""
  225 + The hotword score of each token for biasing word/phrase. Used only if
  226 + --hotwords-file is given.
  227 + """,
  228 + )
  229 + parser.add_argument(
  230 + "--modeling-unit",
  231 + type=str,
  232 + default='cjkchar',
  233 + help="""
  234 + The modeling unit of the used model. Current supported units are:
  235 + - cjkchar(for Chinese)
  236 + - bpe(for English like languages)
  237 + - cjkchar+bpe(for multilingual models)
  238 + """,
  239 + )
  240 + parser.add_argument(
  241 + "--bpe-vocab",
  242 + type=str,
  243 + default='',
  244 + help="""
  245 + The bpe vocabulary generated by sentencepiece toolkit.
  246 + It is only used when modeling-unit is bpe or cjkchar+bpe.
  247 + if you can't find bpe.vocab in the model directory, please run:
  248 + python script/export_bpe_vocab.py --bpe-model exp/bpe.model
  249 + """,
  250 + )
  251 +
  252 +
  253 +def add_modified_beam_search_args(parser: argparse.ArgumentParser):
  254 + parser.add_argument(
  255 + "--num-active-paths",
  256 + type=int,
  257 + default=4,
  258 + help="""Used only when --decoding-method is modified_beam_search.
  259 + It specifies number of active paths to keep during decoding.
  260 + """,
  261 + )
  262 +
  263 +def add_blank_penalty_args(parser: argparse.ArgumentParser):
  264 + parser.add_argument(
  265 + "--blank-penalty",
  266 + type=float,
  267 + default=0.0,
  268 + help="""
  269 + The penalty applied on blank symbol during decoding.
  270 + Note: It is a positive value that would be applied to logits like
  271 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  272 + [batch_size, vocab] and blank id is 0).
  273 + """,
  274 + )
  275 +
  276 +def add_endpointing_args(parser: argparse.ArgumentParser):
  277 + parser.add_argument(
  278 + "--rule1-min-trailing-silence",
  279 + type=float,
  280 + default=2.4,
  281 + help="""This endpointing rule1 requires duration of trailing silence
  282 + in seconds) to be >= this value""",
  283 + )
  284 +
  285 + parser.add_argument(
  286 + "--rule2-min-trailing-silence",
  287 + type=float,
  288 + default=1.2,
  289 + help="""This endpointing rule2 requires duration of trailing silence in
  290 + seconds) to be >= this value.""",
  291 + )
  292 +
  293 + parser.add_argument(
  294 + "--rule3-min-utterance-length",
  295 + type=float,
  296 + default=20,
  297 + help="""This endpointing rule3 requires utterance-length (in seconds)
  298 + to be >= this value.""",
  299 + )
  300 +
  301 +
  302 +def get_args():
  303 + parser = argparse.ArgumentParser(
  304 + formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  305 + )
  306 +
  307 + add_model_args(parser)
  308 + add_decoding_args(parser)
  309 + add_endpointing_args(parser)
  310 + add_hotwords_args(parser)
  311 + add_blank_penalty_args(parser)
  312 +
  313 + parser.add_argument(
  314 + "--port",
  315 + type=int,
  316 + default=6006,
  317 + help="The server will listen on this port",
  318 + )
  319 +
  320 + parser.add_argument(
  321 + "--nn-pool-size",
  322 + type=int,
  323 + default=1,
  324 + help="Number of threads for NN computation and decoding.",
  325 + )
  326 +
  327 + parser.add_argument(
  328 + "--max-batch-size",
  329 + type=int,
  330 + default=3,
  331 + help="""Max batch size for computation. Note if there are not enough
  332 + requests in the queue, it will wait for max_wait_ms time. After that,
  333 + even if there are not enough requests, it still sends the
  334 + available requests in the queue for computation.
  335 + """,
  336 + )
  337 +
  338 + parser.add_argument(
  339 + "--max-wait-ms",
  340 + type=float,
  341 + default=10,
  342 + help="""Max time in millisecond to wait to build batches for inference.
  343 + If there are not enough requests in the stream queue to build a batch
  344 + of max_batch_size, it waits up to this time before fetching available
  345 + requests for computation.
  346 + """,
  347 + )
  348 +
  349 + parser.add_argument(
  350 + "--max-message-size",
  351 + type=int,
  352 + default=(1 << 20),
  353 + help="""Max message size in bytes.
  354 + The max size per message cannot exceed this limit.
  355 + """,
  356 + )
  357 +
  358 + parser.add_argument(
  359 + "--max-queue-size",
  360 + type=int,
  361 + default=32,
  362 + help="Max number of messages in the queue for each connection.",
  363 + )
  364 +
  365 + parser.add_argument(
  366 + "--max-active-connections",
  367 + type=int,
  368 + default=200,
  369 + help="""Maximum number of active connections. The server will refuse
  370 + to accept new connections once the current number of active connections
  371 + equals to this limit.
  372 + """,
  373 + )
  374 +
  375 + parser.add_argument(
  376 + "--num-threads",
  377 + type=int,
  378 + default=2,
  379 + help="Number of threads to run the neural network model",
  380 + )
  381 +
  382 + parser.add_argument(
  383 + "--second-pass-threads",
  384 + type=int,
  385 + default=2,
  386 + help="Number of threads for second pass processing",
  387 + )
  388 +
  389 + parser.add_argument(
  390 + "--certificate",
  391 + type=str,
  392 + help="""Path to the X.509 certificate. You need it only if you want to
  393 + use a secure websocket connection, i.e., use wss:// instead of ws://.
  394 + You can use ./web/generate-certificate.py
  395 + to generate the certificate `cert.pem`.
  396 + Note ./web/generate-certificate.py will generate three files but you
  397 + only need to pass the generated cert.pem to this option.
  398 + """,
  399 + )
  400 +
  401 + return parser.parse_args()
  402 +
  403 +def run_second_pass(
  404 + recognizer: sherpa_onnx.OfflineRecognizer,
  405 + samples: np.ndarray,
  406 + sample_rate: int,
  407 +):
  408 + stream = recognizer.create_stream()
  409 + stream.accept_waveform(sample_rate, samples)
  410 +
  411 + recognizer.decode_stream(stream)
  412 +
  413 + return stream.result.text
  414 +
  415 +def create_first_pass_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
  416 + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
  417 + tokens=args.tokens,
  418 + encoder=args.paraformer_encoder,
  419 + decoder=args.paraformer_decoder,
  420 + num_threads=args.num_threads,
  421 + sample_rate=args.sample_rate,
  422 + feature_dim=args.feat_dim,
  423 + decoding_method=args.decoding_method,
  424 + enable_endpoint_detection=True,
  425 + rule1_min_trailing_silence=args.rule1_min_trailing_silence,
  426 + rule2_min_trailing_silence=args.rule2_min_trailing_silence,
  427 + rule3_min_utterance_length=args.rule3_min_utterance_length,
  428 + provider=args.provider,
  429 + )
  430 + return recognizer
  431 +
  432 +
  433 +def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  434 + recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
  435 + model=args.second_sense_voice,
  436 + tokens=args.second_tokens,
  437 + num_threads=1,
  438 + sample_rate=16000,
  439 + feature_dim=80,
  440 + use_itn=True,
  441 + decoding_method="greedy_search",
  442 + )
  443 + return recognizer
  444 +
  445 +
  446 +def format_timestamps(timestamps: List[float]) -> List[str]:
  447 + return ["{:.3f}".format(t) for t in timestamps]
  448 +
  449 +
  450 +class StreamingServer(object):
  451 + def __init__(
  452 + self,
  453 + first_pass_recognizer: sherpa_onnx.OnlineRecognizer,
  454 + second_pass_recognizer: sherpa_onnx.OfflineRecognizer,
  455 + nn_pool_size: int,
  456 + max_wait_ms: float,
  457 + max_batch_size: int,
  458 + max_message_size: int,
  459 + max_queue_size: int,
  460 + max_active_connections: int,
  461 + second_pass_threads: int = 2,
  462 + certificate: Optional[str] = None,
  463 + ):
  464 + """
  465 + Args:
  466 + first_pass_recognizer:
  467 + An instance of online recognizer for first pass.
  468 + second_pass_recognizer:
  469 + An instance of offline recognizer for second pass.
  470 + nn_pool_size:
  471 + Number of threads for the thread pool that is responsible for
  472 + neural network computation and decoding.
  473 + max_wait_ms:
  474 + Max wait time in milliseconds in order to build a batch of
  475 + `batch_size`.
  476 + max_batch_size:
  477 + Max batch size for inference.
  478 + max_message_size:
  479 + Max size in bytes per message.
  480 + max_queue_size:
  481 + Max number of messages in the queue for each connection.
  482 + max_active_connections:
  483 + Max number of active connections. Once number of active client
  484 + equals to this limit, the server refuses to accept new connections.
  485 + certificate:
  486 + Optional. If not None, it will use secure websocket.
  487 + You can use ./web/generate-certificate.py to generate
  488 + it (the default generated filename is `cert.pem`).
  489 + """
  490 + self.first_pass_recognizer = first_pass_recognizer
  491 + self.second_pass_recognizer = second_pass_recognizer
  492 +
  493 + self.certificate = certificate
  494 +
  495 + self.nn_pool_size = nn_pool_size
  496 + self.nn_pool = ThreadPoolExecutor(
  497 + max_workers=nn_pool_size,
  498 + thread_name_prefix="nn",
  499 + )
  500 +
  501 + self.second_pass_pool = ThreadPoolExecutor(
  502 + max_workers=second_pass_threads,
  503 + thread_name_prefix="second_pass",
  504 + )
  505 +
  506 + self.stream_queue = asyncio.Queue()
  507 +
  508 + self.max_wait_ms = max_wait_ms
  509 + self.max_batch_size = max_batch_size
  510 + self.max_message_size = max_message_size
  511 + self.max_queue_size = max_queue_size
  512 + self.max_active_connections = max_active_connections
  513 +
  514 + self.current_active_connections = 0
  515 +
  516 + self.sample_rate = int(self.first_pass_recognizer.config.feat_config.sampling_rate)
  517 +
  518 + async def stream_consumer_task(self):
  519 + """This function extracts streams from the queue, batches them up, sends
  520 + them to the neural network model for computation and decoding.
  521 + """
  522 + while True:
  523 + if self.stream_queue.empty():
  524 + await asyncio.sleep(self.max_wait_ms / 1000)
  525 + continue
  526 +
  527 + batch = []
  528 + try:
  529 + while len(batch) < self.max_batch_size:
  530 + item = self.stream_queue.get_nowait()
  531 +
  532 + assert self.first_pass_recognizer.is_ready(item[0])
  533 +
  534 + batch.append(item)
  535 + except asyncio.QueueEmpty:
  536 + pass
  537 + stream_list = [b[0] for b in batch]
  538 + future_list = [b[1] for b in batch]
  539 +
  540 + loop = asyncio.get_running_loop()
  541 + await loop.run_in_executor(
  542 + self.nn_pool,
  543 + self.first_pass_recognizer.decode_streams,
  544 + stream_list,
  545 + )
  546 +
  547 + for f in future_list:
  548 + self.stream_queue.task_done()
  549 + f.set_result(None)
  550 +
  551 + async def compute_and_decode(
  552 + self,
  553 + stream: sherpa_onnx.OnlineStream,
  554 + ) -> None:
  555 + """Put the stream into the queue and wait it to be processed by the
  556 + consumer task.
  557 +
  558 + Args:
  559 + stream:
  560 + The stream to be processed. Note: It is changed in-place.
  561 + """
  562 + loop = asyncio.get_running_loop()
  563 + future = loop.create_future()
  564 + await self.stream_queue.put((stream, future))
  565 + await future
  566 +
  567 + async def run_second_pass_async(
  568 + self,
  569 + samples: np.ndarray,
  570 + sample_rate: int,
  571 + ) -> str:
  572 + """Run second-pass recognition asynchronously to avoid blocking.
  573 +
  574 + Args:
  575 + samples: Audio samples.
  576 + sample_rate: Sampling rate.
  577 +
  578 + Returns:
  579 + Text result from the second-pass recognition.
  580 + """
  581 + import time
  582 + start_time = time.time()
  583 +
  584 + loop = asyncio.get_running_loop()
  585 + result = await loop.run_in_executor(
  586 + self.second_pass_pool,
  587 + run_second_pass,
  588 + self.second_pass_recognizer,
  589 + samples,
  590 + sample_rate,
  591 + )
  592 +
  593 + end_time = time.time()
  594 + duration = end_time - start_time
  595 + logging.info(f"Second pass processing completed in {duration:.3f}s for {len(samples)/sample_rate:.2f}s audio")
  596 +
  597 + return result.lower().strip()
  598 +
  599 + async def process_request(
  600 + self,
  601 + path: str,
  602 + request_headers: websockets.Headers,
  603 + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]:
  604 + if self.current_active_connections < self.max_active_connections:
  605 + self.current_active_connections += 1
  606 + return None
  607 +
  608 + # Refuse new connections
  609 + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
  610 + header = {"Hint": "The server is overloaded. Please retry later."}
  611 + response = b"The server is busy. Please retry later."
  612 +
  613 + return status, header, response
  614 +
  615 + async def run(self, port: int):
  616 + tasks = []
  617 + for i in range(self.nn_pool_size):
  618 + tasks.append(asyncio.create_task(self.stream_consumer_task()))
  619 +
  620 + if self.certificate:
  621 + logging.info(f"Using certificate: {self.certificate}")
  622 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  623 + ssl_context.load_cert_chain(self.certificate)
  624 + else:
  625 + ssl_context = None
  626 + logging.info("No certificate provided")
  627 +
  628 + try:
  629 + async with websockets.serve(
  630 + self.handle_connection,
  631 + host="",
  632 + port=port,
  633 + max_size=self.max_message_size,
  634 + max_queue=self.max_queue_size,
  635 + process_request=self.process_request,
  636 + ssl=ssl_context,
  637 + ):
  638 + logging.info(f"Started server on port {port}")
  639 + await asyncio.Future() # run forever
  640 + finally:
  641 + logging.info("Shutting down thread pools...")
  642 + self.nn_pool.shutdown(wait=True)
  643 + self.second_pass_pool.shutdown(wait=True)
  644 + logging.info("Thread pools shut down successfully")
  645 +
  646 + await asyncio.gather(*tasks) # not reachable
  647 +
  648 + async def handle_connection(
  649 + self,
  650 + socket: websockets.WebSocketServerProtocol,
  651 + ):
  652 + """Receive audio samples from the client, process it, and send
  653 + decoding result back to the client.
  654 +
  655 + Args:
  656 + socket:
  657 + The socket for communicating with the client.
  658 + """
  659 + try:
  660 + await self.handle_connection_impl(socket)
  661 + except websockets.exceptions.ConnectionClosed:
  662 + logging.info(f"{socket.remote_address} disconnected")
  663 + finally:
  664 + # Decrement so that it can accept new connections
  665 + self.current_active_connections -= 1
  666 +
  667 + logging.info(
  668 + f"Disconnected: {socket.remote_address}. "
  669 + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
  670 + )
  671 +
  672 + async def handle_connection_impl(
  673 + self,
  674 + socket: websockets.WebSocketServerProtocol,
  675 + ):
  676 + """Receive audio samples from the client, process it, and send
  677 + decoding result back to the client.
  678 +
  679 + Args:
  680 + socket:
  681 + The socket for communicating with the client.
  682 + """
  683 + stream = self.first_pass_recognizer.create_stream()
  684 + segment = 0
  685 + sample_buffers = []
  686 + while True:
  687 + samples = await self.recv_audio_samples(socket)
  688 + if samples is None:
  689 + break
  690 +
  691 + # TODO(fangjun): At present, we assume the sampling rate
  692 + # of the received audio samples equal to --sample-rate
  693 + stream.accept_waveform(sample_rate=self.sample_rate, waveform=samples)
  694 + sample_buffers.append(samples)
  695 + while self.first_pass_recognizer.is_ready(stream):
  696 + await self.compute_and_decode(stream)
  697 + result = self.first_pass_recognizer.get_result(stream)
  698 +
  699 + message = {
  700 + "text": result,
  701 + "segment": segment,
  702 + }
  703 + if self.first_pass_recognizer.is_endpoint(stream):
  704 + if result:
  705 + samples_for_2nd_pass = np.concatenate(sample_buffers)
  706 + sample_buffers = [samples_for_2nd_pass[-8000:]]
  707 + samples_for_2nd_pass = samples_for_2nd_pass[:-8000]
  708 + second_pass_result = (
  709 + await self.run_second_pass_async(
  710 + samples=samples_for_2nd_pass,
  711 + sample_rate=self.sample_rate,
  712 + )
  713 + )
  714 +
  715 + if second_pass_result:
  716 + message["text"] = second_pass_result
  717 + message["segment"] = segment
  718 + else:
  719 + sample_buffers=[]
  720 +
  721 + self.first_pass_recognizer.reset(stream)
  722 + segment += 1
  723 + await socket.send(json.dumps(message))
  724 +
  725 + tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32)
  726 + stream.accept_waveform(sample_rate=self.sample_rate, waveform=tail_padding)
  727 + stream.input_finished()
  728 + while self.first_pass_recognizer.is_ready(stream):
  729 + await self.compute_and_decode(stream)
  730 +
  731 + result = self.first_pass_recognizer.get_result(stream)
  732 +
  733 + message = {
  734 + "text": result,
  735 + "segment": segment,
  736 + }
  737 + await socket.send(json.dumps(message))
  738 +
  739 + async def recv_audio_samples(
  740 + self,
  741 + socket: websockets.WebSocketServerProtocol,
  742 + ) -> Optional[np.ndarray]:
  743 + """Receive audio samples from WebSocket connection
  744 +
  745 + Args:
  746 + socket: WebSocket connection
  747 +
  748 + Returns:
  749 + Numpy array containing audio samples, or None indicating end of audio
  750 + """
  751 + message = await socket.recv()
  752 + if message == "Done":
  753 + return None
  754 + return np.frombuffer(message, dtype=np.float32)
  755 +
  756 +
  757 +def check_args(args):
  758 + if args.encoder:
  759 + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist"
  760 + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist"
  761 + assert args.paraformer_encoder is None, args.paraformer_encoder
  762 + assert args.paraformer_decoder is None, args.paraformer_decoder
  763 +
  764 + elif args.paraformer_encoder:
  765 + assert Path(
  766 + args.paraformer_encoder
  767 + ).is_file(), f"{args.paraformer_encoder} does not exist"
  768 +
  769 + assert Path(
  770 + args.paraformer_decoder
  771 + ).is_file(), f"{args.paraformer_decoder} does not exist"
  772 + else:
  773 + raise ValueError("Please provide a model")
  774 +
  775 + if not Path(args.tokens).is_file():
  776 + raise ValueError(f"{args.tokens} does not exist")
  777 +
  778 + if args.decoding_method not in (
  779 + "greedy_search",
  780 + "modified_beam_search",
  781 + ):
  782 + raise ValueError(f"Unsupported decoding method {args.decoding_method}")
  783 +
  784 + if args.decoding_method == "modified_beam_search":
  785 + assert args.num_active_paths > 0, args.num_active_paths
  786 +
  787 +
  788 +def main():
  789 + args = get_args()
  790 + logging.info(vars(args))
  791 + check_args(args)
  792 +
  793 + first_pass_recognizer = create_first_pass_recognizer(args)
  794 + second_pass_recognizer = create_second_pass_recognizer(args)
  795 +
  796 + port = args.port
  797 + nn_pool_size = args.nn_pool_size
  798 + max_batch_size = args.max_batch_size
  799 + max_wait_ms = args.max_wait_ms
  800 + max_message_size = args.max_message_size
  801 + max_queue_size = args.max_queue_size
  802 + max_active_connections = args.max_active_connections
  803 + second_pass_threads = args.second_pass_threads
  804 + certificate = args.certificate
  805 + # doc_root = args.doc_root
  806 +
  807 + if certificate and not Path(certificate).is_file():
  808 + raise ValueError(f"{certificate} does not exist")
  809 +
  810 + server = StreamingServer(
  811 + first_pass_recognizer=first_pass_recognizer,
  812 + second_pass_recognizer=second_pass_recognizer,
  813 + nn_pool_size=nn_pool_size,
  814 + max_batch_size=max_batch_size,
  815 + max_wait_ms=max_wait_ms,
  816 + max_message_size=max_message_size,
  817 + max_queue_size=max_queue_size,
  818 + max_active_connections=max_active_connections,
  819 + second_pass_threads=second_pass_threads,
  820 + certificate=certificate,
  821 + # doc_root=doc_root,
  822 + )
  823 + asyncio.run(server.run(port))
  824 +
  825 +
  826 +if __name__ == "__main__":
  827 + log_filename = "log/log-streaming-server"
  828 + setup_logger(log_filename)
  829 + main()