Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-04-02 23:05:30 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-04-02 23:05:30 +0800
Commit
5d3c8edbc91f7451d34f8e503c2d14c8886f4e5d
5d3c8edb
1 parent
3f7e0c23
add python tests (#111)
隐藏空白字符变更
内嵌
并排对比
正在显示
11 个修改的文件
包含
488 行增加
和
48 行删除
.github/scripts/test-python.sh
.gitignore
python-api-examples/offline-decode-files.py
sherpa-onnx/csrc/features.cc
sherpa-onnx/csrc/offline-stream.cc
sherpa-onnx/csrc/offline-websocket-server.cc
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
sherpa-onnx/python/tests/CMakeLists.txt
sherpa-onnx/python/tests/test_offline_recognizer.py
sherpa-onnx/python/tests/test_online_recognizer.py
.github/scripts/test-python.sh
查看文件 @
5d3c8ed
...
...
@@ -8,15 +8,20 @@ log() {
echo
-e
"
$(
date
'+%Y-%m-%d %H:%M:%S'
)
(
${
fname
}
:
${
BASH_LINENO
[0]
}
:
${
FUNCNAME
[1]
}
)
$*
"
}
mkdir -p /tmp/icefall-models
dir
=
/tmp/icefall-models
log
"Test streaming transducer models"
pushd
$dir
repo_url
=
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
log
"Start testing
${
repo_url
}
"
repo
=
$(
basename
$repo_url
)
repo
=
$
dir
/
$
(
basename
$repo_url
)
log
"Download pretrained model and test-data from
$repo_url
"
GIT_LFS_SKIP_SMUDGE
=
1 git clone
$repo_url
push
d
$repo
c
d
$repo
git lfs pull --include
"*.onnx"
popd
...
...
@@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/2.wav
\
$repo
/test_wavs/3.wav
$repo
/test_wavs/3.wav
\
$repo
/test_wavs/8k.wav
python3 ./python-api-examples/online-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder-epoch-99-avg-1.int8.onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.int8.onnx
\
--joiner
=
$repo
/joiner-epoch-99-avg-1.int8.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/2.wav
\
$repo
/test_wavs/3.wav
\
$repo
/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
log
"Test non-streaming transducer models"
pushd
$dir
repo_url
=
https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-04-01
log
"Start testing
${
repo_url
}
"
repo
=
$dir
/
$(
basename
$repo_url
)
log
"Download pretrained model and test-data from
$repo_url
"
GIT_LFS_SKIP_SMUDGE
=
1 git clone
$repo_url
cd
$repo
git lfs pull --include
"*.onnx"
popd
ls -lh
$repo
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder-epoch-99-avg-1.onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.onnx
\
--joiner
=
$repo
/joiner-epoch-99-avg-1.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/8k.wav
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--encoder
=
$repo
/encoder-epoch-99-avg-1.int8.onnx
\
--decoder
=
$repo
/decoder-epoch-99-avg-1.int8.onnx
\
--joiner
=
$repo
/joiner-epoch-99-avg-1.int8.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
log
"Test non-streaming paraformer models"
pushd
$dir
repo_url
=
https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
log
"Start testing
${
repo_url
}
"
repo
=
$dir
/
$(
basename
$repo_url
)
log
"Download pretrained model and test-data from
$repo_url
"
GIT_LFS_SKIP_SMUDGE
=
1 git clone
$repo_url
cd
$repo
git lfs pull --include
"*.onnx"
popd
ls -lh
$repo
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--paraformer
=
$repo
/model.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/2.wav
\
$repo
/test_wavs/8k.wav
python3 ./python-api-examples/offline-decode-files.py
\
--tokens
=
$repo
/tokens.txt
\
--paraformer
=
$repo
/model.int8.onnx
\
$repo
/test_wavs/0.wav
\
$repo
/test_wavs/1.wav
\
$repo
/test_wavs/2.wav
\
$repo
/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
...
...
.gitignore
查看文件 @
5d3c8ed
...
...
@@ -51,3 +51,4 @@ a.sh
run-offline-websocket-client-*.sh
run-sherpa-onnx-*.sh
sherpa-onnx-zipformer-en-2023-03-30
sherpa-onnx-zipformer-en-2023-04-01
...
...
python-api-examples/offline-decode-files.py
100644 → 100755
查看文件 @
5d3c8ed
...
...
@@ -46,6 +46,7 @@ from typing import Tuple
import
numpy
as
np
import
sherpa_onnx
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
...
...
@@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
samples_float32
=
samples_float32
/
32768
return
samples_float32
,
f
.
getframerate
()
def
main
():
args
=
get_args
()
assert_file_exists
(
args
.
tokens
)
...
...
@@ -183,7 +185,7 @@ def main():
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
debug
=
args
.
debug
,
)
else
:
assert_file_exists
(
args
.
paraformer
)
...
...
@@ -194,10 +196,9 @@ def main():
sample_rate
=
args
.
sample_rate
,
feature_dim
=
args
.
feature_dim
,
decoding_method
=
args
.
decoding_method
,
debug
=
args
.
debug
debug
=
args
.
debug
,
)
print
(
"Started!"
)
start_time
=
time
.
time
()
...
...
@@ -212,12 +213,8 @@ def main():
s
=
recognizer
.
create_stream
()
s
.
accept_waveform
(
sample_rate
,
samples
)
tail_paddings
=
np
.
zeros
(
int
(
0.2
*
sample_rate
),
dtype
=
np
.
float32
)
s
.
accept_waveform
(
sample_rate
,
tail_paddings
)
streams
.
append
(
s
)
recognizer
.
decode_streams
(
streams
)
results
=
[
s
.
result
.
text
for
s
in
streams
]
end_time
=
time
.
time
()
...
...
sherpa-onnx/csrc/features.cc
查看文件 @
5d3c8ed
...
...
@@ -18,8 +18,8 @@ namespace sherpa_onnx {
void
FeatureExtractorConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"sample-rate"
,
&
sampling_rate
,
"Sampling rate of the input waveform. Must match the one "
"expected by the model. Note: You can have a different "
"Sampling rate of the input waveform. "
"Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor"
);
...
...
sherpa-onnx/csrc/offline-stream.cc
查看文件 @
5d3c8ed
...
...
@@ -17,8 +17,8 @@ namespace sherpa_onnx {
void
OfflineFeatureExtractorConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"sample-rate"
,
&
sampling_rate
,
"Sampling rate of the input waveform. Must match the one "
"expected by the model. Note: You can have a different "
"Sampling rate of the input waveform. "
"Note: You can have a different "
"sample rate for the input waveform. We will do resampling "
"inside the feature extractor"
);
...
...
sherpa-onnx/csrc/offline-websocket-server.cc
查看文件 @
5d3c8ed
...
...
@@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) {
po
.
Register
(
"port"
,
&
port
,
"The port on which the server will listen."
);
config
.
Register
(
&
po
);
po
.
DisableOption
(
"sample-rate"
);
if
(
argc
==
1
)
{
po
.
PrintUsage
();
...
...
sherpa-onnx/python/sherpa_onnx/offline_recognizer.py
查看文件 @
5d3c8ed
...
...
@@ -18,20 +18,25 @@ def _assert_file_exists(f: str):
class
OfflineRecognizer
(
object
):
"""A class for offline speech recognition."""
"""A class for offline speech recognition.
Please refer to the following files for usages
- https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_offline_recognizer.py
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/offline-decode-files.py
"""
@classmethod
def
from_transducer
(
cls
,
encoder
:
str
,
decoder
:
str
,
joiner
:
str
,
tokens
:
str
,
num_threads
:
int
,
sample_rate
:
int
=
16000
,
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
cls
,
encoder
:
str
,
decoder
:
str
,
joiner
:
str
,
tokens
:
str
,
num_threads
:
int
,
sample_rate
:
int
=
16000
,
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
):
"""
Please refer to
...
...
@@ -59,7 +64,7 @@ class OfflineRecognizer(object):
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search
.
Support only greedy_search for now
.
debug:
True to show debug messages.
"""
...
...
@@ -68,14 +73,12 @@ class OfflineRecognizer(object):
transducer
=
OfflineTransducerModelConfig
(
encoder_filename
=
encoder
,
decoder_filename
=
decoder
,
joiner_filename
=
joiner
),
paraformer
=
OfflineParaformerModelConfig
(
model
=
""
joiner_filename
=
joiner
,
),
paraformer
=
OfflineParaformerModelConfig
(
model
=
""
),
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
debug
=
debug
,
)
feat_config
=
OfflineFeatureExtractorConfig
(
...
...
@@ -93,14 +96,14 @@ class OfflineRecognizer(object):
@classmethod
def
from_paraformer
(
cls
,
paraformer
:
str
,
tokens
:
str
,
num_threads
:
int
,
sample_rate
:
int
=
16000
,
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
cls
,
paraformer
:
str
,
tokens
:
str
,
num_threads
:
int
,
sample_rate
:
int
=
16000
,
feature_dim
:
int
=
80
,
decoding_method
:
str
=
"greedy_search"
,
debug
:
bool
=
False
,
):
"""
Please refer to
...
...
@@ -131,16 +134,12 @@ class OfflineRecognizer(object):
self
=
cls
.
__new__
(
cls
)
model_config
=
OfflineModelConfig
(
transducer
=
OfflineTransducerModelConfig
(
encoder_filename
=
""
,
decoder_filename
=
""
,
joiner_filename
=
""
),
paraformer
=
OfflineParaformerModelConfig
(
model
=
paraformer
encoder_filename
=
""
,
decoder_filename
=
""
,
joiner_filename
=
""
),
paraformer
=
OfflineParaformerModelConfig
(
model
=
paraformer
),
tokens
=
tokens
,
num_threads
=
num_threads
,
debug
=
debug
debug
=
debug
,
)
feat_config
=
OfflineFeatureExtractorConfig
(
...
...
@@ -164,4 +163,3 @@ class OfflineRecognizer(object):
def
decode_streams
(
self
,
ss
:
List
[
OfflineStream
]):
self
.
recognizer
.
decode_streams
(
ss
)
...
...
sherpa-onnx/python/sherpa_onnx/online_recognizer.py
查看文件 @
5d3c8ed
...
...
@@ -17,7 +17,12 @@ def _assert_file_exists(f: str):
class
OnlineRecognizer
(
object
):
"""A class for streaming speech recognition."""
"""A class for streaming speech recognition.
Please refer to the following files for usages
- https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py
"""
def
__init__
(
self
,
...
...
sherpa-onnx/python/tests/CMakeLists.txt
查看文件 @
5d3c8ed
...
...
@@ -18,6 +18,8 @@ endfunction()
# please sort the files in alphabetic order
set
(
py_test_files
test_feature_extractor_config.py
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
)
...
...
sherpa-onnx/python/tests/test_offline_recognizer.py
0 → 100755
查看文件 @
5d3c8ed
# sherpa-onnx/python/tests/test_offline_recognizer.py
#
# Copyright (c) 2023 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_offline_recognizer_py
import
unittest
import
wave
from
pathlib
import
Path
from
typing
import
Tuple
import
numpy
as
np
import
sherpa_onnx
d
=
"/tmp/icefall-models"
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
# and
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
# to download pre-trained models for testing
def
read_wave
(
wave_filename
:
str
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with
wave
.
open
(
wave_filename
)
as
f
:
assert
f
.
getnchannels
()
==
1
,
f
.
getnchannels
()
assert
f
.
getsampwidth
()
==
2
,
f
.
getsampwidth
()
# it is in bytes
num_samples
=
f
.
getnframes
()
samples
=
f
.
readframes
(
num_samples
)
samples_int16
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples_float32
=
samples_int16
.
astype
(
np
.
float32
)
samples_float32
=
samples_float32
/
32768
return
samples_float32
,
f
.
getframerate
()
class
TestOfflineRecognizer
(
unittest
.
TestCase
):
def
test_transducer_single_file
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
encoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx"
decoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx"
joiner
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx"
else
:
encoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx"
decoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx"
joiner
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx"
tokens
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav"
if
not
Path
(
encoder
)
.
is_file
():
print
(
"skipping test_transducer_single_file()"
)
return
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_transducer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
)
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave0
)
s
.
accept_waveform
(
sample_rate
,
samples
)
recognizer
.
decode_stream
(
s
)
print
(
s
.
result
.
text
)
def
test_transducer_multiple_files
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
encoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx"
decoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx"
joiner
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx"
else
:
encoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx"
decoder
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx"
joiner
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx"
tokens
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav"
wave1
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/1.wav"
wave2
=
f
"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/8k.wav"
if
not
Path
(
encoder
)
.
is_file
():
print
(
"skipping test_transducer_multiple_files()"
)
return
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_transducer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
)
s0
=
recognizer
.
create_stream
()
samples0
,
sample_rate0
=
read_wave
(
wave0
)
s0
.
accept_waveform
(
sample_rate0
,
samples0
)
s1
=
recognizer
.
create_stream
()
samples1
,
sample_rate1
=
read_wave
(
wave1
)
s1
.
accept_waveform
(
sample_rate1
,
samples1
)
s2
=
recognizer
.
create_stream
()
samples2
,
sample_rate2
=
read_wave
(
wave2
)
s2
.
accept_waveform
(
sample_rate2
,
samples2
)
recognizer
.
decode_streams
([
s0
,
s1
,
s2
])
print
(
s0
.
result
.
text
)
print
(
s1
.
result
.
text
)
print
(
s2
.
result
.
text
)
def
test_paraformer_single_file
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
model
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
else
:
model
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx"
tokens
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
if
not
Path
(
model
)
.
is_file
():
print
(
"skipping test_paraformer_single_file()"
)
return
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_paraformer
(
paraformer
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
)
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave0
)
s
.
accept_waveform
(
sample_rate
,
samples
)
recognizer
.
decode_stream
(
s
)
print
(
s
.
result
.
text
)
def
test_paraformer_multiple_files
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
model
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
else
:
model
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx"
tokens
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
wave1
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav"
wave2
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav"
wave3
=
f
"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav"
if
not
Path
(
model
)
.
is_file
():
print
(
"skipping test_paraformer_multiple_files()"
)
return
recognizer
=
sherpa_onnx
.
OfflineRecognizer
.
from_paraformer
(
paraformer
=
model
,
tokens
=
tokens
,
num_threads
=
1
,
)
s0
=
recognizer
.
create_stream
()
samples0
,
sample_rate0
=
read_wave
(
wave0
)
s0
.
accept_waveform
(
sample_rate0
,
samples0
)
s1
=
recognizer
.
create_stream
()
samples1
,
sample_rate1
=
read_wave
(
wave1
)
s1
.
accept_waveform
(
sample_rate1
,
samples1
)
s2
=
recognizer
.
create_stream
()
samples2
,
sample_rate2
=
read_wave
(
wave2
)
s2
.
accept_waveform
(
sample_rate2
,
samples2
)
s3
=
recognizer
.
create_stream
()
samples3
,
sample_rate3
=
read_wave
(
wave3
)
s3
.
accept_waveform
(
sample_rate3
,
samples3
)
recognizer
.
decode_streams
([
s0
,
s1
,
s2
,
s3
])
print
(
s0
.
result
.
text
)
print
(
s1
.
result
.
text
)
print
(
s2
.
result
.
text
)
print
(
s3
.
result
.
text
)
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
sherpa-onnx/python/tests/test_online_recognizer.py
0 → 100755
查看文件 @
5d3c8ed
# sherpa-onnx/python/tests/test_online_recognizer.py
#
# Copyright (c) 2023 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_online_recognizer_py
import
unittest
import
wave
from
pathlib
import
Path
from
typing
import
Tuple
import
numpy
as
np
import
sherpa_onnx
d
=
"/tmp/icefall-models"
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
# to download pre-trained models for testing
def
read_wave
(
wave_filename
:
str
)
->
Tuple
[
np
.
ndarray
,
int
]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with
wave
.
open
(
wave_filename
)
as
f
:
assert
f
.
getnchannels
()
==
1
,
f
.
getnchannels
()
assert
f
.
getsampwidth
()
==
2
,
f
.
getsampwidth
()
# it is in bytes
num_samples
=
f
.
getnframes
()
samples
=
f
.
readframes
(
num_samples
)
samples_int16
=
np
.
frombuffer
(
samples
,
dtype
=
np
.
int16
)
samples_float32
=
samples_int16
.
astype
(
np
.
float32
)
samples_float32
=
samples_float32
/
32768
return
samples_float32
,
f
.
getframerate
()
class
TestOnlineRecognizer
(
unittest
.
TestCase
):
def
test_transducer_single_file
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
encoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
decoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx"
joiner
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
else
:
encoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
decoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
joiner
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
tokens
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav"
if
not
Path
(
encoder
)
.
is_file
():
print
(
"skipping test_transducer_single_file()"
)
return
for
decoding_method
in
[
"greedy_search"
,
"modified_beam_search"
]:
recognizer
=
sherpa_onnx
.
OnlineRecognizer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
decoding_method
=
decoding_method
,
)
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave0
)
s
.
accept_waveform
(
sample_rate
,
samples
)
tail_paddings
=
np
.
zeros
(
int
(
0.2
*
sample_rate
),
dtype
=
np
.
float32
)
s
.
accept_waveform
(
sample_rate
,
tail_paddings
)
s
.
input_finished
()
while
recognizer
.
is_ready
(
s
):
recognizer
.
decode_stream
(
s
)
print
(
recognizer
.
get_result
(
s
))
def
test_transducer_multiple_files
(
self
):
for
use_int8
in
[
True
,
False
]:
if
use_int8
:
encoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
decoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx"
joiner
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
else
:
encoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
decoder
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
joiner
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
tokens
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
wave0
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav"
wave1
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
wave2
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/2.wav"
wave3
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/3.wav"
wave4
=
f
"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/8k.wav"
if
not
Path
(
encoder
)
.
is_file
():
print
(
"skipping test_transducer_multiple_files()"
)
return
for
decoding_method
in
[
"greedy_search"
,
"modified_beam_search"
]:
recognizer
=
sherpa_onnx
.
OnlineRecognizer
(
encoder
=
encoder
,
decoder
=
decoder
,
joiner
=
joiner
,
tokens
=
tokens
,
num_threads
=
1
,
decoding_method
=
decoding_method
,
)
streams
=
[]
waves
=
[
wave0
,
wave1
,
wave2
,
wave3
,
wave4
]
for
wave
in
waves
:
s
=
recognizer
.
create_stream
()
samples
,
sample_rate
=
read_wave
(
wave
)
s
.
accept_waveform
(
sample_rate
,
samples
)
tail_paddings
=
np
.
zeros
(
int
(
0.2
*
sample_rate
),
dtype
=
np
.
float32
)
s
.
accept_waveform
(
sample_rate
,
tail_paddings
)
s
.
input_finished
()
streams
.
append
(
s
)
while
True
:
ready_list
=
[]
for
s
in
streams
:
if
recognizer
.
is_ready
(
s
):
ready_list
.
append
(
s
)
if
len
(
ready_list
)
==
0
:
break
recognizer
.
decode_streams
(
ready_list
)
results
=
[
recognizer
.
get_result
(
s
)
for
s
in
streams
]
for
wave_filename
,
result
in
zip
(
waves
,
results
):
print
(
f
"{wave_filename}
\n
{result}"
)
print
(
"-"
*
10
)
if
__name__
==
"__main__"
:
unittest
.
main
()
...
...
请
注册
或
登录
后发表评论