Committed by
GitHub
Alex/feat add python example (#2490)
Co-authored-by: wangmh <minghu.wang@shenyiai.net>
正在显示
2 个修改的文件
包含
847 行增加
和
0 行删除
| @@ -39,6 +39,19 @@ from pathlib import Path | @@ -39,6 +39,19 @@ from pathlib import Path | ||
| 39 | 39 | ||
| 40 | import sherpa_onnx | 40 | import sherpa_onnx |
| 41 | import soundfile as sf | 41 | import soundfile as sf |
| 42 | +import librosa | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +def resample_audio(audio, sample_rate, target_sample_rate): | ||
| 46 | + """ | ||
| 47 | + Resample audio to target sample rate using librosa | ||
| 48 | + """ | ||
| 49 | + if sample_rate != target_sample_rate: | ||
| 50 | + print(f"Resampling audio from {sample_rate}Hz to {target_sample_rate}Hz...") | ||
| 51 | + audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sample_rate) | ||
| 52 | + print(f"Resampling completed. New audio shape: {audio.shape}") | ||
| 53 | + return audio, target_sample_rate | ||
| 54 | + return audio, sample_rate | ||
| 42 | 55 | ||
| 43 | 56 | ||
| 44 | def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5): | 57 | def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5): |
| @@ -97,6 +110,11 @@ def main(): | @@ -97,6 +110,11 @@ def main(): | ||
| 97 | # Since we know there are 4 speakers in the above test wave file, we use | 110 | # Since we know there are 4 speakers in the above test wave file, we use |
| 98 | # num_speakers 4 here | 111 | # num_speakers 4 here |
| 99 | sd = init_speaker_diarization(num_speakers=4) | 112 | sd = init_speaker_diarization(num_speakers=4) |
| 113 | + | ||
| 114 | + # Resample audio to match the expected sample rate | ||
| 115 | + target_sample_rate = sd.sample_rate | ||
| 116 | + audio, sample_rate = resample_audio(audio, sample_rate, target_sample_rate) | ||
| 117 | + | ||
| 100 | if sample_rate != sd.sample_rate: | 118 | if sample_rate != sd.sample_rate: |
| 101 | raise RuntimeError( | 119 | raise RuntimeError( |
| 102 | f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}" | 120 | f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}" |
python-api-examples/two-pass-wss.py
0 → 100644
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright (c) 2025 Minghu Wang | ||
| 3 | +""" | ||
| 4 | + | ||
| 5 | +A two-pass streaming ASR server with WebSocket support. This server implements | ||
| 6 | +a two-pass recognition strategy where the first pass uses a fast streaming model | ||
| 7 | +for real-time recognition, and the second pass uses a more accurate offline model | ||
| 8 | +to refine the results. | ||
| 9 | + | ||
| 10 | +The first pass provides immediate feedback to users, while the second pass | ||
| 11 | +improves accuracy by re-processing the complete utterance with a more powerful | ||
| 12 | +model. | ||
| 13 | + | ||
| 14 | +It supports multiple clients sending audio simultaneously and provides | ||
| 15 | +real-time transcription results. | ||
| 16 | + | ||
| 17 | +Usage: | ||
| 18 | + ./two-pass-wss.py --help | ||
| 19 | + | ||
| 20 | +Example: | ||
| 21 | + | ||
| 22 | +(1) Without a certificate | ||
| 23 | + | ||
| 24 | +python3 ./python-api-examples/two-pass-wss.py \ | ||
| 25 | + --paraformer-encoder ./sherpa-onnx-paraformer-zh-2023-09-18/encoder.onnx \ | ||
| 26 | + --paraformer-decoder ./sherpa-onnx-paraformer-zh-2023-09-18/decoder.onnx \ | ||
| 27 | + --tokens ./sherpa-onnx-paraformer-zh-2023-09-18/tokens.txt \ | ||
| 28 | + --second-sense-voice ./sherpa-onnx-sense-voice-zh-2023-09-18/model.onnx \ | ||
| 29 | + --second-tokens ./sherpa-onnx-sense-voice-zh-2023-09-18/tokens.txt | ||
| 30 | + | ||
| 31 | +(2) With a certificate | ||
| 32 | + | ||
| 33 | +(a) Generate a certificate first: | ||
| 34 | + | ||
| 35 | + cd python-api-examples/web | ||
| 36 | + ./generate-certificate.py | ||
| 37 | + cd ../.. | ||
| 38 | + | ||
| 39 | +(b) Start the server | ||
| 40 | + | ||
| 41 | +python3 ./python-api-examples/two-pass-wss.py \ | ||
| 42 | + --paraformer-encoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.onnx \ | ||
| 43 | + --paraformer-decoder ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.onnx \ | ||
| 44 | + --tokens ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ | ||
| 45 | + --second-sense-voice ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx \ | ||
| 46 | + --second-tokens ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \ | ||
| 47 | + --certificate ./python-api-examples/web/cert.pem | ||
| 48 | + | ||
| 49 | +Please refer to | ||
| 50 | +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 51 | +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 52 | +to download pre-trained models. | ||
| 53 | +""" | ||
| 54 | + | ||
| 55 | +import argparse | ||
| 56 | +import asyncio | ||
| 57 | +import http | ||
| 58 | +import json | ||
| 59 | +import logging | ||
| 60 | +import socket | ||
| 61 | +import ssl | ||
| 62 | +from concurrent.futures import ThreadPoolExecutor | ||
| 63 | +from datetime import datetime | ||
| 64 | +from pathlib import Path | ||
| 65 | +from typing import List, Optional, Tuple | ||
| 66 | + | ||
| 67 | +import numpy as np | ||
| 68 | +import sherpa_onnx | ||
| 69 | +import websockets | ||
| 70 | + | ||
| 71 | +def setup_logger( | ||
| 72 | + log_filename: str, | ||
| 73 | + log_level: str = "info", | ||
| 74 | + use_console: bool = True, | ||
| 75 | +) -> None: | ||
| 76 | + """Setup log level. | ||
| 77 | + | ||
| 78 | + Args: | ||
| 79 | + log_filename: | ||
| 80 | + The filename to save the log. | ||
| 81 | + log_level: | ||
| 82 | + The log level to use, e.g., "debug", "info", "warning", "error", | ||
| 83 | + "critical" | ||
| 84 | + use_console: | ||
| 85 | + True to also print logs to console. | ||
| 86 | + """ | ||
| 87 | + now = datetime.now() | ||
| 88 | + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") | ||
| 89 | + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
| 90 | + log_filename = f"{log_filename}-{date_time}.txt" | ||
| 91 | + | ||
| 92 | + Path(log_filename).parent.mkdir(parents=True, exist_ok=True) | ||
| 93 | + | ||
| 94 | + level = logging.ERROR | ||
| 95 | + if log_level == "debug": | ||
| 96 | + level = logging.DEBUG | ||
| 97 | + elif log_level == "info": | ||
| 98 | + level = logging.INFO | ||
| 99 | + elif log_level == "warning": | ||
| 100 | + level = logging.WARNING | ||
| 101 | + elif log_level == "critical": | ||
| 102 | + level = logging.CRITICAL | ||
| 103 | + | ||
| 104 | + logging.basicConfig( | ||
| 105 | + filename=log_filename, | ||
| 106 | + format=formatter, | ||
| 107 | + level=level, | ||
| 108 | + filemode="w", | ||
| 109 | + ) | ||
| 110 | + if use_console: | ||
| 111 | + console = logging.StreamHandler() | ||
| 112 | + console.setLevel(level) | ||
| 113 | + console.setFormatter(logging.Formatter(formatter)) | ||
| 114 | + logging.getLogger("").addHandler(console) | ||
| 115 | + | ||
| 116 | + | ||
| 117 | +def add_model_args(parser: argparse.ArgumentParser): | ||
| 118 | + parser.add_argument( | ||
| 119 | + "--encoder", | ||
| 120 | + type=str, | ||
| 121 | + default="", | ||
| 122 | + help="Path to the transducer encoder model", | ||
| 123 | + ) | ||
| 124 | + | ||
| 125 | + parser.add_argument( | ||
| 126 | + "--decoder", | ||
| 127 | + type=str, | ||
| 128 | + default="", | ||
| 129 | + help="Path to the transducer decoder model.", | ||
| 130 | + ) | ||
| 131 | + | ||
| 132 | + | ||
| 133 | + parser.add_argument( | ||
| 134 | + "--second-tokens", | ||
| 135 | + type=str, | ||
| 136 | + default="", | ||
| 137 | + help="Path to the second pass tokens.txt", | ||
| 138 | + ) | ||
| 139 | + | ||
| 140 | + parser.add_argument( | ||
| 141 | + "--second-sense-voice", | ||
| 142 | + type=str, | ||
| 143 | + default="", | ||
| 144 | + help="Path to the second pass sense voice model.", | ||
| 145 | + ) | ||
| 146 | + | ||
| 147 | + parser.add_argument( | ||
| 148 | + "--paraformer-encoder", | ||
| 149 | + type=str, | ||
| 150 | + default="", | ||
| 151 | + help="Path to the paraformer encoder model", | ||
| 152 | + ) | ||
| 153 | + | ||
| 154 | + parser.add_argument( | ||
| 155 | + "--paraformer-decoder", | ||
| 156 | + type=str, | ||
| 157 | + default="", | ||
| 158 | + help="Path to the paraformer decoder model.", | ||
| 159 | + ) | ||
| 160 | + | ||
| 161 | + parser.add_argument( | ||
| 162 | + "--tokens", | ||
| 163 | + type=str, | ||
| 164 | + default="", | ||
| 165 | + help="Path to tokens.txt", | ||
| 166 | + ) | ||
| 167 | + | ||
| 168 | + parser.add_argument( | ||
| 169 | + "--sample-rate", | ||
| 170 | + type=int, | ||
| 171 | + default=16000, | ||
| 172 | + help="Sample rate of the data used to train the model. " | ||
| 173 | + "Caution: If your input sound files have a different sampling rate, " | ||
| 174 | + "we will do resampling inside", | ||
| 175 | + ) | ||
| 176 | + | ||
| 177 | + parser.add_argument( | ||
| 178 | + "--feat-dim", | ||
| 179 | + type=int, | ||
| 180 | + default=80, | ||
| 181 | + help="Feature dimension of the model", | ||
| 182 | + ) | ||
| 183 | + | ||
| 184 | + parser.add_argument( | ||
| 185 | + "--provider", | ||
| 186 | + type=str, | ||
| 187 | + default="cpu", | ||
| 188 | + help="Valid values: cpu, cuda, coreml", | ||
| 189 | + ) | ||
| 190 | + | ||
| 191 | + | ||
| 192 | +def add_decoding_args(parser: argparse.ArgumentParser): | ||
| 193 | + parser.add_argument( | ||
| 194 | + "--decoding-method", | ||
| 195 | + type=str, | ||
| 196 | + default="greedy_search", | ||
| 197 | + help="""Decoding method to use. Current supported methods are: | ||
| 198 | + - greedy_search | ||
| 199 | + - modified_beam_search | ||
| 200 | + """, | ||
| 201 | + ) | ||
| 202 | + | ||
| 203 | + add_modified_beam_search_args(parser) | ||
| 204 | + | ||
| 205 | + | ||
| 206 | +def add_hotwords_args(parser: argparse.ArgumentParser): | ||
| 207 | + parser.add_argument( | ||
| 208 | + "--hotwords-file", | ||
| 209 | + type=str, | ||
| 210 | + default="", | ||
| 211 | + help=""" | ||
| 212 | + The file containing hotwords, one words/phrases per line, and for each | ||
| 213 | + phrase the bpe/cjkchar are separated by a space. For example: | ||
| 214 | + | ||
| 215 | + ▁HE LL O ▁WORLD | ||
| 216 | + 你 好 世 界 | ||
| 217 | + """, | ||
| 218 | + ) | ||
| 219 | + | ||
| 220 | + parser.add_argument( | ||
| 221 | + "--hotwords-score", | ||
| 222 | + type=float, | ||
| 223 | + default=1.5, | ||
| 224 | + help=""" | ||
| 225 | + The hotword score of each token for biasing word/phrase. Used only if | ||
| 226 | + --hotwords-file is given. | ||
| 227 | + """, | ||
| 228 | + ) | ||
| 229 | + parser.add_argument( | ||
| 230 | + "--modeling-unit", | ||
| 231 | + type=str, | ||
| 232 | + default='cjkchar', | ||
| 233 | + help=""" | ||
| 234 | + The modeling unit of the used model. Current supported units are: | ||
| 235 | + - cjkchar(for Chinese) | ||
| 236 | + - bpe(for English like languages) | ||
| 237 | + - cjkchar+bpe(for multilingual models) | ||
| 238 | + """, | ||
| 239 | + ) | ||
| 240 | + parser.add_argument( | ||
| 241 | + "--bpe-vocab", | ||
| 242 | + type=str, | ||
| 243 | + default='', | ||
| 244 | + help=""" | ||
| 245 | + The bpe vocabulary generated by sentencepiece toolkit. | ||
| 246 | + It is only used when modeling-unit is bpe or cjkchar+bpe. | ||
| 247 | + if you can't find bpe.vocab in the model directory, please run: | ||
| 248 | + python script/export_bpe_vocab.py --bpe-model exp/bpe.model | ||
| 249 | + """, | ||
| 250 | + ) | ||
| 251 | + | ||
| 252 | + | ||
| 253 | +def add_modified_beam_search_args(parser: argparse.ArgumentParser): | ||
| 254 | + parser.add_argument( | ||
| 255 | + "--num-active-paths", | ||
| 256 | + type=int, | ||
| 257 | + default=4, | ||
| 258 | + help="""Used only when --decoding-method is modified_beam_search. | ||
| 259 | + It specifies number of active paths to keep during decoding. | ||
| 260 | + """, | ||
| 261 | + ) | ||
| 262 | + | ||
| 263 | +def add_blank_penalty_args(parser: argparse.ArgumentParser): | ||
| 264 | + parser.add_argument( | ||
| 265 | + "--blank-penalty", | ||
| 266 | + type=float, | ||
| 267 | + default=0.0, | ||
| 268 | + help=""" | ||
| 269 | + The penalty applied on blank symbol during decoding. | ||
| 270 | + Note: It is a positive value that would be applied to logits like | ||
| 271 | + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is | ||
| 272 | + [batch_size, vocab] and blank id is 0). | ||
| 273 | + """, | ||
| 274 | + ) | ||
| 275 | + | ||
| 276 | +def add_endpointing_args(parser: argparse.ArgumentParser): | ||
| 277 | + parser.add_argument( | ||
| 278 | + "--rule1-min-trailing-silence", | ||
| 279 | + type=float, | ||
| 280 | + default=2.4, | ||
| 281 | + help="""This endpointing rule1 requires duration of trailing silence | ||
| 282 | + in seconds) to be >= this value""", | ||
| 283 | + ) | ||
| 284 | + | ||
| 285 | + parser.add_argument( | ||
| 286 | + "--rule2-min-trailing-silence", | ||
| 287 | + type=float, | ||
| 288 | + default=1.2, | ||
| 289 | + help="""This endpointing rule2 requires duration of trailing silence in | ||
| 290 | + seconds) to be >= this value.""", | ||
| 291 | + ) | ||
| 292 | + | ||
| 293 | + parser.add_argument( | ||
| 294 | + "--rule3-min-utterance-length", | ||
| 295 | + type=float, | ||
| 296 | + default=20, | ||
| 297 | + help="""This endpointing rule3 requires utterance-length (in seconds) | ||
| 298 | + to be >= this value.""", | ||
| 299 | + ) | ||
| 300 | + | ||
| 301 | + | ||
| 302 | +def get_args(): | ||
| 303 | + parser = argparse.ArgumentParser( | ||
| 304 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
| 305 | + ) | ||
| 306 | + | ||
| 307 | + add_model_args(parser) | ||
| 308 | + add_decoding_args(parser) | ||
| 309 | + add_endpointing_args(parser) | ||
| 310 | + add_hotwords_args(parser) | ||
| 311 | + add_blank_penalty_args(parser) | ||
| 312 | + | ||
| 313 | + parser.add_argument( | ||
| 314 | + "--port", | ||
| 315 | + type=int, | ||
| 316 | + default=6006, | ||
| 317 | + help="The server will listen on this port", | ||
| 318 | + ) | ||
| 319 | + | ||
| 320 | + parser.add_argument( | ||
| 321 | + "--nn-pool-size", | ||
| 322 | + type=int, | ||
| 323 | + default=1, | ||
| 324 | + help="Number of threads for NN computation and decoding.", | ||
| 325 | + ) | ||
| 326 | + | ||
| 327 | + parser.add_argument( | ||
| 328 | + "--max-batch-size", | ||
| 329 | + type=int, | ||
| 330 | + default=3, | ||
| 331 | + help="""Max batch size for computation. Note if there are not enough | ||
| 332 | + requests in the queue, it will wait for max_wait_ms time. After that, | ||
| 333 | + even if there are not enough requests, it still sends the | ||
| 334 | + available requests in the queue for computation. | ||
| 335 | + """, | ||
| 336 | + ) | ||
| 337 | + | ||
| 338 | + parser.add_argument( | ||
| 339 | + "--max-wait-ms", | ||
| 340 | + type=float, | ||
| 341 | + default=10, | ||
| 342 | + help="""Max time in millisecond to wait to build batches for inference. | ||
| 343 | + If there are not enough requests in the stream queue to build a batch | ||
| 344 | + of max_batch_size, it waits up to this time before fetching available | ||
| 345 | + requests for computation. | ||
| 346 | + """, | ||
| 347 | + ) | ||
| 348 | + | ||
| 349 | + parser.add_argument( | ||
| 350 | + "--max-message-size", | ||
| 351 | + type=int, | ||
| 352 | + default=(1 << 20), | ||
| 353 | + help="""Max message size in bytes. | ||
| 354 | + The max size per message cannot exceed this limit. | ||
| 355 | + """, | ||
| 356 | + ) | ||
| 357 | + | ||
| 358 | + parser.add_argument( | ||
| 359 | + "--max-queue-size", | ||
| 360 | + type=int, | ||
| 361 | + default=32, | ||
| 362 | + help="Max number of messages in the queue for each connection.", | ||
| 363 | + ) | ||
| 364 | + | ||
| 365 | + parser.add_argument( | ||
| 366 | + "--max-active-connections", | ||
| 367 | + type=int, | ||
| 368 | + default=200, | ||
| 369 | + help="""Maximum number of active connections. The server will refuse | ||
| 370 | + to accept new connections once the current number of active connections | ||
| 371 | + equals to this limit. | ||
| 372 | + """, | ||
| 373 | + ) | ||
| 374 | + | ||
| 375 | + parser.add_argument( | ||
| 376 | + "--num-threads", | ||
| 377 | + type=int, | ||
| 378 | + default=2, | ||
| 379 | + help="Number of threads to run the neural network model", | ||
| 380 | + ) | ||
| 381 | + | ||
| 382 | + parser.add_argument( | ||
| 383 | + "--second-pass-threads", | ||
| 384 | + type=int, | ||
| 385 | + default=2, | ||
| 386 | + help="Number of threads for second pass processing", | ||
| 387 | + ) | ||
| 388 | + | ||
| 389 | + parser.add_argument( | ||
| 390 | + "--certificate", | ||
| 391 | + type=str, | ||
| 392 | + help="""Path to the X.509 certificate. You need it only if you want to | ||
| 393 | + use a secure websocket connection, i.e., use wss:// instead of ws://. | ||
| 394 | + You can use ./web/generate-certificate.py | ||
| 395 | + to generate the certificate `cert.pem`. | ||
| 396 | + Note ./web/generate-certificate.py will generate three files but you | ||
| 397 | + only need to pass the generated cert.pem to this option. | ||
| 398 | + """, | ||
| 399 | + ) | ||
| 400 | + | ||
| 401 | + return parser.parse_args() | ||
| 402 | + | ||
| 403 | +def run_second_pass( | ||
| 404 | + recognizer: sherpa_onnx.OfflineRecognizer, | ||
| 405 | + samples: np.ndarray, | ||
| 406 | + sample_rate: int, | ||
| 407 | +): | ||
| 408 | + stream = recognizer.create_stream() | ||
| 409 | + stream.accept_waveform(sample_rate, samples) | ||
| 410 | + | ||
| 411 | + recognizer.decode_stream(stream) | ||
| 412 | + | ||
| 413 | + return stream.result.text | ||
| 414 | + | ||
| 415 | +def create_first_pass_recognizer(args) -> sherpa_onnx.OnlineRecognizer: | ||
| 416 | + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( | ||
| 417 | + tokens=args.tokens, | ||
| 418 | + encoder=args.paraformer_encoder, | ||
| 419 | + decoder=args.paraformer_decoder, | ||
| 420 | + num_threads=args.num_threads, | ||
| 421 | + sample_rate=args.sample_rate, | ||
| 422 | + feature_dim=args.feat_dim, | ||
| 423 | + decoding_method=args.decoding_method, | ||
| 424 | + enable_endpoint_detection=True, | ||
| 425 | + rule1_min_trailing_silence=args.rule1_min_trailing_silence, | ||
| 426 | + rule2_min_trailing_silence=args.rule2_min_trailing_silence, | ||
| 427 | + rule3_min_utterance_length=args.rule3_min_utterance_length, | ||
| 428 | + provider=args.provider, | ||
| 429 | + ) | ||
| 430 | + return recognizer | ||
| 431 | + | ||
| 432 | + | ||
| 433 | +def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 434 | + recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( | ||
| 435 | + model=args.second_sense_voice, | ||
| 436 | + tokens=args.second_tokens, | ||
| 437 | + num_threads=1, | ||
| 438 | + sample_rate=16000, | ||
| 439 | + feature_dim=80, | ||
| 440 | + use_itn=True, | ||
| 441 | + decoding_method="greedy_search", | ||
| 442 | + ) | ||
| 443 | + return recognizer | ||
| 444 | + | ||
| 445 | + | ||
| 446 | +def format_timestamps(timestamps: List[float]) -> List[str]: | ||
| 447 | + return ["{:.3f}".format(t) for t in timestamps] | ||
| 448 | + | ||
| 449 | + | ||
| 450 | +class StreamingServer(object): | ||
| 451 | + def __init__( | ||
| 452 | + self, | ||
| 453 | + first_pass_recognizer: sherpa_onnx.OnlineRecognizer, | ||
| 454 | + second_pass_recognizer: sherpa_onnx.OfflineRecognizer, | ||
| 455 | + nn_pool_size: int, | ||
| 456 | + max_wait_ms: float, | ||
| 457 | + max_batch_size: int, | ||
| 458 | + max_message_size: int, | ||
| 459 | + max_queue_size: int, | ||
| 460 | + max_active_connections: int, | ||
| 461 | + second_pass_threads: int = 2, | ||
| 462 | + certificate: Optional[str] = None, | ||
| 463 | + ): | ||
| 464 | + """ | ||
| 465 | + Args: | ||
| 466 | + first_pass_recognizer: | ||
| 467 | + An instance of online recognizer for first pass. | ||
| 468 | + second_pass_recognizer: | ||
| 469 | + An instance of offline recognizer for second pass. | ||
| 470 | + nn_pool_size: | ||
| 471 | + Number of threads for the thread pool that is responsible for | ||
| 472 | + neural network computation and decoding. | ||
| 473 | + max_wait_ms: | ||
| 474 | + Max wait time in milliseconds in order to build a batch of | ||
| 475 | + `batch_size`. | ||
| 476 | + max_batch_size: | ||
| 477 | + Max batch size for inference. | ||
| 478 | + max_message_size: | ||
| 479 | + Max size in bytes per message. | ||
| 480 | + max_queue_size: | ||
| 481 | + Max number of messages in the queue for each connection. | ||
| 482 | + max_active_connections: | ||
| 483 | + Max number of active connections. Once number of active client | ||
| 484 | + equals to this limit, the server refuses to accept new connections. | ||
| 485 | + certificate: | ||
| 486 | + Optional. If not None, it will use secure websocket. | ||
| 487 | + You can use ./web/generate-certificate.py to generate | ||
| 488 | + it (the default generated filename is `cert.pem`). | ||
| 489 | + """ | ||
| 490 | + self.first_pass_recognizer = first_pass_recognizer | ||
| 491 | + self.second_pass_recognizer = second_pass_recognizer | ||
| 492 | + | ||
| 493 | + self.certificate = certificate | ||
| 494 | + | ||
| 495 | + self.nn_pool_size = nn_pool_size | ||
| 496 | + self.nn_pool = ThreadPoolExecutor( | ||
| 497 | + max_workers=nn_pool_size, | ||
| 498 | + thread_name_prefix="nn", | ||
| 499 | + ) | ||
| 500 | + | ||
| 501 | + self.second_pass_pool = ThreadPoolExecutor( | ||
| 502 | + max_workers=second_pass_threads, | ||
| 503 | + thread_name_prefix="second_pass", | ||
| 504 | + ) | ||
| 505 | + | ||
| 506 | + self.stream_queue = asyncio.Queue() | ||
| 507 | + | ||
| 508 | + self.max_wait_ms = max_wait_ms | ||
| 509 | + self.max_batch_size = max_batch_size | ||
| 510 | + self.max_message_size = max_message_size | ||
| 511 | + self.max_queue_size = max_queue_size | ||
| 512 | + self.max_active_connections = max_active_connections | ||
| 513 | + | ||
| 514 | + self.current_active_connections = 0 | ||
| 515 | + | ||
| 516 | + self.sample_rate = int(self.first_pass_recognizer.config.feat_config.sampling_rate) | ||
| 517 | + | ||
| 518 | + async def stream_consumer_task(self): | ||
| 519 | + """This function extracts streams from the queue, batches them up, sends | ||
| 520 | + them to the neural network model for computation and decoding. | ||
| 521 | + """ | ||
| 522 | + while True: | ||
| 523 | + if self.stream_queue.empty(): | ||
| 524 | + await asyncio.sleep(self.max_wait_ms / 1000) | ||
| 525 | + continue | ||
| 526 | + | ||
| 527 | + batch = [] | ||
| 528 | + try: | ||
| 529 | + while len(batch) < self.max_batch_size: | ||
| 530 | + item = self.stream_queue.get_nowait() | ||
| 531 | + | ||
| 532 | + assert self.first_pass_recognizer.is_ready(item[0]) | ||
| 533 | + | ||
| 534 | + batch.append(item) | ||
| 535 | + except asyncio.QueueEmpty: | ||
| 536 | + pass | ||
| 537 | + stream_list = [b[0] for b in batch] | ||
| 538 | + future_list = [b[1] for b in batch] | ||
| 539 | + | ||
| 540 | + loop = asyncio.get_running_loop() | ||
| 541 | + await loop.run_in_executor( | ||
| 542 | + self.nn_pool, | ||
| 543 | + self.first_pass_recognizer.decode_streams, | ||
| 544 | + stream_list, | ||
| 545 | + ) | ||
| 546 | + | ||
| 547 | + for f in future_list: | ||
| 548 | + self.stream_queue.task_done() | ||
| 549 | + f.set_result(None) | ||
| 550 | + | ||
| 551 | + async def compute_and_decode( | ||
| 552 | + self, | ||
| 553 | + stream: sherpa_onnx.OnlineStream, | ||
| 554 | + ) -> None: | ||
| 555 | + """Put the stream into the queue and wait it to be processed by the | ||
| 556 | + consumer task. | ||
| 557 | + | ||
| 558 | + Args: | ||
| 559 | + stream: | ||
| 560 | + The stream to be processed. Note: It is changed in-place. | ||
| 561 | + """ | ||
| 562 | + loop = asyncio.get_running_loop() | ||
| 563 | + future = loop.create_future() | ||
| 564 | + await self.stream_queue.put((stream, future)) | ||
| 565 | + await future | ||
| 566 | + | ||
| 567 | + async def run_second_pass_async( | ||
| 568 | + self, | ||
| 569 | + samples: np.ndarray, | ||
| 570 | + sample_rate: int, | ||
| 571 | + ) -> str: | ||
| 572 | + """Run second-pass recognition asynchronously to avoid blocking. | ||
| 573 | + | ||
| 574 | + Args: | ||
| 575 | + samples: Audio samples. | ||
| 576 | + sample_rate: Sampling rate. | ||
| 577 | + | ||
| 578 | + Returns: | ||
| 579 | + Text result from the second-pass recognition. | ||
| 580 | + """ | ||
| 581 | + import time | ||
| 582 | + start_time = time.time() | ||
| 583 | + | ||
| 584 | + loop = asyncio.get_running_loop() | ||
| 585 | + result = await loop.run_in_executor( | ||
| 586 | + self.second_pass_pool, | ||
| 587 | + run_second_pass, | ||
| 588 | + self.second_pass_recognizer, | ||
| 589 | + samples, | ||
| 590 | + sample_rate, | ||
| 591 | + ) | ||
| 592 | + | ||
| 593 | + end_time = time.time() | ||
| 594 | + duration = end_time - start_time | ||
| 595 | + logging.info(f"Second pass processing completed in {duration:.3f}s for {len(samples)/sample_rate:.2f}s audio") | ||
| 596 | + | ||
| 597 | + return result.lower().strip() | ||
| 598 | + | ||
| 599 | + async def process_request( | ||
| 600 | + self, | ||
| 601 | + path: str, | ||
| 602 | + request_headers: websockets.Headers, | ||
| 603 | + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: | ||
| 604 | + if self.current_active_connections < self.max_active_connections: | ||
| 605 | + self.current_active_connections += 1 | ||
| 606 | + return None | ||
| 607 | + | ||
| 608 | + # Refuse new connections | ||
| 609 | + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503 | ||
| 610 | + header = {"Hint": "The server is overloaded. Please retry later."} | ||
| 611 | + response = b"The server is busy. Please retry later." | ||
| 612 | + | ||
| 613 | + return status, header, response | ||
| 614 | + | ||
| 615 | + async def run(self, port: int): | ||
| 616 | + tasks = [] | ||
| 617 | + for i in range(self.nn_pool_size): | ||
| 618 | + tasks.append(asyncio.create_task(self.stream_consumer_task())) | ||
| 619 | + | ||
| 620 | + if self.certificate: | ||
| 621 | + logging.info(f"Using certificate: {self.certificate}") | ||
| 622 | + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | ||
| 623 | + ssl_context.load_cert_chain(self.certificate) | ||
| 624 | + else: | ||
| 625 | + ssl_context = None | ||
| 626 | + logging.info("No certificate provided") | ||
| 627 | + | ||
| 628 | + try: | ||
| 629 | + async with websockets.serve( | ||
| 630 | + self.handle_connection, | ||
| 631 | + host="", | ||
| 632 | + port=port, | ||
| 633 | + max_size=self.max_message_size, | ||
| 634 | + max_queue=self.max_queue_size, | ||
| 635 | + process_request=self.process_request, | ||
| 636 | + ssl=ssl_context, | ||
| 637 | + ): | ||
| 638 | + logging.info(f"Started server on port {port}") | ||
| 639 | + await asyncio.Future() # run forever | ||
| 640 | + finally: | ||
| 641 | + logging.info("Shutting down thread pools...") | ||
| 642 | + self.nn_pool.shutdown(wait=True) | ||
| 643 | + self.second_pass_pool.shutdown(wait=True) | ||
| 644 | + logging.info("Thread pools shut down successfully") | ||
| 645 | + | ||
| 646 | + await asyncio.gather(*tasks) # not reachable | ||
| 647 | + | ||
| 648 | + async def handle_connection( | ||
| 649 | + self, | ||
| 650 | + socket: websockets.WebSocketServerProtocol, | ||
| 651 | + ): | ||
| 652 | + """Receive audio samples from the client, process it, and send | ||
| 653 | + decoding result back to the client. | ||
| 654 | + | ||
| 655 | + Args: | ||
| 656 | + socket: | ||
| 657 | + The socket for communicating with the client. | ||
| 658 | + """ | ||
| 659 | + try: | ||
| 660 | + await self.handle_connection_impl(socket) | ||
| 661 | + except websockets.exceptions.ConnectionClosed: | ||
| 662 | + logging.info(f"{socket.remote_address} disconnected") | ||
| 663 | + finally: | ||
| 664 | + # Decrement so that it can accept new connections | ||
| 665 | + self.current_active_connections -= 1 | ||
| 666 | + | ||
| 667 | + logging.info( | ||
| 668 | + f"Disconnected: {socket.remote_address}. " | ||
| 669 | + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa | ||
| 670 | + ) | ||
| 671 | + | ||
| 672 | + async def handle_connection_impl( | ||
| 673 | + self, | ||
| 674 | + socket: websockets.WebSocketServerProtocol, | ||
| 675 | + ): | ||
| 676 | + """Receive audio samples from the client, process it, and send | ||
| 677 | + decoding result back to the client. | ||
| 678 | + | ||
| 679 | + Args: | ||
| 680 | + socket: | ||
| 681 | + The socket for communicating with the client. | ||
| 682 | + """ | ||
| 683 | + stream = self.first_pass_recognizer.create_stream() | ||
| 684 | + segment = 0 | ||
| 685 | + sample_buffers = [] | ||
| 686 | + while True: | ||
| 687 | + samples = await self.recv_audio_samples(socket) | ||
| 688 | + if samples is None: | ||
| 689 | + break | ||
| 690 | + | ||
| 691 | + # TODO(fangjun): At present, we assume the sampling rate | ||
| 692 | + # of the received audio samples equal to --sample-rate | ||
| 693 | + stream.accept_waveform(sample_rate=self.sample_rate, waveform=samples) | ||
| 694 | + sample_buffers.append(samples) | ||
| 695 | + while self.first_pass_recognizer.is_ready(stream): | ||
| 696 | + await self.compute_and_decode(stream) | ||
| 697 | + result = self.first_pass_recognizer.get_result(stream) | ||
| 698 | + | ||
| 699 | + message = { | ||
| 700 | + "text": result, | ||
| 701 | + "segment": segment, | ||
| 702 | + } | ||
| 703 | + if self.first_pass_recognizer.is_endpoint(stream): | ||
| 704 | + if result: | ||
| 705 | + samples_for_2nd_pass = np.concatenate(sample_buffers) | ||
| 706 | + sample_buffers = [samples_for_2nd_pass[-8000:]] | ||
| 707 | + samples_for_2nd_pass = samples_for_2nd_pass[:-8000] | ||
| 708 | + second_pass_result = ( | ||
| 709 | + await self.run_second_pass_async( | ||
| 710 | + samples=samples_for_2nd_pass, | ||
| 711 | + sample_rate=self.sample_rate, | ||
| 712 | + ) | ||
| 713 | + ) | ||
| 714 | + | ||
| 715 | + if second_pass_result: | ||
| 716 | + message["text"] = second_pass_result | ||
| 717 | + message["segment"] = segment | ||
| 718 | + else: | ||
| 719 | + sample_buffers=[] | ||
| 720 | + | ||
| 721 | + self.first_pass_recognizer.reset(stream) | ||
| 722 | + segment += 1 | ||
| 723 | + await socket.send(json.dumps(message)) | ||
| 724 | + | ||
| 725 | + tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) | ||
| 726 | + stream.accept_waveform(sample_rate=self.sample_rate, waveform=tail_padding) | ||
| 727 | + stream.input_finished() | ||
| 728 | + while self.first_pass_recognizer.is_ready(stream): | ||
| 729 | + await self.compute_and_decode(stream) | ||
| 730 | + | ||
| 731 | + result = self.first_pass_recognizer.get_result(stream) | ||
| 732 | + | ||
| 733 | + message = { | ||
| 734 | + "text": result, | ||
| 735 | + "segment": segment, | ||
| 736 | + } | ||
| 737 | + await socket.send(json.dumps(message)) | ||
| 738 | + | ||
| 739 | + async def recv_audio_samples( | ||
| 740 | + self, | ||
| 741 | + socket: websockets.WebSocketServerProtocol, | ||
| 742 | + ) -> Optional[np.ndarray]: | ||
| 743 | + """Receive audio samples from WebSocket connection | ||
| 744 | + | ||
| 745 | + Args: | ||
| 746 | + socket: WebSocket connection | ||
| 747 | + | ||
| 748 | + Returns: | ||
| 749 | + Numpy array containing audio samples, or None indicating end of audio | ||
| 750 | + """ | ||
| 751 | + message = await socket.recv() | ||
| 752 | + if message == "Done": | ||
| 753 | + return None | ||
| 754 | + return np.frombuffer(message, dtype=np.float32) | ||
| 755 | + | ||
| 756 | + | ||
| 757 | +def check_args(args): | ||
| 758 | + if args.encoder: | ||
| 759 | + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist" | ||
| 760 | + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist" | ||
| 761 | + assert args.paraformer_encoder is None, args.paraformer_encoder | ||
| 762 | + assert args.paraformer_decoder is None, args.paraformer_decoder | ||
| 763 | + | ||
| 764 | + elif args.paraformer_encoder: | ||
| 765 | + assert Path( | ||
| 766 | + args.paraformer_encoder | ||
| 767 | + ).is_file(), f"{args.paraformer_encoder} does not exist" | ||
| 768 | + | ||
| 769 | + assert Path( | ||
| 770 | + args.paraformer_decoder | ||
| 771 | + ).is_file(), f"{args.paraformer_decoder} does not exist" | ||
| 772 | + else: | ||
| 773 | + raise ValueError("Please provide a model") | ||
| 774 | + | ||
| 775 | + if not Path(args.tokens).is_file(): | ||
| 776 | + raise ValueError(f"{args.tokens} does not exist") | ||
| 777 | + | ||
| 778 | + if args.decoding_method not in ( | ||
| 779 | + "greedy_search", | ||
| 780 | + "modified_beam_search", | ||
| 781 | + ): | ||
| 782 | + raise ValueError(f"Unsupported decoding method {args.decoding_method}") | ||
| 783 | + | ||
| 784 | + if args.decoding_method == "modified_beam_search": | ||
| 785 | + assert args.num_active_paths > 0, args.num_active_paths | ||
| 786 | + | ||
| 787 | + | ||
| 788 | +def main(): | ||
| 789 | + args = get_args() | ||
| 790 | + logging.info(vars(args)) | ||
| 791 | + check_args(args) | ||
| 792 | + | ||
| 793 | + first_pass_recognizer = create_first_pass_recognizer(args) | ||
| 794 | + second_pass_recognizer = create_second_pass_recognizer(args) | ||
| 795 | + | ||
| 796 | + port = args.port | ||
| 797 | + nn_pool_size = args.nn_pool_size | ||
| 798 | + max_batch_size = args.max_batch_size | ||
| 799 | + max_wait_ms = args.max_wait_ms | ||
| 800 | + max_message_size = args.max_message_size | ||
| 801 | + max_queue_size = args.max_queue_size | ||
| 802 | + max_active_connections = args.max_active_connections | ||
| 803 | + second_pass_threads = args.second_pass_threads | ||
| 804 | + certificate = args.certificate | ||
| 805 | + # doc_root = args.doc_root | ||
| 806 | + | ||
| 807 | + if certificate and not Path(certificate).is_file(): | ||
| 808 | + raise ValueError(f"{certificate} does not exist") | ||
| 809 | + | ||
| 810 | + server = StreamingServer( | ||
| 811 | + first_pass_recognizer=first_pass_recognizer, | ||
| 812 | + second_pass_recognizer=second_pass_recognizer, | ||
| 813 | + nn_pool_size=nn_pool_size, | ||
| 814 | + max_batch_size=max_batch_size, | ||
| 815 | + max_wait_ms=max_wait_ms, | ||
| 816 | + max_message_size=max_message_size, | ||
| 817 | + max_queue_size=max_queue_size, | ||
| 818 | + max_active_connections=max_active_connections, | ||
| 819 | + second_pass_threads=second_pass_threads, | ||
| 820 | + certificate=certificate, | ||
| 821 | + # doc_root=doc_root, | ||
| 822 | + ) | ||
| 823 | + asyncio.run(server.run(port)) | ||
| 824 | + | ||
| 825 | + | ||
| 826 | +if __name__ == "__main__": | ||
| 827 | + log_filename = "log/log-streaming-server" | ||
| 828 | + setup_logger(log_filename) | ||
| 829 | + main() |
-
请 注册 或 登录 后发表评论