Fangjun Kuang
Committed by GitHub

Add Python websocket client (#63)

  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +A websocket client for sherpa-onnx-online-websocket-server
  7 +
  8 +Usage:
  9 + ./online-websocket-client-decode-file.py \
  10 + --server-addr localhost \
  11 + --server-port 6006 \
  12 + --seconds-per-message 0.1 \
  13 + --samples-per-message 8000 \
  14 + /path/to/foo.wav
  15 +
  16 +(Note: You have to first start the server before starting the client)
  17 +
  18 +You can find the server at
  19 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
  20 +
  21 +Note: The server is implemented in C++.
  22 +
  23 +There is also a C++ version of the client. Please see
  24 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
  25 +"""
  26 +
  27 +import argparse
  28 +import asyncio
  29 +import logging
  30 +import time
  31 +import wave
  32 +
  33 +try:
  34 + import websockets
  35 +except ImportError:
  36 + print("please run:")
  37 + print("")
  38 + print(" pip install websockets")
  39 + print("")
  40 + print("before you run this script")
  41 + print("")
  42 +
  43 +import numpy as np
  44 +
  45 +
  46 +def read_wave(wave_filename: str) -> np.ndarray:
  47 + """
  48 + Args:
  49 + wave_filename:
  50 + Path to a wave file. Its sampling rate has to be 16000.
  51 + It should be single channel and each sample should be 16-bit.
  52 + Returns:
  53 + Return a 1-D float32 tensor.
  54 + """
  55 +
  56 + with wave.open(wave_filename) as f:
  57 + assert f.getframerate() == 16000, f.getframerate()
  58 + assert f.getnchannels() == 1, f.getnchannels()
  59 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  60 + num_samples = f.getnframes()
  61 + samples = f.readframes(num_samples)
  62 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  63 + samples_float32 = samples_int16.astype(np.float32)
  64 +
  65 + samples_float32 = samples_float32 / 32768
  66 + return samples_float32
  67 +
  68 +
  69 +def get_args():
  70 + parser = argparse.ArgumentParser(
  71 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  72 + )
  73 +
  74 + parser.add_argument(
  75 + "--server-addr",
  76 + type=str,
  77 + default="localhost",
  78 + help="Address of the server",
  79 + )
  80 +
  81 + parser.add_argument(
  82 + "--server-port",
  83 + type=int,
  84 + default=6006,
  85 + help="Port of the server",
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--samples-per-message",
  90 + type=int,
  91 + default=8000,
  92 + help="Number of samples per message",
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--seconds-per-message",
  97 + type=float,
  98 + default=0.1,
  99 + help="We will simulate that the duration of two messages is of this value",
  100 + )
  101 +
  102 + parser.add_argument(
  103 + "sound_file",
  104 + type=str,
  105 + help="The input sound file. Must be wave with a single channel, 16kHz "
  106 + "sampling rate, 16-bit of each sample.",
  107 + )
  108 +
  109 + return parser.parse_args()
  110 +
  111 +
  112 +async def receive_results(socket: websockets.WebSocketServerProtocol):
  113 + last_message = ""
  114 + async for message in socket:
  115 + if message != "Done!":
  116 + last_message = message
  117 + logging.info(message)
  118 + else:
  119 + return last_message
  120 +
  121 +
  122 +async def run(
  123 + server_addr: str,
  124 + server_port: int,
  125 + wave_filename: str,
  126 + samples_per_message: int,
  127 + seconds_per_message: float,
  128 +):
  129 + data = read_wave(wave_filename)
  130 +
  131 + async with websockets.connect(
  132 + f"ws://{server_addr}:{server_port}"
  133 + ) as websocket: # noqa
  134 + logging.info(f"Sending {wave_filename}")
  135 +
  136 + receive_task = asyncio.create_task(receive_results(websocket))
  137 +
  138 + start = 0
  139 + while start < data.shape[0]:
  140 + end = start + samples_per_message
  141 + end = min(end, data.shape[0])
  142 + d = data.data[start:end].tobytes()
  143 +
  144 + await websocket.send(d)
  145 +
  146 + await asyncio.sleep(seconds_per_message) # in seconds
  147 +
  148 + start += samples_per_message
  149 +
  150 + # to signal that the client has sent all the data
  151 + await websocket.send("Done")
  152 +
  153 + decoding_results = await receive_task
  154 + logging.info(f"\nFinal result is:\n{decoding_results}")
  155 +
  156 +
  157 +async def main():
  158 + args = get_args()
  159 + logging.info(vars(args))
  160 +
  161 + server_addr = args.server_addr
  162 + server_port = args.server_port
  163 + samples_per_message = args.samples_per_message
  164 + seconds_per_message = args.seconds_per_message
  165 +
  166 + await run(
  167 + server_addr=server_addr,
  168 + server_port=server_port,
  169 + wave_filename=args.sound_file,
  170 + samples_per_message=samples_per_message,
  171 + seconds_per_message=seconds_per_message,
  172 + )
  173 +
  174 +
  175 +if __name__ == "__main__":
  176 + formatter = (
  177 + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa
  178 + )
  179 + logging.basicConfig(format=formatter, level=logging.INFO)
  180 + asyncio.run(main())
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +A websocket client for sherpa-onnx-online-websocket-server
  7 +
  8 +Usage:
  9 + ./online-websocket-client-microphone.py \
  10 + --server-addr localhost \
  11 + --server-port 6006
  12 +
  13 +(Note: You have to first start the server before starting the client)
  14 +
  15 +You can find the server at
  16 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
  17 +
  18 +Note: The server is implemented in C++.
  19 +
  20 +There is also a C++ version of the client. Please see
  21 +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
  22 +"""
  23 +
  24 +import argparse
  25 +import asyncio
  26 +import time
  27 +
  28 +import numpy as np
  29 +
  30 +try:
  31 + import sounddevice as sd
  32 +except ImportError as e:
  33 + print("Please install sounddevice first. You can use")
  34 + print()
  35 + print(" pip install sounddevice")
  36 + print()
  37 + print("to install it")
  38 + sys.exit(-1)
  39 +
  40 +try:
  41 + import websockets
  42 +except ImportError:
  43 + print("please run:")
  44 + print("")
  45 + print(" pip install websockets")
  46 + print("")
  47 + print("before you run this script")
  48 + print("")
  49 + sys.exit(-1)
  50 +
  51 +
  52 +def get_args():
  53 + parser = argparse.ArgumentParser(
  54 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  55 + )
  56 +
  57 + parser.add_argument(
  58 + "--server-addr",
  59 + type=str,
  60 + default="localhost",
  61 + help="Address of the server",
  62 + )
  63 +
  64 + parser.add_argument(
  65 + "--server-port",
  66 + type=int,
  67 + default=6006,
  68 + help="Port of the server",
  69 + )
  70 +
  71 + return parser.parse_args()
  72 +
  73 +
  74 +async def inputstream_generator(channels=1):
  75 + """Generator that yields blocks of input data as NumPy arrays.
  76 +
  77 + See https://python-sounddevice.readthedocs.io/en/0.4.6/examples.html#creating-an-asyncio-generator-for-audio-blocks
  78 + """
  79 + q_in = asyncio.Queue()
  80 + loop = asyncio.get_event_loop()
  81 +
  82 + def callback(indata, frame_count, time_info, status):
  83 + loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status))
  84 +
  85 + devices = sd.query_devices()
  86 + print(devices)
  87 + default_input_device_idx = sd.default.device[0]
  88 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  89 + print()
  90 + print("Started! Please speak")
  91 +
  92 + stream = sd.InputStream(
  93 + callback=callback,
  94 + channels=channels,
  95 + dtype="float32",
  96 + samplerate=16000,
  97 + blocksize=int(0.05 * 16000), # 0.05 seconds
  98 + )
  99 + with stream:
  100 + while True:
  101 + indata, status = await q_in.get()
  102 + yield indata, status
  103 +
  104 +
  105 +async def receive_results(socket: websockets.WebSocketServerProtocol):
  106 + last_message = ""
  107 + async for message in socket:
  108 + if message != "Done!":
  109 + if last_message != message:
  110 + last_message = message
  111 +
  112 + if last_message:
  113 + print(last_message)
  114 + else:
  115 + return last_message
  116 +
  117 +
  118 +async def run(
  119 + server_addr: str,
  120 + server_port: int,
  121 +):
  122 + async with websockets.connect(
  123 + f"ws://{server_addr}:{server_port}"
  124 + ) as websocket: # noqa
  125 + receive_task = asyncio.create_task(receive_results(websocket))
  126 + print("Started! Please Speak")
  127 +
  128 + async for indata, status in inputstream_generator():
  129 + if status:
  130 + print(status)
  131 + indata = indata.reshape(-1)
  132 + indata = np.ascontiguousarray(indata)
  133 + await websocket.send(indata.tobytes())
  134 +
  135 + decoding_results = await receive_task
  136 + print("\nFinal result is:\n{decoding_results}")
  137 +
  138 +
  139 +async def main():
  140 + args = get_args()
  141 + print(vars(args))
  142 +
  143 + server_addr = args.server_addr
  144 + server_port = args.server_port
  145 +
  146 + await run(
  147 + server_addr=server_addr,
  148 + server_port=server_port,
  149 + )
  150 +
  151 +
  152 +if __name__ == "__main__":
  153 + try:
  154 + asyncio.run(main())
  155 + except KeyboardInterrupt:
  156 + print("\nCaught Ctrl + C. Exiting")
  1 +# Copyright (c) 2023 Xiaomi Corporation
1 from pathlib import Path 2 from pathlib import Path
2 from typing import List 3 from typing import List
3 4