正在显示
3 个修改的文件
包含
337 行增加
和
0 行删除
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +A websocket client for sherpa-onnx-online-websocket-server | ||
| 7 | + | ||
| 8 | +Usage: | ||
| 9 | + ./online-websocket-client-decode-file.py \ | ||
| 10 | + --server-addr localhost \ | ||
| 11 | + --server-port 6006 \ | ||
| 12 | + --seconds-per-message 0.1 \ | ||
| 13 | + --samples-per-message 8000 \ | ||
| 14 | + /path/to/foo.wav | ||
| 15 | + | ||
| 16 | +(Note: You have to first start the server before starting the client) | ||
| 17 | + | ||
| 18 | +You can find the server at | ||
| 19 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc | ||
| 20 | + | ||
| 21 | +Note: The server is implemented in C++. | ||
| 22 | + | ||
| 23 | +There is also a C++ version of the client. Please see | ||
| 24 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc | ||
| 25 | +""" | ||
| 26 | + | ||
| 27 | +import argparse | ||
| 28 | +import asyncio | ||
| 29 | +import logging | ||
| 30 | +import time | ||
| 31 | +import wave | ||
| 32 | + | ||
| 33 | +try: | ||
| 34 | + import websockets | ||
| 35 | +except ImportError: | ||
| 36 | + print("please run:") | ||
| 37 | + print("") | ||
| 38 | + print(" pip install websockets") | ||
| 39 | + print("") | ||
| 40 | + print("before you run this script") | ||
| 41 | + print("") | ||
| 42 | + | ||
| 43 | +import numpy as np | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +def read_wave(wave_filename: str) -> np.ndarray: | ||
| 47 | + """ | ||
| 48 | + Args: | ||
| 49 | + wave_filename: | ||
| 50 | + Path to a wave file. Its sampling rate has to be 16000. | ||
| 51 | + It should be single channel and each sample should be 16-bit. | ||
| 52 | + Returns: | ||
| 53 | + Return a 1-D float32 tensor. | ||
| 54 | + """ | ||
| 55 | + | ||
| 56 | + with wave.open(wave_filename) as f: | ||
| 57 | + assert f.getframerate() == 16000, f.getframerate() | ||
| 58 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 59 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 60 | + num_samples = f.getnframes() | ||
| 61 | + samples = f.readframes(num_samples) | ||
| 62 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 63 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 64 | + | ||
| 65 | + samples_float32 = samples_float32 / 32768 | ||
| 66 | + return samples_float32 | ||
| 67 | + | ||
| 68 | + | ||
| 69 | +def get_args(): | ||
| 70 | + parser = argparse.ArgumentParser( | ||
| 71 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 72 | + ) | ||
| 73 | + | ||
| 74 | + parser.add_argument( | ||
| 75 | + "--server-addr", | ||
| 76 | + type=str, | ||
| 77 | + default="localhost", | ||
| 78 | + help="Address of the server", | ||
| 79 | + ) | ||
| 80 | + | ||
| 81 | + parser.add_argument( | ||
| 82 | + "--server-port", | ||
| 83 | + type=int, | ||
| 84 | + default=6006, | ||
| 85 | + help="Port of the server", | ||
| 86 | + ) | ||
| 87 | + | ||
| 88 | + parser.add_argument( | ||
| 89 | + "--samples-per-message", | ||
| 90 | + type=int, | ||
| 91 | + default=8000, | ||
| 92 | + help="Number of samples per message", | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + parser.add_argument( | ||
| 96 | + "--seconds-per-message", | ||
| 97 | + type=float, | ||
| 98 | + default=0.1, | ||
| 99 | + help="We will simulate that the duration of two messages is of this value", | ||
| 100 | + ) | ||
| 101 | + | ||
| 102 | + parser.add_argument( | ||
| 103 | + "sound_file", | ||
| 104 | + type=str, | ||
| 105 | + help="The input sound file. Must be wave with a single channel, 16kHz " | ||
| 106 | + "sampling rate, 16-bit of each sample.", | ||
| 107 | + ) | ||
| 108 | + | ||
| 109 | + return parser.parse_args() | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +async def receive_results(socket: websockets.WebSocketServerProtocol): | ||
| 113 | + last_message = "" | ||
| 114 | + async for message in socket: | ||
| 115 | + if message != "Done!": | ||
| 116 | + last_message = message | ||
| 117 | + logging.info(message) | ||
| 118 | + else: | ||
| 119 | + return last_message | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +async def run( | ||
| 123 | + server_addr: str, | ||
| 124 | + server_port: int, | ||
| 125 | + wave_filename: str, | ||
| 126 | + samples_per_message: int, | ||
| 127 | + seconds_per_message: float, | ||
| 128 | +): | ||
| 129 | + data = read_wave(wave_filename) | ||
| 130 | + | ||
| 131 | + async with websockets.connect( | ||
| 132 | + f"ws://{server_addr}:{server_port}" | ||
| 133 | + ) as websocket: # noqa | ||
| 134 | + logging.info(f"Sending {wave_filename}") | ||
| 135 | + | ||
| 136 | + receive_task = asyncio.create_task(receive_results(websocket)) | ||
| 137 | + | ||
| 138 | + start = 0 | ||
| 139 | + while start < data.shape[0]: | ||
| 140 | + end = start + samples_per_message | ||
| 141 | + end = min(end, data.shape[0]) | ||
| 142 | + d = data.data[start:end].tobytes() | ||
| 143 | + | ||
| 144 | + await websocket.send(d) | ||
| 145 | + | ||
| 146 | + await asyncio.sleep(seconds_per_message) # in seconds | ||
| 147 | + | ||
| 148 | + start += samples_per_message | ||
| 149 | + | ||
| 150 | + # to signal that the client has sent all the data | ||
| 151 | + await websocket.send("Done") | ||
| 152 | + | ||
| 153 | + decoding_results = await receive_task | ||
| 154 | + logging.info(f"\nFinal result is:\n{decoding_results}") | ||
| 155 | + | ||
| 156 | + | ||
| 157 | +async def main(): | ||
| 158 | + args = get_args() | ||
| 159 | + logging.info(vars(args)) | ||
| 160 | + | ||
| 161 | + server_addr = args.server_addr | ||
| 162 | + server_port = args.server_port | ||
| 163 | + samples_per_message = args.samples_per_message | ||
| 164 | + seconds_per_message = args.seconds_per_message | ||
| 165 | + | ||
| 166 | + await run( | ||
| 167 | + server_addr=server_addr, | ||
| 168 | + server_port=server_port, | ||
| 169 | + wave_filename=args.sound_file, | ||
| 170 | + samples_per_message=samples_per_message, | ||
| 171 | + seconds_per_message=seconds_per_message, | ||
| 172 | + ) | ||
| 173 | + | ||
| 174 | + | ||
| 175 | +if __name__ == "__main__": | ||
| 176 | + formatter = ( | ||
| 177 | + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa | ||
| 178 | + ) | ||
| 179 | + logging.basicConfig(format=formatter, level=logging.INFO) | ||
| 180 | + asyncio.run(main()) |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +A websocket client for sherpa-onnx-online-websocket-server | ||
| 7 | + | ||
| 8 | +Usage: | ||
| 9 | + ./online-websocket-client-microphone.py \ | ||
| 10 | + --server-addr localhost \ | ||
| 11 | + --server-port 6006 | ||
| 12 | + | ||
| 13 | +(Note: You have to first start the server before starting the client) | ||
| 14 | + | ||
| 15 | +You can find the server at | ||
| 16 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc | ||
| 17 | + | ||
| 18 | +Note: The server is implemented in C++. | ||
| 19 | + | ||
| 20 | +There is also a C++ version of the client. Please see | ||
| 21 | +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc | ||
| 22 | +""" | ||
| 23 | + | ||
| 24 | +import argparse | ||
| 25 | +import asyncio | ||
| 26 | +import time | ||
| 27 | + | ||
| 28 | +import numpy as np | ||
| 29 | + | ||
| 30 | +try: | ||
| 31 | + import sounddevice as sd | ||
| 32 | +except ImportError as e: | ||
| 33 | + print("Please install sounddevice first. You can use") | ||
| 34 | + print() | ||
| 35 | + print(" pip install sounddevice") | ||
| 36 | + print() | ||
| 37 | + print("to install it") | ||
| 38 | + sys.exit(-1) | ||
| 39 | + | ||
| 40 | +try: | ||
| 41 | + import websockets | ||
| 42 | +except ImportError: | ||
| 43 | + print("please run:") | ||
| 44 | + print("") | ||
| 45 | + print(" pip install websockets") | ||
| 46 | + print("") | ||
| 47 | + print("before you run this script") | ||
| 48 | + print("") | ||
| 49 | + sys.exit(-1) | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +def get_args(): | ||
| 53 | + parser = argparse.ArgumentParser( | ||
| 54 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 55 | + ) | ||
| 56 | + | ||
| 57 | + parser.add_argument( | ||
| 58 | + "--server-addr", | ||
| 59 | + type=str, | ||
| 60 | + default="localhost", | ||
| 61 | + help="Address of the server", | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + parser.add_argument( | ||
| 65 | + "--server-port", | ||
| 66 | + type=int, | ||
| 67 | + default=6006, | ||
| 68 | + help="Port of the server", | ||
| 69 | + ) | ||
| 70 | + | ||
| 71 | + return parser.parse_args() | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +async def inputstream_generator(channels=1): | ||
| 75 | + """Generator that yields blocks of input data as NumPy arrays. | ||
| 76 | + | ||
| 77 | + See https://python-sounddevice.readthedocs.io/en/0.4.6/examples.html#creating-an-asyncio-generator-for-audio-blocks | ||
| 78 | + """ | ||
| 79 | + q_in = asyncio.Queue() | ||
| 80 | + loop = asyncio.get_event_loop() | ||
| 81 | + | ||
| 82 | + def callback(indata, frame_count, time_info, status): | ||
| 83 | + loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status)) | ||
| 84 | + | ||
| 85 | + devices = sd.query_devices() | ||
| 86 | + print(devices) | ||
| 87 | + default_input_device_idx = sd.default.device[0] | ||
| 88 | + print(f'Use default device: {devices[default_input_device_idx]["name"]}') | ||
| 89 | + print() | ||
| 90 | + print("Started! Please speak") | ||
| 91 | + | ||
| 92 | + stream = sd.InputStream( | ||
| 93 | + callback=callback, | ||
| 94 | + channels=channels, | ||
| 95 | + dtype="float32", | ||
| 96 | + samplerate=16000, | ||
| 97 | + blocksize=int(0.05 * 16000), # 0.05 seconds | ||
| 98 | + ) | ||
| 99 | + with stream: | ||
| 100 | + while True: | ||
| 101 | + indata, status = await q_in.get() | ||
| 102 | + yield indata, status | ||
| 103 | + | ||
| 104 | + | ||
| 105 | +async def receive_results(socket: websockets.WebSocketServerProtocol): | ||
| 106 | + last_message = "" | ||
| 107 | + async for message in socket: | ||
| 108 | + if message != "Done!": | ||
| 109 | + if last_message != message: | ||
| 110 | + last_message = message | ||
| 111 | + | ||
| 112 | + if last_message: | ||
| 113 | + print(last_message) | ||
| 114 | + else: | ||
| 115 | + return last_message | ||
| 116 | + | ||
| 117 | + | ||
| 118 | +async def run( | ||
| 119 | + server_addr: str, | ||
| 120 | + server_port: int, | ||
| 121 | +): | ||
| 122 | + async with websockets.connect( | ||
| 123 | + f"ws://{server_addr}:{server_port}" | ||
| 124 | + ) as websocket: # noqa | ||
| 125 | + receive_task = asyncio.create_task(receive_results(websocket)) | ||
| 126 | + print("Started! Please Speak") | ||
| 127 | + | ||
| 128 | + async for indata, status in inputstream_generator(): | ||
| 129 | + if status: | ||
| 130 | + print(status) | ||
| 131 | + indata = indata.reshape(-1) | ||
| 132 | + indata = np.ascontiguousarray(indata) | ||
| 133 | + await websocket.send(indata.tobytes()) | ||
| 134 | + | ||
| 135 | + decoding_results = await receive_task | ||
| 136 | + print("\nFinal result is:\n{decoding_results}") | ||
| 137 | + | ||
| 138 | + | ||
| 139 | +async def main(): | ||
| 140 | + args = get_args() | ||
| 141 | + print(vars(args)) | ||
| 142 | + | ||
| 143 | + server_addr = args.server_addr | ||
| 144 | + server_port = args.server_port | ||
| 145 | + | ||
| 146 | + await run( | ||
| 147 | + server_addr=server_addr, | ||
| 148 | + server_port=server_port, | ||
| 149 | + ) | ||
| 150 | + | ||
| 151 | + | ||
| 152 | +if __name__ == "__main__": | ||
| 153 | + try: | ||
| 154 | + asyncio.run(main()) | ||
| 155 | + except KeyboardInterrupt: | ||
| 156 | + print("\nCaught Ctrl + C. Exiting") |
-
请 注册 或 登录 后发表评论