Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-02-24 22:46:30 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-02-24 22:46:30 +0800
Commit
e4b79ad34b27fc1c603b51ec9583e8780754f6bc
e4b79ad3
1 parent
40522f03
Add Python websocket client (#63)
隐藏空白字符变更
内嵌
并排对比
正在显示
3 个修改的文件
包含
337 行增加
和
0 行删除
python-api-examples/online-websocket-client-decode-file.py
python-api-examples/online-websocket-client-microphone.py
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
python-api-examples/online-websocket-client-decode-file.py
0 → 100755
查看文件 @
e4b79ad
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation
"""
A websocket client for sherpa-onnx-online-websocket-server
Usage:
./online-websocket-client-decode-file.py
\
--server-addr localhost
\
--server-port 6006
\
--seconds-per-message 0.1
\
--samples-per-message 8000
\
/path/to/foo.wav
(Note: You have to first start the server before starting the client)
You can find the server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
Note: The server is implemented in C++.
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
"""
import
argparse
import
asyncio
import
logging
import
time
import
wave
try
:
import
websockets
except
ImportError
:
print
(
"please run:"
)
print
(
""
)
print
(
" pip install websockets"
)
print
(
""
)
print
(
"before you run this script"
)
print
(
""
)
import
numpy
as
np
def
read_wave
(
wave_filename
:
str
)
->
np
.
ndarray
:
"""
Args:
wave_filename:
Path to a wave file. Its sampling rate has to be 16000.
It should be single channel and each sample should be 16-bit.
Returns:
Return a 1-D float32 tensor.
"""
with
wave
.
open
(
wave_filename
)
as
f
:
assert
f
.
getframerate
()
==
16000
,
f
.
getframerate
()
assert
f
.
getnchannels
()
==
1
,
f
.
getnchannels
()
assert
f
.
getsampwidth
()
==
2
,
f
.
getsampwidth
()
# it is in bytes
num_samples
=
f
.
getnframes
()
samples
=
f
.
readframes
(
num_samples
)
samples_int16
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples_float32
=
samples_int16
.
astype
(
np
.
float32
)
samples_float32
=
samples_float32
/
32768
return
samples_float32
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
"--server-addr"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Address of the server"
,
)
parser
.
add_argument
(
"--server-port"
,
type
=
int
,
default
=
6006
,
help
=
"Port of the server"
,
)
parser
.
add_argument
(
"--samples-per-message"
,
type
=
int
,
default
=
8000
,
help
=
"Number of samples per message"
,
)
parser
.
add_argument
(
"--seconds-per-message"
,
type
=
float
,
default
=
0.1
,
help
=
"We will simulate that the duration of two messages is of this value"
,
)
parser
.
add_argument
(
"sound_file"
,
type
=
str
,
help
=
"The input sound file. Must be wave with a single channel, 16kHz "
"sampling rate, 16-bit of each sample."
,
)
return
parser
.
parse_args
()
async
def
receive_results
(
socket
:
websockets
.
WebSocketServerProtocol
):
last_message
=
""
async
for
message
in
socket
:
if
message
!=
"Done!"
:
last_message
=
message
logging
.
info
(
message
)
else
:
return
last_message
async
def
run
(
server_addr
:
str
,
server_port
:
int
,
wave_filename
:
str
,
samples_per_message
:
int
,
seconds_per_message
:
float
,
):
data
=
read_wave
(
wave_filename
)
async
with
websockets
.
connect
(
f
"ws://{server_addr}:{server_port}"
)
as
websocket
:
# noqa
logging
.
info
(
f
"Sending {wave_filename}"
)
receive_task
=
asyncio
.
create_task
(
receive_results
(
websocket
))
start
=
0
while
start
<
data
.
shape
[
0
]:
end
=
start
+
samples_per_message
end
=
min
(
end
,
data
.
shape
[
0
])
d
=
data
.
data
[
start
:
end
]
.
tobytes
()
await
websocket
.
send
(
d
)
await
asyncio
.
sleep
(
seconds_per_message
)
# in seconds
start
+=
samples_per_message
# to signal that the client has sent all the data
await
websocket
.
send
(
"Done"
)
decoding_results
=
await
receive_task
logging
.
info
(
f
"
\n
Final result is:
\n
{decoding_results}"
)
async
def
main
():
args
=
get_args
()
logging
.
info
(
vars
(
args
))
server_addr
=
args
.
server_addr
server_port
=
args
.
server_port
samples_per_message
=
args
.
samples_per_message
seconds_per_message
=
args
.
seconds_per_message
await
run
(
server_addr
=
server_addr
,
server_port
=
server_port
,
wave_filename
=
args
.
sound_file
,
samples_per_message
=
samples_per_message
,
seconds_per_message
=
seconds_per_message
,
)
if
__name__
==
"__main__"
:
formatter
=
(
"
%(asctime)
s
%(levelname)
s [
%(filename)
s:
%(lineno)
d]
%(message)
s"
# noqa
)
logging
.
basicConfig
(
format
=
formatter
,
level
=
logging
.
INFO
)
asyncio
.
run
(
main
())
...
...
python-api-examples/online-websocket-client-microphone.py
0 → 100755
查看文件 @
e4b79ad
#!/usr/bin/env python3
#
# Copyright (c) 2023 Xiaomi Corporation
"""
A websocket client for sherpa-onnx-online-websocket-server
Usage:
./online-websocket-client-microphone.py
\
--server-addr localhost
\
--server-port 6006
(Note: You have to first start the server before starting the client)
You can find the server at
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc
Note: The server is implemented in C++.
There is also a C++ version of the client. Please see
https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc
"""
import
argparse
import
asyncio
import
time
import
numpy
as
np
try
:
import
sounddevice
as
sd
except
ImportError
as
e
:
print
(
"Please install sounddevice first. You can use"
)
print
()
print
(
" pip install sounddevice"
)
print
()
print
(
"to install it"
)
sys
.
exit
(
-
1
)
try
:
import
websockets
except
ImportError
:
print
(
"please run:"
)
print
(
""
)
print
(
" pip install websockets"
)
print
(
""
)
print
(
"before you run this script"
)
print
(
""
)
sys
.
exit
(
-
1
)
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
"--server-addr"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Address of the server"
,
)
parser
.
add_argument
(
"--server-port"
,
type
=
int
,
default
=
6006
,
help
=
"Port of the server"
,
)
return
parser
.
parse_args
()
async
def
inputstream_generator
(
channels
=
1
):
"""Generator that yields blocks of input data as NumPy arrays.
See https://python-sounddevice.readthedocs.io/en/0.4.6/examples.html#creating-an-asyncio-generator-for-audio-blocks
"""
q_in
=
asyncio
.
Queue
()
loop
=
asyncio
.
get_event_loop
()
def
callback
(
indata
,
frame_count
,
time_info
,
status
):
loop
.
call_soon_threadsafe
(
q_in
.
put_nowait
,
(
indata
.
copy
(),
status
))
devices
=
sd
.
query_devices
()
print
(
devices
)
default_input_device_idx
=
sd
.
default
.
device
[
0
]
print
(
f
'Use default device: {devices[default_input_device_idx]["name"]}'
)
print
()
print
(
"Started! Please speak"
)
stream
=
sd
.
InputStream
(
callback
=
callback
,
channels
=
channels
,
dtype
=
"float32"
,
samplerate
=
16000
,
blocksize
=
int
(
0.05
*
16000
),
# 0.05 seconds
)
with
stream
:
while
True
:
indata
,
status
=
await
q_in
.
get
()
yield
indata
,
status
async
def
receive_results
(
socket
:
websockets
.
WebSocketServerProtocol
):
last_message
=
""
async
for
message
in
socket
:
if
message
!=
"Done!"
:
if
last_message
!=
message
:
last_message
=
message
if
last_message
:
print
(
last_message
)
else
:
return
last_message
async
def
run
(
server_addr
:
str
,
server_port
:
int
,
):
async
with
websockets
.
connect
(
f
"ws://{server_addr}:{server_port}"
)
as
websocket
:
# noqa
receive_task
=
asyncio
.
create_task
(
receive_results
(
websocket
))
print
(
"Started! Please Speak"
)
async
for
indata
,
status
in
inputstream_generator
():
if
status
:
print
(
status
)
indata
=
indata
.
reshape
(
-
1
)
indata
=
np
.
ascontiguousarray
(
indata
)
await
websocket
.
send
(
indata
.
tobytes
())
decoding_results
=
await
receive_task
print
(
"
\n
Final result is:
\n
{decoding_results}"
)
async
def
main
():
args
=
get_args
()
print
(
vars
(
args
))
server_addr
=
args
.
server_addr
server_port
=
args
.
server_port
await
run
(
server_addr
=
server_addr
,
server_port
=
server_port
,
)
if
__name__
==
"__main__"
:
try
:
asyncio
.
run
(
main
())
except
KeyboardInterrupt
:
print
(
"
\n
Caught Ctrl + C. Exiting"
)
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
e4b79ad
# Copyright (c) 2023 Xiaomi Corporation
from
pathlib
import
Path
from
typing
import
List
...
...
请
注册
或
登录
后发表评论