Committed by
GitHub
add offline websocket server/client (#98)
正在显示
15 个修改的文件
包含
1036 行增加
和
63 行删除
| @@ -30,9 +30,12 @@ ls -lh | @@ -30,9 +30,12 @@ ls -lh | ||
| 30 | 30 | ||
| 31 | ls -lh $repo | 31 | ls -lh $repo |
| 32 | 32 | ||
| 33 | -python3 ./python-api-examples/decode-file.py \ | 33 | +python3 ./python-api-examples/online-decode-files.py \ |
| 34 | --tokens=$repo/tokens.txt \ | 34 | --tokens=$repo/tokens.txt \ |
| 35 | --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | 35 | --encoder=$repo/encoder-epoch-99-avg-1.onnx \ |
| 36 | --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | 36 | --decoder=$repo/decoder-epoch-99-avg-1.onnx \ |
| 37 | --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | 37 | --joiner=$repo/joiner-epoch-99-avg-1.onnx \ |
| 38 | - --wave-filename=$repo/test_wavs/4.wav | 38 | + $repo/test_wavs/0.wav \ |
| 39 | + $repo/test_wavs/1.wav \ | ||
| 40 | + $repo/test_wavs/2.wav \ | ||
| 41 | + $repo/test_wavs/3.wav |
| @@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example | @@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example | ||
| 45 | run-sherpa-onnx-offline-paraformer.sh | 45 | run-sherpa-onnx-offline-paraformer.sh |
| 46 | run-sherpa-onnx-offline-transducer.sh | 46 | run-sherpa-onnx-offline-transducer.sh |
| 47 | sherpa-onnx-paraformer-zh-2023-03-28 | 47 | sherpa-onnx-paraformer-zh-2023-03-28 |
| 48 | +run-offline-websocket-server-paraformer.sh |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +A websocket client for sherpa-onnx-offline-websocket-server | ||
| 7 | + | ||
| 8 | +This file shows how to transcribe multiple | ||
| 9 | +files in parallel. We create a separate connection for transcribing each file. | ||
| 10 | + | ||
| 11 | +Usage: | ||
| 12 | + ./offline-websocket-client-decode-files-parallel.py \ | ||
| 13 | + --server-addr localhost \ | ||
| 14 | + --server-port 6006 \ | ||
| 15 | + /path/to/foo.wav \ | ||
| 16 | + /path/to/bar.wav \ | ||
| 17 | + /path/to/16kHz.wav \ | ||
| 18 | + /path/to/8kHz.wav | ||
| 19 | + | ||
| 20 | +(Note: You have to first start the server before starting the client) | ||
| 21 | + | ||
| 22 | +You can find the server at | ||
| 23 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc | ||
| 24 | + | ||
| 25 | +Note: The server is implemented in C++. | ||
| 26 | +""" | ||
| 27 | + | ||
| 28 | +import argparse | ||
| 29 | +import asyncio | ||
| 30 | +import logging | ||
| 31 | +import wave | ||
| 32 | +from typing import Tuple | ||
| 33 | + | ||
| 34 | +try: | ||
| 35 | + import websockets | ||
| 36 | +except ImportError: | ||
| 37 | + print("please run:") | ||
| 38 | + print("") | ||
| 39 | + print(" pip install websockets") | ||
| 40 | + print("") | ||
| 41 | + print("before you run this script") | ||
| 42 | + print("") | ||
| 43 | + | ||
| 44 | +import numpy as np | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def get_args(): | ||
| 48 | + parser = argparse.ArgumentParser( | ||
| 49 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 50 | + ) | ||
| 51 | + | ||
| 52 | + parser.add_argument( | ||
| 53 | + "--server-addr", | ||
| 54 | + type=str, | ||
| 55 | + default="localhost", | ||
| 56 | + help="Address of the server", | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + parser.add_argument( | ||
| 60 | + "--server-port", | ||
| 61 | + type=int, | ||
| 62 | + default=6006, | ||
| 63 | + help="Port of the server", | ||
| 64 | + ) | ||
| 65 | + | ||
| 66 | + parser.add_argument( | ||
| 67 | + "sound_files", | ||
| 68 | + type=str, | ||
| 69 | + nargs="+", | ||
| 70 | + help="The input sound file(s) to decode. Each file must be of WAVE" | ||
| 71 | + "format with a single channel, and each sample has 16-bit, " | ||
| 72 | + "i.e., int16_t. " | ||
| 73 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 74 | + "be 16 kHz", | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + return parser.parse_args() | ||
| 78 | + | ||
| 79 | + | ||
| 80 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 81 | + """ | ||
| 82 | + Args: | ||
| 83 | + wave_filename: | ||
| 84 | + Path to a wave file. It should be single channel and each sample should | ||
| 85 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 86 | + Returns: | ||
| 87 | + Return a tuple containing: | ||
| 88 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 89 | + normalized to the range [-1, 1]. | ||
| 90 | + - sample rate of the wave file | ||
| 91 | + """ | ||
| 92 | + | ||
| 93 | + with wave.open(wave_filename) as f: | ||
| 94 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 95 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 96 | + num_samples = f.getnframes() | ||
| 97 | + samples = f.readframes(num_samples) | ||
| 98 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 99 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 100 | + | ||
| 101 | + samples_float32 = samples_float32 / 32768 | ||
| 102 | + return samples_float32, f.getframerate() | ||
| 103 | + | ||
| 104 | + | ||
| 105 | +async def run( | ||
| 106 | + server_addr: str, | ||
| 107 | + server_port: int, | ||
| 108 | + wave_filename: str, | ||
| 109 | +): | ||
| 110 | + async with websockets.connect( | ||
| 111 | + f"ws://{server_addr}:{server_port}" | ||
| 112 | + ) as websocket: # noqa | ||
| 113 | + logging.info(f"Sending {wave_filename}") | ||
| 114 | + samples, sample_rate = read_wave(wave_filename) | ||
| 115 | + assert isinstance(sample_rate, int) | ||
| 116 | + assert samples.dtype == np.float32, samples.dtype | ||
| 117 | + assert samples.ndim == 1, samples.dim | ||
| 118 | + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes | ||
| 119 | + buf += (samples.size * 4).to_bytes(4, byteorder="little") | ||
| 120 | + buf += samples.tobytes() | ||
| 121 | + | ||
| 122 | + await websocket.send(buf) | ||
| 123 | + | ||
| 124 | + decoding_results = await websocket.recv() | ||
| 125 | + logging.info(f"{wave_filename}\n{decoding_results}") | ||
| 126 | + | ||
| 127 | + # to signal that the client has sent all the data | ||
| 128 | + await websocket.send("Done") | ||
| 129 | + | ||
| 130 | + | ||
| 131 | +async def main(): | ||
| 132 | + args = get_args() | ||
| 133 | + logging.info(vars(args)) | ||
| 134 | + | ||
| 135 | + server_addr = args.server_addr | ||
| 136 | + server_port = args.server_port | ||
| 137 | + sound_files = args.sound_files | ||
| 138 | + | ||
| 139 | + all_tasks = [] | ||
| 140 | + for wave_filename in sound_files: | ||
| 141 | + task = asyncio.create_task( | ||
| 142 | + run( | ||
| 143 | + server_addr=server_addr, | ||
| 144 | + server_port=server_port, | ||
| 145 | + wave_filename=wave_filename, | ||
| 146 | + ) | ||
| 147 | + ) | ||
| 148 | + all_tasks.append(task) | ||
| 149 | + | ||
| 150 | + await asyncio.gather(*all_tasks) | ||
| 151 | + | ||
| 152 | + | ||
| 153 | +if __name__ == "__main__": | ||
| 154 | + formatter = ( | ||
| 155 | + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa | ||
| 156 | + ) | ||
| 157 | + logging.basicConfig(format=formatter, level=logging.INFO) | ||
| 158 | + asyncio.run(main()) |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +A websocket client for sherpa-onnx-offline-websocket-server | ||
| 7 | + | ||
| 8 | +This file shows how to use a single connection to transcribe multiple | ||
| 9 | +files sequentially. | ||
| 10 | + | ||
| 11 | +Usage: | ||
| 12 | + ./offline-websocket-client-decode-files-sequential.py \ | ||
| 13 | + --server-addr localhost \ | ||
| 14 | + --server-port 6006 \ | ||
| 15 | + /path/to/foo.wav \ | ||
| 16 | + /path/to/bar.wav \ | ||
| 17 | + /path/to/16kHz.wav \ | ||
| 18 | + /path/to/8kHz.wav | ||
| 19 | + | ||
| 20 | +(Note: You have to first start the server before starting the client) | ||
| 21 | + | ||
| 22 | +You can find the server at | ||
| 23 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc | ||
| 24 | + | ||
| 25 | +Note: The server is implemented in C++. | ||
| 26 | +""" | ||
| 27 | + | ||
| 28 | +import argparse | ||
| 29 | +import asyncio | ||
| 30 | +import logging | ||
| 31 | +import wave | ||
| 32 | +from typing import List, Tuple | ||
| 33 | + | ||
| 34 | +try: | ||
| 35 | + import websockets | ||
| 36 | +except ImportError: | ||
| 37 | + print("please run:") | ||
| 38 | + print("") | ||
| 39 | + print(" pip install websockets") | ||
| 40 | + print("") | ||
| 41 | + print("before you run this script") | ||
| 42 | + print("") | ||
| 43 | + | ||
| 44 | +import numpy as np | ||
| 45 | + | ||
| 46 | + | ||
| 47 | +def get_args(): | ||
| 48 | + parser = argparse.ArgumentParser( | ||
| 49 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 50 | + ) | ||
| 51 | + | ||
| 52 | + parser.add_argument( | ||
| 53 | + "--server-addr", | ||
| 54 | + type=str, | ||
| 55 | + default="localhost", | ||
| 56 | + help="Address of the server", | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + parser.add_argument( | ||
| 60 | + "--server-port", | ||
| 61 | + type=int, | ||
| 62 | + default=6006, | ||
| 63 | + help="Port of the server", | ||
| 64 | + ) | ||
| 65 | + | ||
| 66 | + parser.add_argument( | ||
| 67 | + "sound_files", | ||
| 68 | + type=str, | ||
| 69 | + nargs="+", | ||
| 70 | + help="The input sound file(s) to decode. Each file must be of WAVE" | ||
| 71 | + "format with a single channel, and each sample has 16-bit, " | ||
| 72 | + "i.e., int16_t. " | ||
| 73 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 74 | + "be 16 kHz", | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + return parser.parse_args() | ||
| 78 | + | ||
| 79 | + | ||
| 80 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 81 | + """ | ||
| 82 | + Args: | ||
| 83 | + wave_filename: | ||
| 84 | + Path to a wave file. It should be single channel and each sample should | ||
| 85 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 86 | + Returns: | ||
| 87 | + Return a tuple containing: | ||
| 88 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 89 | + normalized to the range [-1, 1]. | ||
| 90 | + - sample rate of the wave file | ||
| 91 | + """ | ||
| 92 | + | ||
| 93 | + with wave.open(wave_filename) as f: | ||
| 94 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 95 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 96 | + num_samples = f.getnframes() | ||
| 97 | + samples = f.readframes(num_samples) | ||
| 98 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 99 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 100 | + | ||
| 101 | + samples_float32 = samples_float32 / 32768 | ||
| 102 | + return samples_float32, f.getframerate() | ||
| 103 | + | ||
| 104 | + | ||
| 105 | +async def run( | ||
| 106 | + server_addr: str, | ||
| 107 | + server_port: int, | ||
| 108 | + sound_files: List[str], | ||
| 109 | +): | ||
| 110 | + async with websockets.connect( | ||
| 111 | + f"ws://{server_addr}:{server_port}" | ||
| 112 | + ) as websocket: # noqa | ||
| 113 | + for wave_filename in sound_files: | ||
| 114 | + logging.info(f"Sending {wave_filename}") | ||
| 115 | + samples, sample_rate = read_wave(wave_filename) | ||
| 116 | + assert isinstance(sample_rate, int) | ||
| 117 | + assert samples.dtype == np.float32, samples.dtype | ||
| 118 | + assert samples.ndim == 1, samples.dim | ||
| 119 | + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes | ||
| 120 | + buf += (samples.size * 4).to_bytes(4, byteorder="little") | ||
| 121 | + buf += samples.tobytes() | ||
| 122 | + | ||
| 123 | + await websocket.send(buf) | ||
| 124 | + | ||
| 125 | + decoding_results = await websocket.recv() | ||
| 126 | + print(decoding_results) | ||
| 127 | + | ||
| 128 | + # to signal that the client has sent all the data | ||
| 129 | + await websocket.send("Done") | ||
| 130 | + | ||
| 131 | + | ||
| 132 | +async def main(): | ||
| 133 | + args = get_args() | ||
| 134 | + logging.info(vars(args)) | ||
| 135 | + | ||
| 136 | + server_addr = args.server_addr | ||
| 137 | + server_port = args.server_port | ||
| 138 | + sound_files = args.sound_files | ||
| 139 | + | ||
| 140 | + await run( | ||
| 141 | + server_addr=server_addr, | ||
| 142 | + server_port=server_port, | ||
| 143 | + sound_files=sound_files, | ||
| 144 | + ) | ||
| 145 | + | ||
| 146 | + | ||
| 147 | +if __name__ == "__main__": | ||
| 148 | + formatter = ( | ||
| 149 | + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa | ||
| 150 | + ) | ||
| 151 | + logging.basicConfig(format=formatter, level=logging.INFO) | ||
| 152 | + asyncio.run(main()) |
| 1 | #!/usr/bin/env python3 | 1 | #!/usr/bin/env python3 |
| 2 | 2 | ||
| 3 | """ | 3 | """ |
| 4 | -This file demonstrates how to use sherpa-onnx Python API to recognize | ||
| 5 | -a single file. | 4 | +This file demonstrates how to use sherpa-onnx Python API to transcribe |
| 5 | +file(s) with a streaming model. | ||
| 6 | + | ||
| 7 | +Usage: | ||
| 8 | + ./online-decode-files.py \ | ||
| 9 | + /path/to/foo.wav \ | ||
| 10 | + /path/to/bar.wav \ | ||
| 11 | + /path/to/16kHz.wav \ | ||
| 12 | + /path/to/8kHz.wav | ||
| 6 | 13 | ||
| 7 | Please refer to | 14 | Please refer to |
| 8 | https://k2-fsa.github.io/sherpa/onnx/index.html | 15 | https://k2-fsa.github.io/sherpa/onnx/index.html |
| @@ -13,17 +20,12 @@ import argparse | @@ -13,17 +20,12 @@ import argparse | ||
| 13 | import time | 20 | import time |
| 14 | import wave | 21 | import wave |
| 15 | from pathlib import Path | 22 | from pathlib import Path |
| 23 | +from typing import Tuple | ||
| 16 | 24 | ||
| 17 | import numpy as np | 25 | import numpy as np |
| 18 | import sherpa_onnx | 26 | import sherpa_onnx |
| 19 | 27 | ||
| 20 | 28 | ||
| 21 | -def assert_file_exists(filename: str): | ||
| 22 | - assert Path( | ||
| 23 | - filename | ||
| 24 | - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 25 | - | ||
| 26 | - | ||
| 27 | def get_args(): | 29 | def get_args(): |
| 28 | parser = argparse.ArgumentParser( | 30 | parser = argparse.ArgumentParser( |
| 29 | formatter_class=argparse.ArgumentDefaultsHelpFormatter | 31 | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| @@ -68,26 +70,58 @@ def get_args(): | @@ -68,26 +70,58 @@ def get_args(): | ||
| 68 | ) | 70 | ) |
| 69 | 71 | ||
| 70 | parser.add_argument( | 72 | parser.add_argument( |
| 71 | - "--wave-filename", | 73 | + "sound_files", |
| 72 | type=str, | 74 | type=str, |
| 73 | - help="""Path to the wave filename. | ||
| 74 | - Should have a single channel with 16-bit samples. | ||
| 75 | - It does not need to be 16kHz. It can have any sampling rate. | ||
| 76 | - """, | 75 | + nargs="+", |
| 76 | + help="The input sound file(s) to decode. Each file must be of WAVE" | ||
| 77 | + "format with a single channel, and each sample has 16-bit, " | ||
| 78 | + "i.e., int16_t. " | ||
| 79 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 80 | + "be 16 kHz", | ||
| 77 | ) | 81 | ) |
| 78 | 82 | ||
| 79 | return parser.parse_args() | 83 | return parser.parse_args() |
| 80 | 84 | ||
| 81 | 85 | ||
| 86 | +def assert_file_exists(filename: str): | ||
| 87 | + assert Path(filename).is_file(), ( | ||
| 88 | + f"{filename} does not exist!\n" | ||
| 89 | + "Please refer to " | ||
| 90 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 91 | + ) | ||
| 92 | + | ||
| 93 | + | ||
| 94 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 95 | + """ | ||
| 96 | + Args: | ||
| 97 | + wave_filename: | ||
| 98 | + Path to a wave file. It should be single channel and each sample should | ||
| 99 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 100 | + Returns: | ||
| 101 | + Return a tuple containing: | ||
| 102 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 103 | + normalized to the range [-1, 1]. | ||
| 104 | + - sample rate of the wave file | ||
| 105 | + """ | ||
| 106 | + | ||
| 107 | + with wave.open(wave_filename) as f: | ||
| 108 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 109 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 110 | + num_samples = f.getnframes() | ||
| 111 | + samples = f.readframes(num_samples) | ||
| 112 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 113 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 114 | + | ||
| 115 | + samples_float32 = samples_float32 / 32768 | ||
| 116 | + return samples_float32, f.getframerate() | ||
| 117 | + | ||
| 118 | + | ||
| 82 | def main(): | 119 | def main(): |
| 83 | args = get_args() | 120 | args = get_args() |
| 84 | assert_file_exists(args.encoder) | 121 | assert_file_exists(args.encoder) |
| 85 | assert_file_exists(args.decoder) | 122 | assert_file_exists(args.decoder) |
| 86 | assert_file_exists(args.joiner) | 123 | assert_file_exists(args.joiner) |
| 87 | assert_file_exists(args.tokens) | 124 | assert_file_exists(args.tokens) |
| 88 | - if not Path(args.wave_filename).is_file(): | ||
| 89 | - print(f"{args.wave_filename} does not exist!") | ||
| 90 | - return | ||
| 91 | 125 | ||
| 92 | recognizer = sherpa_onnx.OnlineRecognizer( | 126 | recognizer = sherpa_onnx.OnlineRecognizer( |
| 93 | tokens=args.tokens, | 127 | tokens=args.tokens, |
| @@ -99,42 +133,44 @@ def main(): | @@ -99,42 +133,44 @@ def main(): | ||
| 99 | feature_dim=80, | 133 | feature_dim=80, |
| 100 | decoding_method=args.decoding_method, | 134 | decoding_method=args.decoding_method, |
| 101 | ) | 135 | ) |
| 102 | - with wave.open(args.wave_filename) as f: | ||
| 103 | - # If the wave file has a different sampling rate from the one | ||
| 104 | - # expected by the model (16 kHz in our case), we will do | ||
| 105 | - # resampling inside sherpa-onnx | ||
| 106 | - wave_file_sample_rate = f.getframerate() | ||
| 107 | - | ||
| 108 | - assert f.getnchannels() == 1, f.getnchannels() | ||
| 109 | - assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 110 | - num_samples = f.getnframes() | ||
| 111 | - samples = f.readframes(num_samples) | ||
| 112 | - samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 113 | - samples_float32 = samples_int16.astype(np.float32) | ||
| 114 | - | ||
| 115 | - samples_float32 = samples_float32 / 32768 | ||
| 116 | 136 | ||
| 117 | - duration = len(samples_float32) / wave_file_sample_rate | ||
| 118 | - | ||
| 119 | - start_time = time.time() | ||
| 120 | print("Started!") | 137 | print("Started!") |
| 138 | + start_time = time.time() | ||
| 121 | 139 | ||
| 122 | - stream = recognizer.create_stream() | ||
| 123 | - | ||
| 124 | - stream.accept_waveform(wave_file_sample_rate, samples_float32) | ||
| 125 | - | ||
| 126 | - tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32) | ||
| 127 | - stream.accept_waveform(wave_file_sample_rate, tail_paddings) | ||
| 128 | - | ||
| 129 | - stream.input_finished() | ||
| 130 | - | ||
| 131 | - while recognizer.is_ready(stream): | ||
| 132 | - recognizer.decode_stream(stream) | 140 | + streams = [] |
| 141 | + total_duration = 0 | ||
| 142 | + for wave_filename in args.sound_files: | ||
| 143 | + assert_file_exists(wave_filename) | ||
| 144 | + samples, sample_rate = read_wave(wave_filename) | ||
| 145 | + duration = len(samples) / sample_rate | ||
| 146 | + total_duration += duration | ||
| 147 | + | ||
| 148 | + s = recognizer.create_stream() | ||
| 149 | + s.accept_waveform(sample_rate, samples) | ||
| 150 | + | ||
| 151 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 152 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 153 | + | ||
| 154 | + s.input_finished() | ||
| 155 | + | ||
| 156 | + streams.append(s) | ||
| 157 | + | ||
| 158 | + while True: | ||
| 159 | + ready_list = [] | ||
| 160 | + for s in streams: | ||
| 161 | + if recognizer.is_ready(s): | ||
| 162 | + ready_list.append(s) | ||
| 163 | + if len(ready_list) == 0: | ||
| 164 | + break | ||
| 165 | + recognizer.decode_streams(ready_list) | ||
| 166 | + results = [recognizer.get_result(s) for s in streams] | ||
| 167 | + end_time = time.time() | ||
| 168 | + print("Done!") | ||
| 133 | 169 | ||
| 134 | - print(recognizer.get_result(stream)) | 170 | + for wave_filename, result in zip(args.sound_files, results): |
| 171 | + print(f"{wave_filename}\n{result}") | ||
| 172 | + print("-" * 10) | ||
| 135 | 173 | ||
| 136 | - print("Done!") | ||
| 137 | - end_time = time.time() | ||
| 138 | elapsed_seconds = end_time - start_time | 174 | elapsed_seconds = end_time - start_time |
| 139 | rtf = elapsed_seconds / duration | 175 | rtf = elapsed_seconds / duration |
| 140 | print(f"num_threads: {args.num_threads}") | 176 | print(f"num_threads: {args.num_threads}") |
| @@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | @@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | ||
| 27 | import argparse | 27 | import argparse |
| 28 | import asyncio | 28 | import asyncio |
| 29 | import logging | 29 | import logging |
| 30 | -import time | ||
| 31 | import wave | 30 | import wave |
| 32 | 31 | ||
| 33 | try: | 32 | try: |
python-api-examples/online-websocket-client-microphone.py
100644 → 100755
| @@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | @@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc | ||
| 24 | import argparse | 24 | import argparse |
| 25 | import asyncio | 25 | import asyncio |
| 26 | import sys | 26 | import sys |
| 27 | -import time | ||
| 28 | 27 | ||
| 29 | import numpy as np | 28 | import numpy as np |
| 30 | 29 | ||
| 31 | try: | 30 | try: |
| 32 | import sounddevice as sd | 31 | import sounddevice as sd |
| 33 | -except ImportError as e: | 32 | +except ImportError: |
| 34 | print("Please install sounddevice first. You can use") | 33 | print("Please install sounddevice first. You can use") |
| 35 | print() | 34 | print() |
| 36 | print(" pip install sounddevice") | 35 | print(" pip install sounddevice") |
| @@ -134,7 +133,7 @@ async def run( | @@ -134,7 +133,7 @@ async def run( | ||
| 134 | await websocket.send(indata.tobytes()) | 133 | await websocket.send(indata.tobytes()) |
| 135 | 134 | ||
| 136 | decoding_results = await receive_task | 135 | decoding_results = await receive_task |
| 137 | - print("\nFinal result is:\n{decoding_results}") | 136 | + print(f"\nFinal result is:\n{decoding_results}") |
| 138 | 137 | ||
| 139 | 138 | ||
| 140 | async def main(): | 139 | async def main(): |
| @@ -13,7 +13,7 @@ from pathlib import Path | @@ -13,7 +13,7 @@ from pathlib import Path | ||
| 13 | 13 | ||
| 14 | try: | 14 | try: |
| 15 | import sounddevice as sd | 15 | import sounddevice as sd |
| 16 | -except ImportError as e: | 16 | +except ImportError: |
| 17 | print("Please install sounddevice first. You can use") | 17 | print("Please install sounddevice first. You can use") |
| 18 | print() | 18 | print() |
| 19 | print(" pip install sounddevice") | 19 | print(" pip install sounddevice") |
| @@ -25,9 +25,11 @@ import sherpa_onnx | @@ -25,9 +25,11 @@ import sherpa_onnx | ||
| 25 | 25 | ||
| 26 | 26 | ||
| 27 | def assert_file_exists(filename: str): | 27 | def assert_file_exists(filename: str): |
| 28 | - assert Path( | ||
| 29 | - filename | ||
| 30 | - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | 28 | + assert Path(filename).is_file(), ( |
| 29 | + f"{filename} does not exist!\n" | ||
| 30 | + "Please refer to " | ||
| 31 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 32 | + ) | ||
| 31 | 33 | ||
| 32 | 34 | ||
| 33 | def get_args(): | 35 | def get_args(): |
| @@ -12,7 +12,7 @@ from pathlib import Path | @@ -12,7 +12,7 @@ from pathlib import Path | ||
| 12 | 12 | ||
| 13 | try: | 13 | try: |
| 14 | import sounddevice as sd | 14 | import sounddevice as sd |
| 15 | -except ImportError as e: | 15 | +except ImportError: |
| 16 | print("Please install sounddevice first. You can use") | 16 | print("Please install sounddevice first. You can use") |
| 17 | print() | 17 | print() |
| 18 | print(" pip install sounddevice") | 18 | print(" pip install sounddevice") |
| @@ -24,9 +24,11 @@ import sherpa_onnx | @@ -24,9 +24,11 @@ import sherpa_onnx | ||
| 24 | 24 | ||
| 25 | 25 | ||
| 26 | def assert_file_exists(filename: str): | 26 | def assert_file_exists(filename: str): |
| 27 | - assert Path( | ||
| 28 | - filename | ||
| 29 | - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | 27 | + assert Path(filename).is_file(), ( |
| 28 | + f"{filename} does not exist!\n" | ||
| 29 | + "Please refer to " | ||
| 30 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 31 | + ) | ||
| 30 | 32 | ||
| 31 | 33 | ||
| 32 | def get_args(): | 34 | def get_args(): |
| @@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | @@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | ||
| 128 | ) | 128 | ) |
| 129 | target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) | 129 | target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) |
| 130 | 130 | ||
| 131 | - | ||
| 132 | add_executable(sherpa-onnx-online-websocket-client | 131 | add_executable(sherpa-onnx-online-websocket-client |
| 133 | online-websocket-client.cc | 132 | online-websocket-client.cc |
| 134 | ) | 133 | ) |
| @@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | @@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) | ||
| 142 | target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) | 141 | target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) |
| 143 | endif() | 142 | endif() |
| 144 | 143 | ||
| 144 | + # For offline websocket | ||
| 145 | + add_executable(sherpa-onnx-offline-websocket-server | ||
| 146 | + offline-websocket-server-impl.cc | ||
| 147 | + offline-websocket-server.cc | ||
| 148 | + ) | ||
| 149 | + target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core) | ||
| 150 | + | ||
| 151 | + if(NOT WIN32) | ||
| 152 | + target_link_libraries(sherpa-onnx-offline-websocket-server -pthread) | ||
| 153 | + target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations) | ||
| 154 | + endif() | ||
| 145 | endif() | 155 | endif() |
| 146 | 156 | ||
| 147 | 157 |
| 1 | +// sherpa-onnx/csrc/offline-websocket-server-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) { | ||
| 14 | + recognizer_config.Register(po); | ||
| 15 | + | ||
| 16 | + po->Register("max-batch-size", &max_batch_size, | ||
| 17 | + "Max batch size for decoding."); | ||
| 18 | + | ||
| 19 | + po->Register( | ||
| 20 | + "max-utterance-length", &max_utterance_length, | ||
| 21 | + "Max utterance length in seconds. If we receive an utterance " | ||
| 22 | + "longer than this value, we will reject the connection. " | ||
| 23 | + "If you have enough memory, you can select a large value for it."); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +void OfflineWebsocketDecoderConfig::Validate() const { | ||
| 27 | + if (!recognizer_config.Validate()) { | ||
| 28 | + SHERPA_ONNX_LOGE("Error in recongizer config"); | ||
| 29 | + exit(-1); | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + if (max_batch_size <= 0) { | ||
| 33 | + SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size); | ||
| 34 | + exit(-1); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + if (max_utterance_length <= 0) { | ||
| 38 | + SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f", | ||
| 39 | + max_utterance_length); | ||
| 40 | + exit(-1); | ||
| 41 | + } | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server) | ||
| 45 | + : config_(server->GetConfig().decoder_config), | ||
| 46 | + server_(server), | ||
| 47 | + recognizer_(config_.recognizer_config) {} | ||
| 48 | + | ||
| 49 | +void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) { | ||
| 50 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 51 | + streams_.push_back({hdl, d}); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +void OfflineWebsocketDecoder::Decode() { | ||
| 55 | + std::unique_lock<std::mutex> lock(mutex_); | ||
| 56 | + if (streams_.empty()) { | ||
| 57 | + return; | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + int32_t size = | ||
| 61 | + std::min(static_cast<int32_t>(streams_.size()), config_.max_batch_size); | ||
| 62 | + SHERPA_ONNX_LOGE("size: %d", size); | ||
| 63 | + | ||
| 64 | + // We first lock the mutex for streams_, take items from it, and then | ||
| 65 | + // unlock the mutex; in doing so we don't need to lock the mutex to | ||
| 66 | + // access hdl and connection_data later. | ||
| 67 | + std::vector<connection_hdl> handles(size); | ||
| 68 | + | ||
| 69 | + // Store connection_data here to prevent the data from being freed | ||
| 70 | + // while we are still using it. | ||
| 71 | + std::vector<ConnectionDataPtr> connection_data(size); | ||
| 72 | + | ||
| 73 | + std::vector<const float *> samples(size); | ||
| 74 | + std::vector<int32_t> samples_length(size); | ||
| 75 | + std::vector<std::unique_ptr<OfflineStream>> ss(size); | ||
| 76 | + std::vector<OfflineStream *> p_ss(size); | ||
| 77 | + | ||
| 78 | + for (int32_t i = 0; i != size; ++i) { | ||
| 79 | + auto &p = streams_.front(); | ||
| 80 | + handles[i] = p.first; | ||
| 81 | + connection_data[i] = p.second; | ||
| 82 | + streams_.pop_front(); | ||
| 83 | + | ||
| 84 | + auto sample_rate = connection_data[i]->sample_rate; | ||
| 85 | + auto samples = | ||
| 86 | + reinterpret_cast<const float *>(&connection_data[i]->data[0]); | ||
| 87 | + auto num_samples = connection_data[i]->expected_byte_size / sizeof(float); | ||
| 88 | + auto s = recognizer_.CreateStream(); | ||
| 89 | + s->AcceptWaveform(sample_rate, samples, num_samples); | ||
| 90 | + | ||
| 91 | + ss[i] = std::move(s); | ||
| 92 | + p_ss[i] = ss[i].get(); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + lock.unlock(); | ||
| 96 | + | ||
| 97 | + // Note: DecodeStreams is thread-safe | ||
| 98 | + recognizer_.DecodeStreams(p_ss.data(), size); | ||
| 99 | + | ||
| 100 | + for (int32_t i = 0; i != size; ++i) { | ||
| 101 | + connection_hdl hdl = handles[i]; | ||
| 102 | + asio::post(server_->GetConnectionContext(), | ||
| 103 | + [this, hdl, text = ss[i]->GetResult().text]() { | ||
| 104 | + websocketpp::lib::error_code ec; | ||
| 105 | + server_->GetServer().send( | ||
| 106 | + hdl, text, websocketpp::frame::opcode::text, ec); | ||
| 107 | + if (ec) { | ||
| 108 | + server_->GetServer().get_alog().write( | ||
| 109 | + websocketpp::log::alevel::app, ec.message()); | ||
| 110 | + } | ||
| 111 | + }); | ||
| 112 | + } | ||
| 113 | +} | ||
| 114 | + | ||
| 115 | +void OfflineWebsocketServerConfig::Register(ParseOptions *po) { | ||
| 116 | + decoder_config.Register(po); | ||
| 117 | + po->Register("log-file", &log_file, | ||
| 118 | + "Path to the log file. Logs are " | ||
| 119 | + "appended to this file"); | ||
| 120 | +} | ||
| 121 | + | ||
| 122 | +void OfflineWebsocketServerConfig::Validate() const { | ||
| 123 | + decoder_config.Validate(); | ||
| 124 | +} | ||
| 125 | + | ||
| 126 | +OfflineWebsocketServer::OfflineWebsocketServer( | ||
| 127 | + asio::io_context &io_conn, // NOLINT | ||
| 128 | + asio::io_context &io_work, // NOLINT | ||
| 129 | + const OfflineWebsocketServerConfig &config) | ||
| 130 | + : io_conn_(io_conn), | ||
| 131 | + io_work_(io_work), | ||
| 132 | + config_(config), | ||
| 133 | + log_(config.log_file, std::ios::app), | ||
| 134 | + tee_(std::cout, log_), | ||
| 135 | + decoder_(this) { | ||
| 136 | + SetupLog(); | ||
| 137 | + | ||
| 138 | + server_.init_asio(&io_conn_); | ||
| 139 | + | ||
| 140 | + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); | ||
| 141 | + | ||
| 142 | + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); | ||
| 143 | + | ||
| 144 | + server_.set_message_handler( | ||
| 145 | + [this](connection_hdl hdl, server::message_ptr msg) { | ||
| 146 | + OnMessage(hdl, msg); | ||
| 147 | + }); | ||
| 148 | +} | ||
| 149 | + | ||
| 150 | +void OfflineWebsocketServer::SetupLog() { | ||
| 151 | + server_.clear_access_channels(websocketpp::log::alevel::all); | ||
| 152 | + server_.set_access_channels(websocketpp::log::alevel::connect); | ||
| 153 | + server_.set_access_channels(websocketpp::log::alevel::disconnect); | ||
| 154 | + | ||
| 155 | + // So that it also prints to std::cout and std::cerr | ||
| 156 | + server_.get_alog().set_ostream(&tee_); | ||
| 157 | + server_.get_elog().set_ostream(&tee_); | ||
| 158 | +} | ||
| 159 | + | ||
| 160 | +void OfflineWebsocketServer::OnOpen(connection_hdl hdl) { | ||
| 161 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 162 | + connections_.emplace(hdl, std::make_shared<ConnectionData>()); | ||
| 163 | + | ||
| 164 | + SHERPA_ONNX_LOGE("Number of active connections: %d", | ||
| 165 | + static_cast<int32_t>(connections_.size())); | ||
| 166 | +} | ||
| 167 | + | ||
| 168 | +void OfflineWebsocketServer::OnClose(connection_hdl hdl) { | ||
| 169 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 170 | + connections_.erase(hdl); | ||
| 171 | + | ||
| 172 | + SHERPA_ONNX_LOGE("Number of active connections: %d", | ||
| 173 | + static_cast<int32_t>(connections_.size())); | ||
| 174 | +} | ||
| 175 | + | ||
| 176 | +void OfflineWebsocketServer::OnMessage(connection_hdl hdl, | ||
| 177 | + server::message_ptr msg) { | ||
| 178 | + std::unique_lock<std::mutex> lock(mutex_); | ||
| 179 | + auto connection_data = connections_.find(hdl)->second; | ||
| 180 | + lock.unlock(); | ||
| 181 | + const std::string &payload = msg->get_payload(); | ||
| 182 | + | ||
| 183 | + switch (msg->get_opcode()) { | ||
| 184 | + case websocketpp::frame::opcode::text: | ||
| 185 | + if (payload == "Done") { | ||
| 186 | + // The client will not send any more data. We can close the | ||
| 187 | + // connection now. | ||
| 188 | + Close(hdl, websocketpp::close::status::normal, "Done"); | ||
| 189 | + } else { | ||
| 190 | + Close(hdl, websocketpp::close::status::normal, | ||
| 191 | + std::string("Invalid payload: ") + payload); | ||
| 192 | + } | ||
| 193 | + break; | ||
| 194 | + | ||
| 195 | + case websocketpp::frame::opcode::binary: { | ||
| 196 | + auto p = reinterpret_cast<const int8_t *>(payload.data()); | ||
| 197 | + | ||
| 198 | + if (connection_data->expected_byte_size == 0) { | ||
| 199 | + if (payload.size() < 8) { | ||
| 200 | + Close(hdl, websocketpp::close::status::normal, | ||
| 201 | + "Payload is too short"); | ||
| 202 | + break; | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + connection_data->sample_rate = *reinterpret_cast<const int32_t *>(p); | ||
| 206 | + | ||
| 207 | + connection_data->expected_byte_size = | ||
| 208 | + *reinterpret_cast<const int32_t *>(p + 4); | ||
| 209 | + | ||
| 210 | + int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length * | ||
| 211 | + connection_data->sample_rate * sizeof(float); | ||
| 212 | + if (connection_data->expected_byte_size > max_byte_size_) { | ||
| 213 | + float num_samples = | ||
| 214 | + connection_data->expected_byte_size / sizeof(float); | ||
| 215 | + | ||
| 216 | + float duration = num_samples / connection_data->sample_rate; | ||
| 217 | + | ||
| 218 | + std::ostringstream os; | ||
| 219 | + os << "Max utterance length is configured to " | ||
| 220 | + << decoder_.GetConfig().max_utterance_length | ||
| 221 | + << " seconds, received length is " << duration << " seconds. " | ||
| 222 | + << "Payload is too large!"; | ||
| 223 | + Close(hdl, websocketpp::close::status::message_too_big, os.str()); | ||
| 224 | + break; | ||
| 225 | + } | ||
| 226 | + | ||
| 227 | + connection_data->data.resize(connection_data->expected_byte_size); | ||
| 228 | + std::copy(payload.begin() + 8, payload.end(), | ||
| 229 | + connection_data->data.data()); | ||
| 230 | + connection_data->cur = payload.size() - 8; | ||
| 231 | + } else { | ||
| 232 | + std::copy(payload.begin(), payload.end(), | ||
| 233 | + connection_data->data.data() + connection_data->cur); | ||
| 234 | + connection_data->cur += payload.size(); | ||
| 235 | + } | ||
| 236 | + | ||
| 237 | + if (connection_data->expected_byte_size == connection_data->cur) { | ||
| 238 | + auto d = std::make_shared<ConnectionData>(std::move(*connection_data)); | ||
| 239 | + // Clear it so that we can handle the next audio file from the client. | ||
| 240 | + // The client can send multiple audio files for recognition without | ||
| 241 | + // the need to create another connection. | ||
| 242 | + connection_data->sample_rate = 0; | ||
| 243 | + connection_data->expected_byte_size = 0; | ||
| 244 | + connection_data->cur = 0; | ||
| 245 | + | ||
| 246 | + decoder_.Push(hdl, d); | ||
| 247 | + | ||
| 248 | + connection_data->Clear(); | ||
| 249 | + | ||
| 250 | + asio::post(io_work_, [this]() { decoder_.Decode(); }); | ||
| 251 | + } | ||
| 252 | + break; | ||
| 253 | + } | ||
| 254 | + | ||
| 255 | + default: | ||
| 256 | + // Unexpected message, ignore it | ||
| 257 | + break; | ||
| 258 | + } | ||
| 259 | +} | ||
| 260 | + | ||
| 261 | +void OfflineWebsocketServer::Close(connection_hdl hdl, | ||
| 262 | + websocketpp::close::status::value code, | ||
| 263 | + const std::string &reason) { | ||
| 264 | + auto con = server_.get_con_from_hdl(hdl); | ||
| 265 | + | ||
| 266 | + std::ostringstream os; | ||
| 267 | + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason | ||
| 268 | + << "\n"; | ||
| 269 | + | ||
| 270 | + websocketpp::lib::error_code ec; | ||
| 271 | + server_.close(hdl, code, reason, ec); | ||
| 272 | + if (ec) { | ||
| 273 | + os << "Failed to close" << con->get_remote_endpoint() << ". " | ||
| 274 | + << ec.message() << "\n"; | ||
| 275 | + } | ||
| 276 | + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); | ||
| 277 | +} | ||
| 278 | + | ||
| 279 | +void OfflineWebsocketServer::Run(uint16_t port) { | ||
| 280 | + server_.set_reuse_addr(true); | ||
| 281 | + server_.listen(asio::ip::tcp::v4(), port); | ||
| 282 | + server_.start_accept(); | ||
| 283 | +} | ||
| 284 | + | ||
| 285 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-websocket-server-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <deque> | ||
| 9 | +#include <fstream> | ||
| 10 | +#include <map> | ||
| 11 | +#include <memory> | ||
| 12 | +#include <string> | ||
| 13 | +#include <utility> | ||
| 14 | +#include <vector> | ||
| 15 | + | ||
| 16 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 17 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 18 | +#include "sherpa-onnx/csrc/tee-stream.h" | ||
| 19 | +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS | ||
| 20 | +#include "websocketpp/server.hpp" | ||
| 21 | + | ||
| 22 | +using server = websocketpp::server<websocketpp::config::asio>; | ||
| 23 | +using connection_hdl = websocketpp::connection_hdl; | ||
| 24 | + | ||
| 25 | +namespace sherpa_onnx { | ||
| 26 | + | ||
| 27 | +/** Communication protocol | ||
| 28 | + * | ||
| 29 | + * The client sends a byte stream to the server. The first 4 bytes in little | ||
| 30 | + * endian indicates the sample rate of the audio data that the client will send. | ||
| 31 | + * The next 4 bytes in little endian indicates the total samples in bytes the | ||
| 32 | + * client will send. The remaining bytes represent audio samples. Each audio | ||
| 33 | + * sample is a float occupying 4 bytes and is normalized into the range | ||
| 34 | + * [-1, 1]. | ||
| 35 | + * | ||
| 36 | + * The byte stream can be broken into arbitrary number of messages. | ||
| 37 | + * We require that the first message has to be at least 8 bytes so that | ||
| 38 | + * we can get `sample_rate` and `expected_byte_size` from the first message. | ||
| 39 | + */ | ||
| 40 | +struct ConnectionData { | ||
| 41 | + // Sample rate of the audio samples the client | ||
| 42 | + int32_t sample_rate; | ||
| 43 | + | ||
| 44 | + // Number of expected bytes sent from the client | ||
| 45 | + int32_t expected_byte_size = 0; | ||
| 46 | + | ||
| 47 | + // Number of bytes received so far | ||
| 48 | + int32_t cur = 0; | ||
| 49 | + | ||
| 50 | + // It saves the received samples from the client. | ||
| 51 | + // We will **reinterpret_cast** it to float. | ||
| 52 | + // We expect that data.size() == expected_byte_size | ||
| 53 | + std::vector<int8_t> data; | ||
| 54 | + | ||
| 55 | + void Clear() { | ||
| 56 | + sample_rate = 0; | ||
| 57 | + expected_byte_size = 0; | ||
| 58 | + cur = 0; | ||
| 59 | + data.clear(); | ||
| 60 | + } | ||
| 61 | +}; | ||
| 62 | + | ||
| 63 | +using ConnectionDataPtr = std::shared_ptr<ConnectionData>; | ||
| 64 | + | ||
| 65 | +struct OfflineWebsocketDecoderConfig { | ||
| 66 | + OfflineRecognizerConfig recognizer_config; | ||
| 67 | + | ||
| 68 | + int32_t max_batch_size = 5; | ||
| 69 | + | ||
| 70 | + float max_utterance_length = 300; // seconds | ||
| 71 | + | ||
| 72 | + void Register(ParseOptions *po); | ||
| 73 | + void Validate() const; | ||
| 74 | +}; | ||
| 75 | + | ||
| 76 | +class OfflineWebsocketServer; | ||
| 77 | + | ||
| 78 | +class OfflineWebsocketDecoder { | ||
| 79 | + public: | ||
| 80 | + /** | ||
| 81 | + * @param config Configuration for the decoder. | ||
| 82 | + * @param server **Borrowed** from outside. | ||
| 83 | + */ | ||
| 84 | + explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server); | ||
| 85 | + | ||
| 86 | + /** Insert received data to the queue for decoding. | ||
| 87 | + * | ||
| 88 | + * @param hdl A handle to the connection. We can use it to send the result | ||
| 89 | + * back to the client once it finishes decoding. | ||
| 90 | + * @param d The received data | ||
| 91 | + */ | ||
| 92 | + void Push(connection_hdl hdl, ConnectionDataPtr d); | ||
| 93 | + | ||
| 94 | + /** It is called by one of the work thread. | ||
| 95 | + */ | ||
| 96 | + void Decode(); | ||
| 97 | + | ||
| 98 | + const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; } | ||
| 99 | + | ||
| 100 | + private: | ||
| 101 | + OfflineWebsocketDecoderConfig config_; | ||
| 102 | + | ||
| 103 | + /** When we have received all the data from the client, we put it into | ||
| 104 | + * this queue; the worker threads will get items from this queue for | ||
| 105 | + * decoding. | ||
| 106 | + * | ||
| 107 | + * Number of items to take from this queue is determined by | ||
| 108 | + * `--max-batch-size`. If there are not enough items in the queue, we won't | ||
| 109 | + * wait and take whatever we have for decoding. | ||
| 110 | + */ | ||
| 111 | + std::mutex mutex_; | ||
| 112 | + std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_; | ||
| 113 | + | ||
| 114 | + OfflineWebsocketServer *server_; // Not owned | ||
| 115 | + OfflineRecognizer recognizer_; | ||
| 116 | +}; | ||
| 117 | + | ||
| 118 | +struct OfflineWebsocketServerConfig { | ||
| 119 | + OfflineWebsocketDecoderConfig decoder_config; | ||
| 120 | + std::string log_file = "./log.txt"; | ||
| 121 | + | ||
| 122 | + void Register(ParseOptions *po); | ||
| 123 | + void Validate() const; | ||
| 124 | +}; | ||
| 125 | + | ||
| 126 | +class OfflineWebsocketServer { | ||
| 127 | + public: | ||
| 128 | + OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT | ||
| 129 | + asio::io_context &io_work, // NOLINT | ||
| 130 | + const OfflineWebsocketServerConfig &config); | ||
| 131 | + | ||
| 132 | + asio::io_context &GetConnectionContext() { return io_conn_; } | ||
| 133 | + server &GetServer() { return server_; } | ||
| 134 | + | ||
| 135 | + void Run(uint16_t port); | ||
| 136 | + | ||
| 137 | + const OfflineWebsocketServerConfig &GetConfig() const { return config_; } | ||
| 138 | + | ||
| 139 | + private: | ||
| 140 | + void SetupLog(); | ||
| 141 | + | ||
| 142 | + // When a websocket client is connected, it will invoke this method | ||
| 143 | + // (Not for HTTP) | ||
| 144 | + void OnOpen(connection_hdl hdl); | ||
| 145 | + | ||
| 146 | + // When a websocket client is disconnected, it will invoke this method | ||
| 147 | + void OnClose(connection_hdl hdl); | ||
| 148 | + | ||
| 149 | + // When a message is received from a websocket client, this method will | ||
| 150 | + // be invoked. | ||
| 151 | + // | ||
| 152 | + // The protocol between the client and the server is as follows: | ||
| 153 | + // | ||
| 154 | + // (1) The client connects to the server | ||
| 155 | + // (2) The client starts to send binary byte stream to the server. | ||
| 156 | + // The byte stream can be broken into multiple messages or it can | ||
| 157 | + // be put into a single message. | ||
| 158 | + // The first message has to contain at least 8 bytes. The first | ||
| 159 | + // 4 bytes in little endian contains a int32_t indicating the | ||
| 160 | + // sampling rate. The next 4 bytes in little endian contains a int32_t | ||
| 161 | + // indicating total number of bytes of samples the client will send. | ||
| 162 | + // We assume each sample is a float containing 4 bytes and has been | ||
| 163 | + // normalized to the range [-1, 1]. | ||
| 164 | + // (4) When the server receives all the samples from the client, it will | ||
| 165 | + // start to decode them. Once decoded, the server sends a text message | ||
| 166 | + // to the client containing the decoded results | ||
| 167 | + // (5) After receiving the decoded results from the server, if the client has | ||
| 168 | + // another audio file to send, it repeats (2), (3), (4) | ||
| 169 | + // (6) If the client has no more audio files to decode, the client sends a | ||
| 170 | + // text message containing "Done" to the server and closes the connection | ||
| 171 | + // (7) The server receives a text message "Done" and closes the connection | ||
| 172 | + // | ||
| 173 | + // Note: | ||
| 174 | + // (a) All models in icefall use features extracted from audio samples | ||
| 175 | + // normalized to the range [-1, 1]. Please send normalized audio samples | ||
| 176 | + // if you use models from icefall. | ||
| 177 | + // (b) Only sound files with a single channel is supported | ||
| 178 | + // (c) Only audio samples are sent. For instance, if we want to decode | ||
| 179 | + // a WAVE file, the RIFF header of the WAVE is not sent. | ||
| 180 | + void OnMessage(connection_hdl hdl, server::message_ptr msg); | ||
| 181 | + | ||
| 182 | + // Close a websocket connection with given code and reason | ||
| 183 | + void Close(connection_hdl hdl, websocketpp::close::status::value code, | ||
| 184 | + const std::string &reason); | ||
| 185 | + | ||
| 186 | + private: | ||
| 187 | + asio::io_context &io_conn_; | ||
| 188 | + asio::io_context &io_work_; | ||
| 189 | + server server_; | ||
| 190 | + | ||
| 191 | + std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>> | ||
| 192 | + connections_; | ||
| 193 | + std::mutex mutex_; | ||
| 194 | + | ||
| 195 | + OfflineWebsocketServerConfig config_; | ||
| 196 | + | ||
| 197 | + std::ofstream log_; | ||
| 198 | + TeeStream tee_; | ||
| 199 | + | ||
| 200 | + OfflineWebsocketDecoder decoder_; | ||
| 201 | +}; | ||
| 202 | + | ||
| 203 | +} // namespace sherpa_onnx | ||
| 204 | + | ||
| 205 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ |
sherpa-onnx/csrc/offline-websocket-server.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-websocket-server.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "asio.hpp" | ||
| 6 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 7 | +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h" | ||
| 8 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 9 | + | ||
| 10 | +static constexpr const char *kUsageMessage = R"( | ||
| 11 | +Automatic speech recognition with sherpa-onnx using websocket. | ||
| 12 | + | ||
| 13 | +Usage: | ||
| 14 | + | ||
| 15 | +./bin/sherpa-onnx-offline-websocket-server --help | ||
| 16 | + | ||
| 17 | +(1) For transducer models | ||
| 18 | + | ||
| 19 | +./bin/sherpa-onnx-offline-websocket-server \ | ||
| 20 | + --port=6006 \ | ||
| 21 | + --num-work-threads=5 \ | ||
| 22 | + --tokens=/path/to/tokens.txt \ | ||
| 23 | + --encoder=/path/to/encoder.onnx \ | ||
| 24 | + --decoder=/path/to/decoder.onnx \ | ||
| 25 | + --joiner=/path/to/joiner.onnx \ | ||
| 26 | + --log-file=./log.txt \ | ||
| 27 | + --max-batch-size=5 | ||
| 28 | + | ||
| 29 | +(2) For Paraformer | ||
| 30 | + | ||
| 31 | +./bin/sherpa-onnx-offline-websocket-server \ | ||
| 32 | + --port=6006 \ | ||
| 33 | + --num-work-threads=5 \ | ||
| 34 | + --tokens=/path/to/tokens.txt \ | ||
| 35 | + --paraformer=/path/to/model.onnx \ | ||
| 36 | + --log-file=./log.txt \ | ||
| 37 | + --max-batch-size=5 | ||
| 38 | + | ||
| 39 | +Please refer to | ||
| 40 | +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 41 | +for a list of pre-trained models to download. | ||
| 42 | +)"; | ||
| 43 | + | ||
| 44 | +int32_t main(int32_t argc, char *argv[]) { | ||
| 45 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 46 | + | ||
| 47 | + sherpa_onnx::OfflineWebsocketServerConfig config; | ||
| 48 | + | ||
| 49 | + // the server will listen on this port | ||
| 50 | + int32_t port = 6006; | ||
| 51 | + | ||
| 52 | + // size of the thread pool for handling network connections | ||
| 53 | + int32_t num_io_threads = 1; | ||
| 54 | + | ||
| 55 | + // size of the thread pool for neural network computation and decoding | ||
| 56 | + int32_t num_work_threads = 3; | ||
| 57 | + | ||
| 58 | + po.Register("num-io-threads", &num_io_threads, | ||
| 59 | + "Thread pool size for network connections."); | ||
| 60 | + | ||
| 61 | + po.Register("num-work-threads", &num_work_threads, | ||
| 62 | + "Thread pool size for for neural network " | ||
| 63 | + "computation and decoding."); | ||
| 64 | + | ||
| 65 | + po.Register("port", &port, "The port on which the server will listen."); | ||
| 66 | + | ||
| 67 | + config.Register(&po); | ||
| 68 | + | ||
| 69 | + if (argc == 1) { | ||
| 70 | + po.PrintUsage(); | ||
| 71 | + exit(EXIT_FAILURE); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + po.Read(argc, argv); | ||
| 75 | + | ||
| 76 | + if (po.NumArgs() != 0) { | ||
| 77 | + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); | ||
| 78 | + po.PrintUsage(); | ||
| 79 | + exit(EXIT_FAILURE); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + config.Validate(); | ||
| 83 | + | ||
| 84 | + asio::io_context io_conn; // for network connections | ||
| 85 | + asio::io_context io_work; // for neural network and decoding | ||
| 86 | + | ||
| 87 | + sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config); | ||
| 88 | + server.Run(port); | ||
| 89 | + | ||
| 90 | + SHERPA_ONNX_LOGE("Started!"); | ||
| 91 | + SHERPA_ONNX_LOGE("Listening on: %d", port); | ||
| 92 | + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); | ||
| 93 | + | ||
| 94 | + // give some work to do for the io_work pool | ||
| 95 | + auto work_guard = asio::make_work_guard(io_work); | ||
| 96 | + | ||
| 97 | + std::vector<std::thread> io_threads; | ||
| 98 | + | ||
| 99 | + // decrement since the main thread is also used for network communications | ||
| 100 | + for (int32_t i = 0; i < num_io_threads - 1; ++i) { | ||
| 101 | + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + std::vector<std::thread> work_threads; | ||
| 105 | + for (int32_t i = 0; i < num_work_threads; ++i) { | ||
| 106 | + work_threads.emplace_back([&io_work]() { io_work.run(); }); | ||
| 107 | + } | ||
| 108 | + | ||
| 109 | + io_conn.run(); | ||
| 110 | + | ||
| 111 | + for (auto &t : io_threads) { | ||
| 112 | + t.join(); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + for (auto &t : work_threads) { | ||
| 116 | + t.join(); | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + return 0; | ||
| 120 | +} |
| @@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 76 | sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); | 76 | sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); |
| 77 | server.Run(port); | 77 | server.Run(port); |
| 78 | 78 | ||
| 79 | + SHERPA_ONNX_LOGE("Started!"); | ||
| 79 | SHERPA_ONNX_LOGE("Listening on: %d", port); | 80 | SHERPA_ONNX_LOGE("Listening on: %d", port); |
| 80 | SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); | 81 | SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); |
| 81 | 82 |
-
请 注册 或 登录 后发表评论