Fangjun Kuang
Committed by GitHub

add offline websocket server/client (#98)

1 [flake8] 1 [flake8]
2 show-source=true 2 show-source=true
3 statistics=true 3 statistics=true
4 -max-line-length = 80 4 +max-line-length = 120
5 5
6 exclude = 6 exclude =
7 .git, 7 .git,
@@ -30,9 +30,12 @@ ls -lh @@ -30,9 +30,12 @@ ls -lh
30 30
31 ls -lh $repo 31 ls -lh $repo
32 32
33 -python3 ./python-api-examples/decode-file.py \ 33 +python3 ./python-api-examples/online-decode-files.py \
34 --tokens=$repo/tokens.txt \ 34 --tokens=$repo/tokens.txt \
35 --encoder=$repo/encoder-epoch-99-avg-1.onnx \ 35 --encoder=$repo/encoder-epoch-99-avg-1.onnx \
36 --decoder=$repo/decoder-epoch-99-avg-1.onnx \ 36 --decoder=$repo/decoder-epoch-99-avg-1.onnx \
37 --joiner=$repo/joiner-epoch-99-avg-1.onnx \ 37 --joiner=$repo/joiner-epoch-99-avg-1.onnx \
38 - --wave-filename=$repo/test_wavs/4.wav 38 + $repo/test_wavs/0.wav \
  39 + $repo/test_wavs/1.wav \
  40 + $repo/test_wavs/2.wav \
  41 + $repo/test_wavs/3.wav
@@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example @@ -45,3 +45,4 @@ paraformer-onnxruntime-python-example
45 run-sherpa-onnx-offline-paraformer.sh 45 run-sherpa-onnx-offline-paraformer.sh
46 run-sherpa-onnx-offline-transducer.sh 46 run-sherpa-onnx-offline-transducer.sh
47 sherpa-onnx-paraformer-zh-2023-03-28 47 sherpa-onnx-paraformer-zh-2023-03-28
  48 +run-offline-websocket-server-paraformer.sh
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +A websocket client for sherpa-onnx-offline-websocket-server
  7 +
  8 +This file shows how to transcribe multiple
  9 +files in parallel. We create a separate connection for transcribing each file.
  10 +
  11 +Usage:
  12 + ./offline-websocket-client-decode-files-parallel.py \
  13 + --server-addr localhost \
  14 + --server-port 6006 \
  15 + /path/to/foo.wav \
  16 + /path/to/bar.wav \
  17 + /path/to/16kHz.wav \
  18 + /path/to/8kHz.wav
  19 +
  20 +(Note: You have to first start the server before starting the client)
  21 +
  22 +You can find the server at
  23 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
  24 +
  25 +Note: The server is implemented in C++.
  26 +"""
  27 +
  28 +import argparse
  29 +import asyncio
  30 +import logging
  31 +import wave
  32 +from typing import Tuple
  33 +
  34 +try:
  35 + import websockets
  36 +except ImportError:
  37 + print("please run:")
  38 + print("")
  39 + print(" pip install websockets")
  40 + print("")
  41 + print("before you run this script")
  42 + print("")
  43 +
  44 +import numpy as np
  45 +
  46 +
  47 +def get_args():
  48 + parser = argparse.ArgumentParser(
  49 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  50 + )
  51 +
  52 + parser.add_argument(
  53 + "--server-addr",
  54 + type=str,
  55 + default="localhost",
  56 + help="Address of the server",
  57 + )
  58 +
  59 + parser.add_argument(
  60 + "--server-port",
  61 + type=int,
  62 + default=6006,
  63 + help="Port of the server",
  64 + )
  65 +
  66 + parser.add_argument(
  67 + "sound_files",
  68 + type=str,
  69 + nargs="+",
  70 + help="The input sound file(s) to decode. Each file must be of WAVE"
  71 + "format with a single channel, and each sample has 16-bit, "
  72 + "i.e., int16_t. "
  73 + "The sample rate of the file can be arbitrary and does not need to "
  74 + "be 16 kHz",
  75 + )
  76 +
  77 + return parser.parse_args()
  78 +
  79 +
  80 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  81 + """
  82 + Args:
  83 + wave_filename:
  84 + Path to a wave file. It should be single channel and each sample should
  85 + be 16-bit. Its sample rate does not need to be 16kHz.
  86 + Returns:
  87 + Return a tuple containing:
  88 + - A 1-D array of dtype np.float32 containing the samples, which are
  89 + normalized to the range [-1, 1].
  90 + - sample rate of the wave file
  91 + """
  92 +
  93 + with wave.open(wave_filename) as f:
  94 + assert f.getnchannels() == 1, f.getnchannels()
  95 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  96 + num_samples = f.getnframes()
  97 + samples = f.readframes(num_samples)
  98 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  99 + samples_float32 = samples_int16.astype(np.float32)
  100 +
  101 + samples_float32 = samples_float32 / 32768
  102 + return samples_float32, f.getframerate()
  103 +
  104 +
  105 +async def run(
  106 + server_addr: str,
  107 + server_port: int,
  108 + wave_filename: str,
  109 +):
  110 + async with websockets.connect(
  111 + f"ws://{server_addr}:{server_port}"
  112 + ) as websocket: # noqa
  113 + logging.info(f"Sending {wave_filename}")
  114 + samples, sample_rate = read_wave(wave_filename)
  115 + assert isinstance(sample_rate, int)
  116 + assert samples.dtype == np.float32, samples.dtype
  117 + assert samples.ndim == 1, samples.dim
  118 + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
  119 + buf += (samples.size * 4).to_bytes(4, byteorder="little")
  120 + buf += samples.tobytes()
  121 +
  122 + await websocket.send(buf)
  123 +
  124 + decoding_results = await websocket.recv()
  125 + logging.info(f"{wave_filename}\n{decoding_results}")
  126 +
  127 + # to signal that the client has sent all the data
  128 + await websocket.send("Done")
  129 +
  130 +
  131 +async def main():
  132 + args = get_args()
  133 + logging.info(vars(args))
  134 +
  135 + server_addr = args.server_addr
  136 + server_port = args.server_port
  137 + sound_files = args.sound_files
  138 +
  139 + all_tasks = []
  140 + for wave_filename in sound_files:
  141 + task = asyncio.create_task(
  142 + run(
  143 + server_addr=server_addr,
  144 + server_port=server_port,
  145 + wave_filename=wave_filename,
  146 + )
  147 + )
  148 + all_tasks.append(task)
  149 +
  150 + await asyncio.gather(*all_tasks)
  151 +
  152 +
  153 +if __name__ == "__main__":
  154 + formatter = (
  155 + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
  156 + )
  157 + logging.basicConfig(format=formatter, level=logging.INFO)
  158 + asyncio.run(main())
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +A websocket client for sherpa-onnx-offline-websocket-server
  7 +
  8 +This file shows how to use a single connection to transcribe multiple
  9 +files sequentially.
  10 +
  11 +Usage:
  12 + ./offline-websocket-client-decode-files-sequential.py \
  13 + --server-addr localhost \
  14 + --server-port 6006 \
  15 + /path/to/foo.wav \
  16 + /path/to/bar.wav \
  17 + /path/to/16kHz.wav \
  18 + /path/to/8kHz.wav
  19 +
  20 +(Note: You have to first start the server before starting the client)
  21 +
  22 +You can find the server at
  23 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc
  24 +
  25 +Note: The server is implemented in C++.
  26 +"""
  27 +
  28 +import argparse
  29 +import asyncio
  30 +import logging
  31 +import wave
  32 +from typing import List, Tuple
  33 +
  34 +try:
  35 + import websockets
  36 +except ImportError:
  37 + print("please run:")
  38 + print("")
  39 + print(" pip install websockets")
  40 + print("")
  41 + print("before you run this script")
  42 + print("")
  43 +
  44 +import numpy as np
  45 +
  46 +
  47 +def get_args():
  48 + parser = argparse.ArgumentParser(
  49 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  50 + )
  51 +
  52 + parser.add_argument(
  53 + "--server-addr",
  54 + type=str,
  55 + default="localhost",
  56 + help="Address of the server",
  57 + )
  58 +
  59 + parser.add_argument(
  60 + "--server-port",
  61 + type=int,
  62 + default=6006,
  63 + help="Port of the server",
  64 + )
  65 +
  66 + parser.add_argument(
  67 + "sound_files",
  68 + type=str,
  69 + nargs="+",
  70 + help="The input sound file(s) to decode. Each file must be of WAVE"
  71 + "format with a single channel, and each sample has 16-bit, "
  72 + "i.e., int16_t. "
  73 + "The sample rate of the file can be arbitrary and does not need to "
  74 + "be 16 kHz",
  75 + )
  76 +
  77 + return parser.parse_args()
  78 +
  79 +
  80 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  81 + """
  82 + Args:
  83 + wave_filename:
  84 + Path to a wave file. It should be single channel and each sample should
  85 + be 16-bit. Its sample rate does not need to be 16kHz.
  86 + Returns:
  87 + Return a tuple containing:
  88 + - A 1-D array of dtype np.float32 containing the samples, which are
  89 + normalized to the range [-1, 1].
  90 + - sample rate of the wave file
  91 + """
  92 +
  93 + with wave.open(wave_filename) as f:
  94 + assert f.getnchannels() == 1, f.getnchannels()
  95 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  96 + num_samples = f.getnframes()
  97 + samples = f.readframes(num_samples)
  98 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  99 + samples_float32 = samples_int16.astype(np.float32)
  100 +
  101 + samples_float32 = samples_float32 / 32768
  102 + return samples_float32, f.getframerate()
  103 +
  104 +
  105 +async def run(
  106 + server_addr: str,
  107 + server_port: int,
  108 + sound_files: List[str],
  109 +):
  110 + async with websockets.connect(
  111 + f"ws://{server_addr}:{server_port}"
  112 + ) as websocket: # noqa
  113 + for wave_filename in sound_files:
  114 + logging.info(f"Sending {wave_filename}")
  115 + samples, sample_rate = read_wave(wave_filename)
  116 + assert isinstance(sample_rate, int)
  117 + assert samples.dtype == np.float32, samples.dtype
  118 + assert samples.ndim == 1, samples.dim
  119 + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes
  120 + buf += (samples.size * 4).to_bytes(4, byteorder="little")
  121 + buf += samples.tobytes()
  122 +
  123 + await websocket.send(buf)
  124 +
  125 + decoding_results = await websocket.recv()
  126 + print(decoding_results)
  127 +
  128 + # to signal that the client has sent all the data
  129 + await websocket.send("Done")
  130 +
  131 +
  132 +async def main():
  133 + args = get_args()
  134 + logging.info(vars(args))
  135 +
  136 + server_addr = args.server_addr
  137 + server_port = args.server_port
  138 + sound_files = args.sound_files
  139 +
  140 + await run(
  141 + server_addr=server_addr,
  142 + server_port=server_port,
  143 + sound_files=sound_files,
  144 + )
  145 +
  146 +
  147 +if __name__ == "__main__":
  148 + formatter = (
  149 + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
  150 + )
  151 + logging.basicConfig(format=formatter, level=logging.INFO)
  152 + asyncio.run(main())
1 #!/usr/bin/env python3 1 #!/usr/bin/env python3
2 2
3 """ 3 """
4 -This file demonstrates how to use sherpa-onnx Python API to recognize  
5 -a single file. 4 +This file demonstrates how to use sherpa-onnx Python API to transcribe
  5 +file(s) with a streaming model.
  6 +
  7 +Usage:
  8 + ./online-decode-files.py \
  9 + /path/to/foo.wav \
  10 + /path/to/bar.wav \
  11 + /path/to/16kHz.wav \
  12 + /path/to/8kHz.wav
6 13
7 Please refer to 14 Please refer to
8 https://k2-fsa.github.io/sherpa/onnx/index.html 15 https://k2-fsa.github.io/sherpa/onnx/index.html
@@ -13,17 +20,12 @@ import argparse @@ -13,17 +20,12 @@ import argparse
13 import time 20 import time
14 import wave 21 import wave
15 from pathlib import Path 22 from pathlib import Path
  23 +from typing import Tuple
16 24
17 import numpy as np 25 import numpy as np
18 import sherpa_onnx 26 import sherpa_onnx
19 27
20 28
21 -def assert_file_exists(filename: str):  
22 - assert Path(  
23 - filename  
24 - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"  
25 -  
26 -  
27 def get_args(): 29 def get_args():
28 parser = argparse.ArgumentParser( 30 parser = argparse.ArgumentParser(
29 formatter_class=argparse.ArgumentDefaultsHelpFormatter 31 formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -68,26 +70,58 @@ def get_args(): @@ -68,26 +70,58 @@ def get_args():
68 ) 70 )
69 71
70 parser.add_argument( 72 parser.add_argument(
71 - "--wave-filename", 73 + "sound_files",
72 type=str, 74 type=str,
73 - help="""Path to the wave filename.  
74 - Should have a single channel with 16-bit samples.  
75 - It does not need to be 16kHz. It can have any sampling rate.  
76 - """, 75 + nargs="+",
  76 + help="The input sound file(s) to decode. Each file must be of WAVE"
  77 + "format with a single channel, and each sample has 16-bit, "
  78 + "i.e., int16_t. "
  79 + "The sample rate of the file can be arbitrary and does not need to "
  80 + "be 16 kHz",
77 ) 81 )
78 82
79 return parser.parse_args() 83 return parser.parse_args()
80 84
81 85
  86 +def assert_file_exists(filename: str):
  87 + assert Path(filename).is_file(), (
  88 + f"{filename} does not exist!\n"
  89 + "Please refer to "
  90 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  91 + )
  92 +
  93 +
  94 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  95 + """
  96 + Args:
  97 + wave_filename:
  98 + Path to a wave file. It should be single channel and each sample should
  99 + be 16-bit. Its sample rate does not need to be 16kHz.
  100 + Returns:
  101 + Return a tuple containing:
  102 + - A 1-D array of dtype np.float32 containing the samples, which are
  103 + normalized to the range [-1, 1].
  104 + - sample rate of the wave file
  105 + """
  106 +
  107 + with wave.open(wave_filename) as f:
  108 + assert f.getnchannels() == 1, f.getnchannels()
  109 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  110 + num_samples = f.getnframes()
  111 + samples = f.readframes(num_samples)
  112 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  113 + samples_float32 = samples_int16.astype(np.float32)
  114 +
  115 + samples_float32 = samples_float32 / 32768
  116 + return samples_float32, f.getframerate()
  117 +
  118 +
82 def main(): 119 def main():
83 args = get_args() 120 args = get_args()
84 assert_file_exists(args.encoder) 121 assert_file_exists(args.encoder)
85 assert_file_exists(args.decoder) 122 assert_file_exists(args.decoder)
86 assert_file_exists(args.joiner) 123 assert_file_exists(args.joiner)
87 assert_file_exists(args.tokens) 124 assert_file_exists(args.tokens)
88 - if not Path(args.wave_filename).is_file():  
89 - print(f"{args.wave_filename} does not exist!")  
90 - return  
91 125
92 recognizer = sherpa_onnx.OnlineRecognizer( 126 recognizer = sherpa_onnx.OnlineRecognizer(
93 tokens=args.tokens, 127 tokens=args.tokens,
@@ -99,42 +133,44 @@ def main(): @@ -99,42 +133,44 @@ def main():
99 feature_dim=80, 133 feature_dim=80,
100 decoding_method=args.decoding_method, 134 decoding_method=args.decoding_method,
101 ) 135 )
102 - with wave.open(args.wave_filename) as f:  
103 - # If the wave file has a different sampling rate from the one  
104 - # expected by the model (16 kHz in our case), we will do  
105 - # resampling inside sherpa-onnx  
106 - wave_file_sample_rate = f.getframerate()  
107 -  
108 - assert f.getnchannels() == 1, f.getnchannels()  
109 - assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes  
110 - num_samples = f.getnframes()  
111 - samples = f.readframes(num_samples)  
112 - samples_int16 = np.frombuffer(samples, dtype=np.int16)  
113 - samples_float32 = samples_int16.astype(np.float32)  
114 -  
115 - samples_float32 = samples_float32 / 32768  
116 136
117 - duration = len(samples_float32) / wave_file_sample_rate  
118 -  
119 - start_time = time.time()  
120 print("Started!") 137 print("Started!")
  138 + start_time = time.time()
121 139
122 - stream = recognizer.create_stream()  
123 -  
124 - stream.accept_waveform(wave_file_sample_rate, samples_float32)  
125 -  
126 - tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)  
127 - stream.accept_waveform(wave_file_sample_rate, tail_paddings)  
128 -  
129 - stream.input_finished()  
130 -  
131 - while recognizer.is_ready(stream):  
132 - recognizer.decode_stream(stream) 140 + streams = []
  141 + total_duration = 0
  142 + for wave_filename in args.sound_files:
  143 + assert_file_exists(wave_filename)
  144 + samples, sample_rate = read_wave(wave_filename)
  145 + duration = len(samples) / sample_rate
  146 + total_duration += duration
  147 +
  148 + s = recognizer.create_stream()
  149 + s.accept_waveform(sample_rate, samples)
  150 +
  151 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  152 + s.accept_waveform(sample_rate, tail_paddings)
  153 +
  154 + s.input_finished()
  155 +
  156 + streams.append(s)
  157 +
  158 + while True:
  159 + ready_list = []
  160 + for s in streams:
  161 + if recognizer.is_ready(s):
  162 + ready_list.append(s)
  163 + if len(ready_list) == 0:
  164 + break
  165 + recognizer.decode_streams(ready_list)
  166 + results = [recognizer.get_result(s) for s in streams]
  167 + end_time = time.time()
  168 + print("Done!")
133 169
134 - print(recognizer.get_result(stream)) 170 + for wave_filename, result in zip(args.sound_files, results):
  171 + print(f"{wave_filename}\n{result}")
  172 + print("-" * 10)
135 173
136 - print("Done!")  
137 - end_time = time.time()  
138 elapsed_seconds = end_time - start_time 174 elapsed_seconds = end_time - start_time
139 rtf = elapsed_seconds / duration 175 rtf = elapsed_seconds / duration
140 print(f"num_threads: {args.num_threads}") 176 print(f"num_threads: {args.num_threads}")
@@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc @@ -27,7 +27,6 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
27 import argparse 27 import argparse
28 import asyncio 28 import asyncio
29 import logging 29 import logging
30 -import time  
31 import wave 30 import wave
32 31
33 try: 32 try:
@@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc @@ -24,13 +24,12 @@ https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websoc
24 import argparse 24 import argparse
25 import asyncio 25 import asyncio
26 import sys 26 import sys
27 -import time  
28 27
29 import numpy as np 28 import numpy as np
30 29
31 try: 30 try:
32 import sounddevice as sd 31 import sounddevice as sd
33 -except ImportError as e: 32 +except ImportError:
34 print("Please install sounddevice first. You can use") 33 print("Please install sounddevice first. You can use")
35 print() 34 print()
36 print(" pip install sounddevice") 35 print(" pip install sounddevice")
@@ -134,7 +133,7 @@ async def run( @@ -134,7 +133,7 @@ async def run(
134 await websocket.send(indata.tobytes()) 133 await websocket.send(indata.tobytes())
135 134
136 decoding_results = await receive_task 135 decoding_results = await receive_task
137 - print("\nFinal result is:\n{decoding_results}") 136 + print(f"\nFinal result is:\n{decoding_results}")
138 137
139 138
140 async def main(): 139 async def main():
@@ -13,7 +13,7 @@ from pathlib import Path @@ -13,7 +13,7 @@ from pathlib import Path
13 13
14 try: 14 try:
15 import sounddevice as sd 15 import sounddevice as sd
16 -except ImportError as e: 16 +except ImportError:
17 print("Please install sounddevice first. You can use") 17 print("Please install sounddevice first. You can use")
18 print() 18 print()
19 print(" pip install sounddevice") 19 print(" pip install sounddevice")
@@ -25,9 +25,11 @@ import sherpa_onnx @@ -25,9 +25,11 @@ import sherpa_onnx
25 25
26 26
27 def assert_file_exists(filename: str): 27 def assert_file_exists(filename: str):
28 - assert Path(  
29 - filename  
30 - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" 28 + assert Path(filename).is_file(), (
  29 + f"{filename} does not exist!\n"
  30 + "Please refer to "
  31 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  32 + )
31 33
32 34
33 def get_args(): 35 def get_args():
@@ -12,7 +12,7 @@ from pathlib import Path @@ -12,7 +12,7 @@ from pathlib import Path
12 12
13 try: 13 try:
14 import sounddevice as sd 14 import sounddevice as sd
15 -except ImportError as e: 15 +except ImportError:
16 print("Please install sounddevice first. You can use") 16 print("Please install sounddevice first. You can use")
17 print() 17 print()
18 print(" pip install sounddevice") 18 print(" pip install sounddevice")
@@ -24,9 +24,11 @@ import sherpa_onnx @@ -24,9 +24,11 @@ import sherpa_onnx
24 24
25 25
26 def assert_file_exists(filename: str): 26 def assert_file_exists(filename: str):
27 - assert Path(  
28 - filename  
29 - ).is_file(), f"{filename} does not exist!\nPlease refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" 27 + assert Path(filename).is_file(), (
  28 + f"{filename} does not exist!\n"
  29 + "Please refer to "
  30 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  31 + )
30 32
31 33
32 def get_args(): 34 def get_args():
@@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) @@ -128,7 +128,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
128 ) 128 )
129 target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core) 129 target_link_libraries(sherpa-onnx-online-websocket-server sherpa-onnx-core)
130 130
131 -  
132 add_executable(sherpa-onnx-online-websocket-client 131 add_executable(sherpa-onnx-online-websocket-client
133 online-websocket-client.cc 132 online-websocket-client.cc
134 ) 133 )
@@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) @@ -142,6 +141,17 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
142 target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations) 141 target_compile_options(sherpa-onnx-online-websocket-client PRIVATE -Wno-deprecated-declarations)
143 endif() 142 endif()
144 143
  144 + # For offline websocket
  145 + add_executable(sherpa-onnx-offline-websocket-server
  146 + offline-websocket-server-impl.cc
  147 + offline-websocket-server.cc
  148 + )
  149 + target_link_libraries(sherpa-onnx-offline-websocket-server sherpa-onnx-core)
  150 +
  151 + if(NOT WIN32)
  152 + target_link_libraries(sherpa-onnx-offline-websocket-server -pthread)
  153 + target_compile_options(sherpa-onnx-offline-websocket-server PRIVATE -Wno-deprecated-declarations)
  154 + endif()
145 endif() 155 endif()
146 156
147 157
  1 +// sherpa-onnx/csrc/offline-websocket-server-impl.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
  6 +
  7 +#include <algorithm>
  8 +
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) {
  14 + recognizer_config.Register(po);
  15 +
  16 + po->Register("max-batch-size", &max_batch_size,
  17 + "Max batch size for decoding.");
  18 +
  19 + po->Register(
  20 + "max-utterance-length", &max_utterance_length,
  21 + "Max utterance length in seconds. If we receive an utterance "
  22 + "longer than this value, we will reject the connection. "
  23 + "If you have enough memory, you can select a large value for it.");
  24 +}
  25 +
  26 +void OfflineWebsocketDecoderConfig::Validate() const {
  27 + if (!recognizer_config.Validate()) {
  28 + SHERPA_ONNX_LOGE("Error in recongizer config");
  29 + exit(-1);
  30 + }
  31 +
  32 + if (max_batch_size <= 0) {
  33 + SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size);
  34 + exit(-1);
  35 + }
  36 +
  37 + if (max_utterance_length <= 0) {
  38 + SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f",
  39 + max_utterance_length);
  40 + exit(-1);
  41 + }
  42 +}
  43 +
  44 +OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server)
  45 + : config_(server->GetConfig().decoder_config),
  46 + server_(server),
  47 + recognizer_(config_.recognizer_config) {}
  48 +
  49 +void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) {
  50 + std::lock_guard<std::mutex> lock(mutex_);
  51 + streams_.push_back({hdl, d});
  52 +}
  53 +
  54 +void OfflineWebsocketDecoder::Decode() {
  55 + std::unique_lock<std::mutex> lock(mutex_);
  56 + if (streams_.empty()) {
  57 + return;
  58 + }
  59 +
  60 + int32_t size =
  61 + std::min(static_cast<int32_t>(streams_.size()), config_.max_batch_size);
  62 + SHERPA_ONNX_LOGE("size: %d", size);
  63 +
  64 + // We first lock the mutex for streams_, take items from it, and then
  65 + // unlock the mutex; in doing so we don't need to lock the mutex to
  66 + // access hdl and connection_data later.
  67 + std::vector<connection_hdl> handles(size);
  68 +
  69 + // Store connection_data here to prevent the data from being freed
  70 + // while we are still using it.
  71 + std::vector<ConnectionDataPtr> connection_data(size);
  72 +
  73 + std::vector<const float *> samples(size);
  74 + std::vector<int32_t> samples_length(size);
  75 + std::vector<std::unique_ptr<OfflineStream>> ss(size);
  76 + std::vector<OfflineStream *> p_ss(size);
  77 +
  78 + for (int32_t i = 0; i != size; ++i) {
  79 + auto &p = streams_.front();
  80 + handles[i] = p.first;
  81 + connection_data[i] = p.second;
  82 + streams_.pop_front();
  83 +
  84 + auto sample_rate = connection_data[i]->sample_rate;
  85 + auto samples =
  86 + reinterpret_cast<const float *>(&connection_data[i]->data[0]);
  87 + auto num_samples = connection_data[i]->expected_byte_size / sizeof(float);
  88 + auto s = recognizer_.CreateStream();
  89 + s->AcceptWaveform(sample_rate, samples, num_samples);
  90 +
  91 + ss[i] = std::move(s);
  92 + p_ss[i] = ss[i].get();
  93 + }
  94 +
  95 + lock.unlock();
  96 +
  97 + // Note: DecodeStreams is thread-safe
  98 + recognizer_.DecodeStreams(p_ss.data(), size);
  99 +
  100 + for (int32_t i = 0; i != size; ++i) {
  101 + connection_hdl hdl = handles[i];
  102 + asio::post(server_->GetConnectionContext(),
  103 + [this, hdl, text = ss[i]->GetResult().text]() {
  104 + websocketpp::lib::error_code ec;
  105 + server_->GetServer().send(
  106 + hdl, text, websocketpp::frame::opcode::text, ec);
  107 + if (ec) {
  108 + server_->GetServer().get_alog().write(
  109 + websocketpp::log::alevel::app, ec.message());
  110 + }
  111 + });
  112 + }
  113 +}
  114 +
  115 +void OfflineWebsocketServerConfig::Register(ParseOptions *po) {
  116 + decoder_config.Register(po);
  117 + po->Register("log-file", &log_file,
  118 + "Path to the log file. Logs are "
  119 + "appended to this file");
  120 +}
  121 +
  122 +void OfflineWebsocketServerConfig::Validate() const {
  123 + decoder_config.Validate();
  124 +}
  125 +
  126 +OfflineWebsocketServer::OfflineWebsocketServer(
  127 + asio::io_context &io_conn, // NOLINT
  128 + asio::io_context &io_work, // NOLINT
  129 + const OfflineWebsocketServerConfig &config)
  130 + : io_conn_(io_conn),
  131 + io_work_(io_work),
  132 + config_(config),
  133 + log_(config.log_file, std::ios::app),
  134 + tee_(std::cout, log_),
  135 + decoder_(this) {
  136 + SetupLog();
  137 +
  138 + server_.init_asio(&io_conn_);
  139 +
  140 + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); });
  141 +
  142 + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); });
  143 +
  144 + server_.set_message_handler(
  145 + [this](connection_hdl hdl, server::message_ptr msg) {
  146 + OnMessage(hdl, msg);
  147 + });
  148 +}
  149 +
  150 +void OfflineWebsocketServer::SetupLog() {
  151 + server_.clear_access_channels(websocketpp::log::alevel::all);
  152 + server_.set_access_channels(websocketpp::log::alevel::connect);
  153 + server_.set_access_channels(websocketpp::log::alevel::disconnect);
  154 +
  155 + // So that it also prints to std::cout and std::cerr
  156 + server_.get_alog().set_ostream(&tee_);
  157 + server_.get_elog().set_ostream(&tee_);
  158 +}
  159 +
  160 +void OfflineWebsocketServer::OnOpen(connection_hdl hdl) {
  161 + std::lock_guard<std::mutex> lock(mutex_);
  162 + connections_.emplace(hdl, std::make_shared<ConnectionData>());
  163 +
  164 + SHERPA_ONNX_LOGE("Number of active connections: %d",
  165 + static_cast<int32_t>(connections_.size()));
  166 +}
  167 +
  168 +void OfflineWebsocketServer::OnClose(connection_hdl hdl) {
  169 + std::lock_guard<std::mutex> lock(mutex_);
  170 + connections_.erase(hdl);
  171 +
  172 + SHERPA_ONNX_LOGE("Number of active connections: %d",
  173 + static_cast<int32_t>(connections_.size()));
  174 +}
  175 +
  176 +void OfflineWebsocketServer::OnMessage(connection_hdl hdl,
  177 + server::message_ptr msg) {
  178 + std::unique_lock<std::mutex> lock(mutex_);
  179 + auto connection_data = connections_.find(hdl)->second;
  180 + lock.unlock();
  181 + const std::string &payload = msg->get_payload();
  182 +
  183 + switch (msg->get_opcode()) {
  184 + case websocketpp::frame::opcode::text:
  185 + if (payload == "Done") {
  186 + // The client will not send any more data. We can close the
  187 + // connection now.
  188 + Close(hdl, websocketpp::close::status::normal, "Done");
  189 + } else {
  190 + Close(hdl, websocketpp::close::status::normal,
  191 + std::string("Invalid payload: ") + payload);
  192 + }
  193 + break;
  194 +
  195 + case websocketpp::frame::opcode::binary: {
  196 + auto p = reinterpret_cast<const int8_t *>(payload.data());
  197 +
  198 + if (connection_data->expected_byte_size == 0) {
  199 + if (payload.size() < 8) {
  200 + Close(hdl, websocketpp::close::status::normal,
  201 + "Payload is too short");
  202 + break;
  203 + }
  204 +
  205 + connection_data->sample_rate = *reinterpret_cast<const int32_t *>(p);
  206 +
  207 + connection_data->expected_byte_size =
  208 + *reinterpret_cast<const int32_t *>(p + 4);
  209 +
  210 + int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length *
  211 + connection_data->sample_rate * sizeof(float);
  212 + if (connection_data->expected_byte_size > max_byte_size_) {
  213 + float num_samples =
  214 + connection_data->expected_byte_size / sizeof(float);
  215 +
  216 + float duration = num_samples / connection_data->sample_rate;
  217 +
  218 + std::ostringstream os;
  219 + os << "Max utterance length is configured to "
  220 + << decoder_.GetConfig().max_utterance_length
  221 + << " seconds, received length is " << duration << " seconds. "
  222 + << "Payload is too large!";
  223 + Close(hdl, websocketpp::close::status::message_too_big, os.str());
  224 + break;
  225 + }
  226 +
  227 + connection_data->data.resize(connection_data->expected_byte_size);
  228 + std::copy(payload.begin() + 8, payload.end(),
  229 + connection_data->data.data());
  230 + connection_data->cur = payload.size() - 8;
  231 + } else {
  232 + std::copy(payload.begin(), payload.end(),
  233 + connection_data->data.data() + connection_data->cur);
  234 + connection_data->cur += payload.size();
  235 + }
  236 +
  237 + if (connection_data->expected_byte_size == connection_data->cur) {
  238 + auto d = std::make_shared<ConnectionData>(std::move(*connection_data));
  239 + // Clear it so that we can handle the next audio file from the client.
  240 + // The client can send multiple audio files for recognition without
  241 + // the need to create another connection.
  242 + connection_data->sample_rate = 0;
  243 + connection_data->expected_byte_size = 0;
  244 + connection_data->cur = 0;
  245 +
  246 + decoder_.Push(hdl, d);
  247 +
  248 + connection_data->Clear();
  249 +
  250 + asio::post(io_work_, [this]() { decoder_.Decode(); });
  251 + }
  252 + break;
  253 + }
  254 +
  255 + default:
  256 + // Unexpected message, ignore it
  257 + break;
  258 + }
  259 +}
  260 +
  261 +void OfflineWebsocketServer::Close(connection_hdl hdl,
  262 + websocketpp::close::status::value code,
  263 + const std::string &reason) {
  264 + auto con = server_.get_con_from_hdl(hdl);
  265 +
  266 + std::ostringstream os;
  267 + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason
  268 + << "\n";
  269 +
  270 + websocketpp::lib::error_code ec;
  271 + server_.close(hdl, code, reason, ec);
  272 + if (ec) {
  273 + os << "Failed to close" << con->get_remote_endpoint() << ". "
  274 + << ec.message() << "\n";
  275 + }
  276 + server_.get_alog().write(websocketpp::log::alevel::app, os.str());
  277 +}
  278 +
  279 +void OfflineWebsocketServer::Run(uint16_t port) {
  280 + server_.set_reuse_addr(true);
  281 + server_.listen(asio::ip::tcp::v4(), port);
  282 + server_.start_accept();
  283 +}
  284 +
  285 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-websocket-server-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
  7 +
  8 +#include <deque>
  9 +#include <fstream>
  10 +#include <map>
  11 +#include <memory>
  12 +#include <string>
  13 +#include <utility>
  14 +#include <vector>
  15 +
  16 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  17 +#include "sherpa-onnx/csrc/parse-options.h"
  18 +#include "sherpa-onnx/csrc/tee-stream.h"
  19 +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
  20 +#include "websocketpp/server.hpp"
  21 +
  22 +using server = websocketpp::server<websocketpp::config::asio>;
  23 +using connection_hdl = websocketpp::connection_hdl;
  24 +
  25 +namespace sherpa_onnx {
  26 +
  27 +/** Communication protocol
  28 + *
  29 + * The client sends a byte stream to the server. The first 4 bytes in little
  30 + * endian indicates the sample rate of the audio data that the client will send.
  31 + * The next 4 bytes in little endian indicates the total samples in bytes the
  32 + * client will send. The remaining bytes represent audio samples. Each audio
  33 + * sample is a float occupying 4 bytes and is normalized into the range
  34 + * [-1, 1].
  35 + *
  36 + * The byte stream can be broken into arbitrary number of messages.
  37 + * We require that the first message has to be at least 8 bytes so that
  38 + * we can get `sample_rate` and `expected_byte_size` from the first message.
  39 + */
  40 +struct ConnectionData {
  41 + // Sample rate of the audio samples the client
  42 + int32_t sample_rate;
  43 +
  44 + // Number of expected bytes sent from the client
  45 + int32_t expected_byte_size = 0;
  46 +
  47 + // Number of bytes received so far
  48 + int32_t cur = 0;
  49 +
  50 + // It saves the received samples from the client.
  51 + // We will **reinterpret_cast** it to float.
  52 + // We expect that data.size() == expected_byte_size
  53 + std::vector<int8_t> data;
  54 +
  55 + void Clear() {
  56 + sample_rate = 0;
  57 + expected_byte_size = 0;
  58 + cur = 0;
  59 + data.clear();
  60 + }
  61 +};
  62 +
  63 +using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
  64 +
  65 +struct OfflineWebsocketDecoderConfig {
  66 + OfflineRecognizerConfig recognizer_config;
  67 +
  68 + int32_t max_batch_size = 5;
  69 +
  70 + float max_utterance_length = 300; // seconds
  71 +
  72 + void Register(ParseOptions *po);
  73 + void Validate() const;
  74 +};
  75 +
  76 +class OfflineWebsocketServer;
  77 +
  78 +class OfflineWebsocketDecoder {
  79 + public:
  80 + /**
  81 + * @param config Configuration for the decoder.
  82 + * @param server **Borrowed** from outside.
  83 + */
  84 + explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
  85 +
  86 + /** Insert received data to the queue for decoding.
  87 + *
  88 + * @param hdl A handle to the connection. We can use it to send the result
  89 + * back to the client once it finishes decoding.
  90 + * @param d The received data
  91 + */
  92 + void Push(connection_hdl hdl, ConnectionDataPtr d);
  93 +
  94 + /** It is called by one of the work thread.
  95 + */
  96 + void Decode();
  97 +
  98 + const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
  99 +
  100 + private:
  101 + OfflineWebsocketDecoderConfig config_;
  102 +
  103 + /** When we have received all the data from the client, we put it into
  104 + * this queue; the worker threads will get items from this queue for
  105 + * decoding.
  106 + *
  107 + * Number of items to take from this queue is determined by
  108 + * `--max-batch-size`. If there are not enough items in the queue, we won't
  109 + * wait and take whatever we have for decoding.
  110 + */
  111 + std::mutex mutex_;
  112 + std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
  113 +
  114 + OfflineWebsocketServer *server_; // Not owned
  115 + OfflineRecognizer recognizer_;
  116 +};
  117 +
  118 +struct OfflineWebsocketServerConfig {
  119 + OfflineWebsocketDecoderConfig decoder_config;
  120 + std::string log_file = "./log.txt";
  121 +
  122 + void Register(ParseOptions *po);
  123 + void Validate() const;
  124 +};
  125 +
  126 +class OfflineWebsocketServer {
  127 + public:
  128 + OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
  129 + asio::io_context &io_work, // NOLINT
  130 + const OfflineWebsocketServerConfig &config);
  131 +
  132 + asio::io_context &GetConnectionContext() { return io_conn_; }
  133 + server &GetServer() { return server_; }
  134 +
  135 + void Run(uint16_t port);
  136 +
  137 + const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
  138 +
  139 + private:
  140 + void SetupLog();
  141 +
  142 + // When a websocket client is connected, it will invoke this method
  143 + // (Not for HTTP)
  144 + void OnOpen(connection_hdl hdl);
  145 +
  146 + // When a websocket client is disconnected, it will invoke this method
  147 + void OnClose(connection_hdl hdl);
  148 +
  149 + // When a message is received from a websocket client, this method will
  150 + // be invoked.
  151 + //
  152 + // The protocol between the client and the server is as follows:
  153 + //
  154 + // (1) The client connects to the server
  155 + // (2) The client starts to send binary byte stream to the server.
  156 + // The byte stream can be broken into multiple messages or it can
  157 + // be put into a single message.
  158 + // The first message has to contain at least 8 bytes. The first
  159 + // 4 bytes in little endian contains a int32_t indicating the
  160 + // sampling rate. The next 4 bytes in little endian contains a int32_t
  161 + // indicating total number of bytes of samples the client will send.
  162 + // We assume each sample is a float containing 4 bytes and has been
  163 + // normalized to the range [-1, 1].
  164 + // (4) When the server receives all the samples from the client, it will
  165 + // start to decode them. Once decoded, the server sends a text message
  166 + // to the client containing the decoded results
  167 + // (5) After receiving the decoded results from the server, if the client has
  168 + // another audio file to send, it repeats (2), (3), (4)
  169 + // (6) If the client has no more audio files to decode, the client sends a
  170 + // text message containing "Done" to the server and closes the connection
  171 + // (7) The server receives a text message "Done" and closes the connection
  172 + //
  173 + // Note:
  174 + // (a) All models in icefall use features extracted from audio samples
  175 + // normalized to the range [-1, 1]. Please send normalized audio samples
  176 + // if you use models from icefall.
  177 + // (b) Only sound files with a single channel is supported
  178 + // (c) Only audio samples are sent. For instance, if we want to decode
  179 + // a WAVE file, the RIFF header of the WAVE is not sent.
  180 + void OnMessage(connection_hdl hdl, server::message_ptr msg);
  181 +
  182 + // Close a websocket connection with given code and reason
  183 + void Close(connection_hdl hdl, websocketpp::close::status::value code,
  184 + const std::string &reason);
  185 +
  186 + private:
  187 + asio::io_context &io_conn_;
  188 + asio::io_context &io_work_;
  189 + server server_;
  190 +
  191 + std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
  192 + connections_;
  193 + std::mutex mutex_;
  194 +
  195 + OfflineWebsocketServerConfig config_;
  196 +
  197 + std::ofstream log_;
  198 + TeeStream tee_;
  199 +
  200 + OfflineWebsocketDecoder decoder_;
  201 +};
  202 +
  203 +} // namespace sherpa_onnx
  204 +
  205 +#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-websocket-server.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "asio.hpp"
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/csrc/offline-websocket-server-impl.h"
  8 +#include "sherpa-onnx/csrc/parse-options.h"
  9 +
  10 +static constexpr const char *kUsageMessage = R"(
  11 +Automatic speech recognition with sherpa-onnx using websocket.
  12 +
  13 +Usage:
  14 +
  15 +./bin/sherpa-onnx-offline-websocket-server --help
  16 +
  17 +(1) For transducer models
  18 +
  19 +./bin/sherpa-onnx-offline-websocket-server \
  20 + --port=6006 \
  21 + --num-work-threads=5 \
  22 + --tokens=/path/to/tokens.txt \
  23 + --encoder=/path/to/encoder.onnx \
  24 + --decoder=/path/to/decoder.onnx \
  25 + --joiner=/path/to/joiner.onnx \
  26 + --log-file=./log.txt \
  27 + --max-batch-size=5
  28 +
  29 +(2) For Paraformer
  30 +
  31 +./bin/sherpa-onnx-offline-websocket-server \
  32 + --port=6006 \
  33 + --num-work-threads=5 \
  34 + --tokens=/path/to/tokens.txt \
  35 + --paraformer=/path/to/model.onnx \
  36 + --log-file=./log.txt \
  37 + --max-batch-size=5
  38 +
  39 +Please refer to
  40 +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  41 +for a list of pre-trained models to download.
  42 +)";
  43 +
  44 +int32_t main(int32_t argc, char *argv[]) {
  45 + sherpa_onnx::ParseOptions po(kUsageMessage);
  46 +
  47 + sherpa_onnx::OfflineWebsocketServerConfig config;
  48 +
  49 + // the server will listen on this port
  50 + int32_t port = 6006;
  51 +
  52 + // size of the thread pool for handling network connections
  53 + int32_t num_io_threads = 1;
  54 +
  55 + // size of the thread pool for neural network computation and decoding
  56 + int32_t num_work_threads = 3;
  57 +
  58 + po.Register("num-io-threads", &num_io_threads,
  59 + "Thread pool size for network connections.");
  60 +
  61 + po.Register("num-work-threads", &num_work_threads,
  62 + "Thread pool size for for neural network "
  63 + "computation and decoding.");
  64 +
  65 + po.Register("port", &port, "The port on which the server will listen.");
  66 +
  67 + config.Register(&po);
  68 +
  69 + if (argc == 1) {
  70 + po.PrintUsage();
  71 + exit(EXIT_FAILURE);
  72 + }
  73 +
  74 + po.Read(argc, argv);
  75 +
  76 + if (po.NumArgs() != 0) {
  77 + SHERPA_ONNX_LOGE("Unrecognized positional arguments!");
  78 + po.PrintUsage();
  79 + exit(EXIT_FAILURE);
  80 + }
  81 +
  82 + config.Validate();
  83 +
  84 + asio::io_context io_conn; // for network connections
  85 + asio::io_context io_work; // for neural network and decoding
  86 +
  87 + sherpa_onnx::OfflineWebsocketServer server(io_conn, io_work, config);
  88 + server.Run(port);
  89 +
  90 + SHERPA_ONNX_LOGE("Started!");
  91 + SHERPA_ONNX_LOGE("Listening on: %d", port);
  92 + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
  93 +
  94 + // give some work to do for the io_work pool
  95 + auto work_guard = asio::make_work_guard(io_work);
  96 +
  97 + std::vector<std::thread> io_threads;
  98 +
  99 + // decrement since the main thread is also used for network communications
  100 + for (int32_t i = 0; i < num_io_threads - 1; ++i) {
  101 + io_threads.emplace_back([&io_conn]() { io_conn.run(); });
  102 + }
  103 +
  104 + std::vector<std::thread> work_threads;
  105 + for (int32_t i = 0; i < num_work_threads; ++i) {
  106 + work_threads.emplace_back([&io_work]() { io_work.run(); });
  107 + }
  108 +
  109 + io_conn.run();
  110 +
  111 + for (auto &t : io_threads) {
  112 + t.join();
  113 + }
  114 +
  115 + for (auto &t : work_threads) {
  116 + t.join();
  117 + }
  118 +
  119 + return 0;
  120 +}
@@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) { @@ -76,6 +76,7 @@ int32_t main(int32_t argc, char *argv[]) {
76 sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config); 76 sherpa_onnx::OnlineWebsocketServer server(io_conn, io_work, config);
77 server.Run(port); 77 server.Run(port);
78 78
  79 + SHERPA_ONNX_LOGE("Started!");
79 SHERPA_ONNX_LOGE("Listening on: %d", port); 80 SHERPA_ONNX_LOGE("Listening on: %d", port);
80 SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); 81 SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads);
81 82