Fangjun Kuang
Committed by GitHub

add python tests (#111)

@@ -8,15 +8,20 @@ log() { @@ -8,15 +8,20 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +mkdir -p /tmp/icefall-models
  12 +dir=/tmp/icefall-models
11 13
  14 +log "Test streaming transducer models"
  15 +
  16 +pushd $dir
12 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 17 repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
13 18
14 log "Start testing ${repo_url}" 19 log "Start testing ${repo_url}"
15 -repo=$(basename $repo_url) 20 +repo=$dir/$(basename $repo_url)
16 log "Download pretrained model and test-data from $repo_url" 21 log "Download pretrained model and test-data from $repo_url"
17 22
18 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 23 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
19 -pushd $repo 24 +cd $repo
20 git lfs pull --include "*.onnx" 25 git lfs pull --include "*.onnx"
21 popd 26 popd
22 27
@@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \ @@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \
38 $repo/test_wavs/0.wav \ 43 $repo/test_wavs/0.wav \
39 $repo/test_wavs/1.wav \ 44 $repo/test_wavs/1.wav \
40 $repo/test_wavs/2.wav \ 45 $repo/test_wavs/2.wav \
41 - $repo/test_wavs/3.wav 46 + $repo/test_wavs/3.wav \
  47 + $repo/test_wavs/8k.wav
  48 +
  49 +python3 ./python-api-examples/online-decode-files.py \
  50 + --tokens=$repo/tokens.txt \
  51 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  52 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  53 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  54 + $repo/test_wavs/0.wav \
  55 + $repo/test_wavs/1.wav \
  56 + $repo/test_wavs/2.wav \
  57 + $repo/test_wavs/3.wav \
  58 + $repo/test_wavs/8k.wav
  59 +
  60 +python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
  61 +
  62 +log "Test non-streaming transducer models"
  63 +
  64 +pushd $dir
  65 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-04-01
  66 +
  67 +log "Start testing ${repo_url}"
  68 +repo=$dir/$(basename $repo_url)
  69 +log "Download pretrained model and test-data from $repo_url"
  70 +
  71 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  72 +cd $repo
  73 +git lfs pull --include "*.onnx"
  74 +popd
  75 +
  76 +ls -lh $repo
  77 +
  78 +python3 ./python-api-examples/offline-decode-files.py \
  79 + --tokens=$repo/tokens.txt \
  80 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  81 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  82 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  83 + $repo/test_wavs/0.wav \
  84 + $repo/test_wavs/1.wav \
  85 + $repo/test_wavs/8k.wav
  86 +
  87 +python3 ./python-api-examples/offline-decode-files.py \
  88 + --tokens=$repo/tokens.txt \
  89 + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
  90 + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \
  91 + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
  92 + $repo/test_wavs/0.wav \
  93 + $repo/test_wavs/1.wav \
  94 + $repo/test_wavs/8k.wav
  95 +
  96 +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
  97 +
  98 +log "Test non-streaming paraformer models"
  99 +
  100 +pushd $dir
  101 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
  102 +
  103 +log "Start testing ${repo_url}"
  104 +repo=$dir/$(basename $repo_url)
  105 +log "Download pretrained model and test-data from $repo_url"
  106 +
  107 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  108 +cd $repo
  109 +git lfs pull --include "*.onnx"
  110 +popd
  111 +
  112 +ls -lh $repo
  113 +
  114 +python3 ./python-api-examples/offline-decode-files.py \
  115 + --tokens=$repo/tokens.txt \
  116 + --paraformer=$repo/model.onnx \
  117 + $repo/test_wavs/0.wav \
  118 + $repo/test_wavs/1.wav \
  119 + $repo/test_wavs/2.wav \
  120 + $repo/test_wavs/8k.wav
  121 +
  122 +python3 ./python-api-examples/offline-decode-files.py \
  123 + --tokens=$repo/tokens.txt \
  124 + --paraformer=$repo/model.int8.onnx \
  125 + $repo/test_wavs/0.wav \
  126 + $repo/test_wavs/1.wav \
  127 + $repo/test_wavs/2.wav \
  128 + $repo/test_wavs/8k.wav
  129 +
  130 +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
@@ -51,3 +51,4 @@ a.sh @@ -51,3 +51,4 @@ a.sh
51 run-offline-websocket-client-*.sh 51 run-offline-websocket-client-*.sh
52 run-sherpa-onnx-*.sh 52 run-sherpa-onnx-*.sh
53 sherpa-onnx-zipformer-en-2023-03-30 53 sherpa-onnx-zipformer-en-2023-03-30
  54 +sherpa-onnx-zipformer-en-2023-04-01
@@ -46,6 +46,7 @@ from typing import Tuple @@ -46,6 +46,7 @@ from typing import Tuple
46 import numpy as np 46 import numpy as np
47 import sherpa_onnx 47 import sherpa_onnx
48 48
  49 +
49 def get_args(): 50 def get_args():
50 parser = argparse.ArgumentParser( 51 parser = argparse.ArgumentParser(
51 formatter_class=argparse.ArgumentDefaultsHelpFormatter 52 formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: @@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
165 samples_float32 = samples_float32 / 32768 166 samples_float32 = samples_float32 / 32768
166 return samples_float32, f.getframerate() 167 return samples_float32, f.getframerate()
167 168
  169 +
168 def main(): 170 def main():
169 args = get_args() 171 args = get_args()
170 assert_file_exists(args.tokens) 172 assert_file_exists(args.tokens)
@@ -183,7 +185,7 @@ def main(): @@ -183,7 +185,7 @@ def main():
183 sample_rate=args.sample_rate, 185 sample_rate=args.sample_rate,
184 feature_dim=args.feature_dim, 186 feature_dim=args.feature_dim,
185 decoding_method=args.decoding_method, 187 decoding_method=args.decoding_method,
186 - debug=args.debug 188 + debug=args.debug,
187 ) 189 )
188 else: 190 else:
189 assert_file_exists(args.paraformer) 191 assert_file_exists(args.paraformer)
@@ -194,10 +196,9 @@ def main(): @@ -194,10 +196,9 @@ def main():
194 sample_rate=args.sample_rate, 196 sample_rate=args.sample_rate,
195 feature_dim=args.feature_dim, 197 feature_dim=args.feature_dim,
196 decoding_method=args.decoding_method, 198 decoding_method=args.decoding_method,
197 - debug=args.debug 199 + debug=args.debug,
198 ) 200 )
199 201
200 -  
201 print("Started!") 202 print("Started!")
202 start_time = time.time() 203 start_time = time.time()
203 204
@@ -212,12 +213,8 @@ def main(): @@ -212,12 +213,8 @@ def main():
212 s = recognizer.create_stream() 213 s = recognizer.create_stream()
213 s.accept_waveform(sample_rate, samples) 214 s.accept_waveform(sample_rate, samples)
214 215
215 - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)  
216 - s.accept_waveform(sample_rate, tail_paddings)  
217 -  
218 streams.append(s) 216 streams.append(s)
219 217
220 -  
221 recognizer.decode_streams(streams) 218 recognizer.decode_streams(streams)
222 results = [s.result.text for s in streams] 219 results = [s.result.text for s in streams]
223 end_time = time.time() 220 end_time = time.time()
@@ -18,8 +18,8 @@ namespace sherpa_onnx { @@ -18,8 +18,8 @@ namespace sherpa_onnx {
18 18
19 void FeatureExtractorConfig::Register(ParseOptions *po) { 19 void FeatureExtractorConfig::Register(ParseOptions *po) {
20 po->Register("sample-rate", &sampling_rate, 20 po->Register("sample-rate", &sampling_rate,
21 - "Sampling rate of the input waveform. Must match the one "  
22 - "expected by the model. Note: You can have a different " 21 + "Sampling rate of the input waveform. "
  22 + "Note: You can have a different "
23 "sample rate for the input waveform. We will do resampling " 23 "sample rate for the input waveform. We will do resampling "
24 "inside the feature extractor"); 24 "inside the feature extractor");
25 25
@@ -17,8 +17,8 @@ namespace sherpa_onnx { @@ -17,8 +17,8 @@ namespace sherpa_onnx {
17 17
18 void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { 18 void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
19 po->Register("sample-rate", &sampling_rate, 19 po->Register("sample-rate", &sampling_rate,
20 - "Sampling rate of the input waveform. Must match the one "  
21 - "expected by the model. Note: You can have a different " 20 + "Sampling rate of the input waveform. "
  21 + "Note: You can have a different "
22 "sample rate for the input waveform. We will do resampling " 22 "sample rate for the input waveform. We will do resampling "
23 "inside the feature extractor"); 23 "inside the feature extractor");
24 24
@@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) { @@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) {
65 po.Register("port", &port, "The port on which the server will listen."); 65 po.Register("port", &port, "The port on which the server will listen.");
66 66
67 config.Register(&po); 67 config.Register(&po);
  68 + po.DisableOption("sample-rate");
68 69
69 if (argc == 1) { 70 if (argc == 1) {
70 po.PrintUsage(); 71 po.PrintUsage();
@@ -18,20 +18,25 @@ def _assert_file_exists(f: str): @@ -18,20 +18,25 @@ def _assert_file_exists(f: str):
18 18
19 19
20 class OfflineRecognizer(object): 20 class OfflineRecognizer(object):
21 - """A class for offline speech recognition.""" 21 + """A class for offline speech recognition.
  22 +
  23 + Please refer to the following files for usages
  24 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_offline_recognizer.py
  25 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/offline-decode-files.py
  26 + """
22 27
23 @classmethod 28 @classmethod
24 def from_transducer( 29 def from_transducer(
25 - cls,  
26 - encoder: str,  
27 - decoder: str,  
28 - joiner: str,  
29 - tokens: str,  
30 - num_threads: int,  
31 - sample_rate: int = 16000,  
32 - feature_dim: int = 80,  
33 - decoding_method: str = "greedy_search",  
34 - debug: bool = False, 30 + cls,
  31 + encoder: str,
  32 + decoder: str,
  33 + joiner: str,
  34 + tokens: str,
  35 + num_threads: int,
  36 + sample_rate: int = 16000,
  37 + feature_dim: int = 80,
  38 + decoding_method: str = "greedy_search",
  39 + debug: bool = False,
35 ): 40 ):
36 """ 41 """
37 Please refer to 42 Please refer to
@@ -59,7 +64,7 @@ class OfflineRecognizer(object): @@ -59,7 +64,7 @@ class OfflineRecognizer(object):
59 feature_dim: 64 feature_dim:
60 Dimension of the feature used to train the model. 65 Dimension of the feature used to train the model.
61 decoding_method: 66 decoding_method:
62 - Valid values are greedy_search, modified_beam_search. 67 + Support only greedy_search for now.
63 debug: 68 debug:
64 True to show debug messages. 69 True to show debug messages.
65 """ 70 """
@@ -68,14 +73,12 @@ class OfflineRecognizer(object): @@ -68,14 +73,12 @@ class OfflineRecognizer(object):
68 transducer=OfflineTransducerModelConfig( 73 transducer=OfflineTransducerModelConfig(
69 encoder_filename=encoder, 74 encoder_filename=encoder,
70 decoder_filename=decoder, 75 decoder_filename=decoder,
71 - joiner_filename=joiner  
72 - ),  
73 - paraformer=OfflineParaformerModelConfig(  
74 - model="" 76 + joiner_filename=joiner,
75 ), 77 ),
  78 + paraformer=OfflineParaformerModelConfig(model=""),
76 tokens=tokens, 79 tokens=tokens,
77 num_threads=num_threads, 80 num_threads=num_threads,
78 - debug=debug 81 + debug=debug,
79 ) 82 )
80 83
81 feat_config = OfflineFeatureExtractorConfig( 84 feat_config = OfflineFeatureExtractorConfig(
@@ -93,14 +96,14 @@ class OfflineRecognizer(object): @@ -93,14 +96,14 @@ class OfflineRecognizer(object):
93 96
94 @classmethod 97 @classmethod
95 def from_paraformer( 98 def from_paraformer(
96 - cls,  
97 - paraformer: str,  
98 - tokens: str,  
99 - num_threads: int,  
100 - sample_rate: int = 16000,  
101 - feature_dim: int = 80,  
102 - decoding_method: str = "greedy_search",  
103 - debug: bool = False, 99 + cls,
  100 + paraformer: str,
  101 + tokens: str,
  102 + num_threads: int,
  103 + sample_rate: int = 16000,
  104 + feature_dim: int = 80,
  105 + decoding_method: str = "greedy_search",
  106 + debug: bool = False,
104 ): 107 ):
105 """ 108 """
106 Please refer to 109 Please refer to
@@ -131,16 +134,12 @@ class OfflineRecognizer(object): @@ -131,16 +134,12 @@ class OfflineRecognizer(object):
131 self = cls.__new__(cls) 134 self = cls.__new__(cls)
132 model_config = OfflineModelConfig( 135 model_config = OfflineModelConfig(
133 transducer=OfflineTransducerModelConfig( 136 transducer=OfflineTransducerModelConfig(
134 - encoder_filename="",  
135 - decoder_filename="",  
136 - joiner_filename=""  
137 - ),  
138 - paraformer=OfflineParaformerModelConfig(  
139 - model=paraformer 137 + encoder_filename="", decoder_filename="", joiner_filename=""
140 ), 138 ),
  139 + paraformer=OfflineParaformerModelConfig(model=paraformer),
141 tokens=tokens, 140 tokens=tokens,
142 num_threads=num_threads, 141 num_threads=num_threads,
143 - debug=debug 142 + debug=debug,
144 ) 143 )
145 144
146 feat_config = OfflineFeatureExtractorConfig( 145 feat_config = OfflineFeatureExtractorConfig(
@@ -164,4 +163,3 @@ class OfflineRecognizer(object): @@ -164,4 +163,3 @@ class OfflineRecognizer(object):
164 163
165 def decode_streams(self, ss: List[OfflineStream]): 164 def decode_streams(self, ss: List[OfflineStream]):
166 self.recognizer.decode_streams(ss) 165 self.recognizer.decode_streams(ss)
167 -  
@@ -17,7 +17,12 @@ def _assert_file_exists(f: str): @@ -17,7 +17,12 @@ def _assert_file_exists(f: str):
17 17
18 18
19 class OnlineRecognizer(object): 19 class OnlineRecognizer(object):
20 - """A class for streaming speech recognition.""" 20 + """A class for streaming speech recognition.
  21 +
  22 + Please refer to the following files for usages
  23 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py
  24 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py
  25 + """
21 26
22 def __init__( 27 def __init__(
23 self, 28 self,
@@ -18,6 +18,8 @@ endfunction() @@ -18,6 +18,8 @@ endfunction()
18 # please sort the files in alphabetic order 18 # please sort the files in alphabetic order
19 set(py_test_files 19 set(py_test_files
20 test_feature_extractor_config.py 20 test_feature_extractor_config.py
  21 + test_offline_recognizer.py
  22 + test_online_recognizer.py
21 test_online_transducer_model_config.py 23 test_online_transducer_model_config.py
22 ) 24 )
23 25
  1 +# sherpa-onnx/python/tests/test_offline_recognizer.py
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_offline_recognizer_py
  8 +
  9 +import unittest
  10 +import wave
  11 +from pathlib import Path
  12 +from typing import Tuple
  13 +
  14 +import numpy as np
  15 +import sherpa_onnx
  16 +
  17 +d = "/tmp/icefall-models"
  18 +# Please refer to
  19 +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
  20 +# and
  21 +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
  22 +# to download pre-trained models for testing
  23 +
  24 +
  25 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  26 + """
  27 + Args:
  28 + wave_filename:
  29 + Path to a wave file. It should be single channel and each sample should
  30 + be 16-bit. Its sample rate does not need to be 16kHz.
  31 + Returns:
  32 + Return a tuple containing:
  33 + - A 1-D array of dtype np.float32 containing the samples, which are
  34 + normalized to the range [-1, 1].
  35 + - sample rate of the wave file
  36 + """
  37 +
  38 + with wave.open(wave_filename) as f:
  39 + assert f.getnchannels() == 1, f.getnchannels()
  40 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  41 + num_samples = f.getnframes()
  42 + samples = f.readframes(num_samples)
  43 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  44 + samples_float32 = samples_int16.astype(np.float32)
  45 +
  46 + samples_float32 = samples_float32 / 32768
  47 + return samples_float32, f.getframerate()
  48 +
  49 +
  50 +class TestOfflineRecognizer(unittest.TestCase):
  51 + def test_transducer_single_file(self):
  52 + for use_int8 in [True, False]:
  53 + if use_int8:
  54 + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx"
  55 + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx"
  56 + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx"
  57 + else:
  58 + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx"
  59 + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx"
  60 + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx"
  61 +
  62 + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt"
  63 + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav"
  64 +
  65 + if not Path(encoder).is_file():
  66 + print("skipping test_transducer_single_file()")
  67 + return
  68 +
  69 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  70 + encoder=encoder,
  71 + decoder=decoder,
  72 + joiner=joiner,
  73 + tokens=tokens,
  74 + num_threads=1,
  75 + )
  76 +
  77 + s = recognizer.create_stream()
  78 + samples, sample_rate = read_wave(wave0)
  79 + s.accept_waveform(sample_rate, samples)
  80 + recognizer.decode_stream(s)
  81 + print(s.result.text)
  82 +
  83 + def test_transducer_multiple_files(self):
  84 + for use_int8 in [True, False]:
  85 + if use_int8:
  86 + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx"
  87 + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx"
  88 + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx"
  89 + else:
  90 + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx"
  91 + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx"
  92 + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx"
  93 +
  94 + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt"
  95 + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav"
  96 + wave1 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/1.wav"
  97 + wave2 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/8k.wav"
  98 +
  99 + if not Path(encoder).is_file():
  100 + print("skipping test_transducer_multiple_files()")
  101 + return
  102 +
  103 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  104 + encoder=encoder,
  105 + decoder=decoder,
  106 + joiner=joiner,
  107 + tokens=tokens,
  108 + num_threads=1,
  109 + )
  110 +
  111 + s0 = recognizer.create_stream()
  112 + samples0, sample_rate0 = read_wave(wave0)
  113 + s0.accept_waveform(sample_rate0, samples0)
  114 +
  115 + s1 = recognizer.create_stream()
  116 + samples1, sample_rate1 = read_wave(wave1)
  117 + s1.accept_waveform(sample_rate1, samples1)
  118 +
  119 + s2 = recognizer.create_stream()
  120 + samples2, sample_rate2 = read_wave(wave2)
  121 + s2.accept_waveform(sample_rate2, samples2)
  122 +
  123 + recognizer.decode_streams([s0, s1, s2])
  124 + print(s0.result.text)
  125 + print(s1.result.text)
  126 + print(s2.result.text)
  127 +
  128 + def test_paraformer_single_file(self):
  129 + for use_int8 in [True, False]:
  130 + if use_int8:
  131 + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
  132 + else:
  133 + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx"
  134 +
  135 + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
  136 + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
  137 +
  138 + if not Path(model).is_file():
  139 + print("skipping test_paraformer_single_file()")
  140 + return
  141 +
  142 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  143 + paraformer=model,
  144 + tokens=tokens,
  145 + num_threads=1,
  146 + )
  147 +
  148 + s = recognizer.create_stream()
  149 + samples, sample_rate = read_wave(wave0)
  150 + s.accept_waveform(sample_rate, samples)
  151 + recognizer.decode_stream(s)
  152 + print(s.result.text)
  153 +
  154 + def test_paraformer_multiple_files(self):
  155 + for use_int8 in [True, False]:
  156 + if use_int8:
  157 + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
  158 + else:
  159 + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx"
  160 +
  161 + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
  162 + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav"
  163 + wave1 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav"
  164 + wave2 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav"
  165 + wave3 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav"
  166 +
  167 + if not Path(model).is_file():
  168 + print("skipping test_paraformer_multiple_files()")
  169 + return
  170 +
  171 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  172 + paraformer=model,
  173 + tokens=tokens,
  174 + num_threads=1,
  175 + )
  176 +
  177 + s0 = recognizer.create_stream()
  178 + samples0, sample_rate0 = read_wave(wave0)
  179 + s0.accept_waveform(sample_rate0, samples0)
  180 +
  181 + s1 = recognizer.create_stream()
  182 + samples1, sample_rate1 = read_wave(wave1)
  183 + s1.accept_waveform(sample_rate1, samples1)
  184 +
  185 + s2 = recognizer.create_stream()
  186 + samples2, sample_rate2 = read_wave(wave2)
  187 + s2.accept_waveform(sample_rate2, samples2)
  188 +
  189 + s3 = recognizer.create_stream()
  190 + samples3, sample_rate3 = read_wave(wave3)
  191 + s3.accept_waveform(sample_rate3, samples3)
  192 +
  193 + recognizer.decode_streams([s0, s1, s2, s3])
  194 + print(s0.result.text)
  195 + print(s1.result.text)
  196 + print(s2.result.text)
  197 + print(s3.result.text)
  198 +
  199 +
  200 +if __name__ == "__main__":
  201 + unittest.main()
  1 +# sherpa-onnx/python/tests/test_online_recognizer.py
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_online_recognizer_py
  8 +
  9 +import unittest
  10 +import wave
  11 +from pathlib import Path
  12 +from typing import Tuple
  13 +
  14 +import numpy as np
  15 +import sherpa_onnx
  16 +
  17 +d = "/tmp/icefall-models"
  18 +# Please refer to
  19 +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html
  20 +# to download pre-trained models for testing
  21 +
  22 +
  23 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  24 + """
  25 + Args:
  26 + wave_filename:
  27 + Path to a wave file. It should be single channel and each sample should
  28 + be 16-bit. Its sample rate does not need to be 16kHz.
  29 + Returns:
  30 + Return a tuple containing:
  31 + - A 1-D array of dtype np.float32 containing the samples, which are
  32 + normalized to the range [-1, 1].
  33 + - sample rate of the wave file
  34 + """
  35 +
  36 + with wave.open(wave_filename) as f:
  37 + assert f.getnchannels() == 1, f.getnchannels()
  38 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  39 + num_samples = f.getnframes()
  40 + samples = f.readframes(num_samples)
  41 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  42 + samples_float32 = samples_int16.astype(np.float32)
  43 +
  44 + samples_float32 = samples_float32 / 32768
  45 + return samples_float32, f.getframerate()
  46 +
  47 +
  48 +class TestOnlineRecognizer(unittest.TestCase):
  49 + def test_transducer_single_file(self):
  50 + for use_int8 in [True, False]:
  51 + if use_int8:
  52 + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
  53 + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx"
  54 + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
  55 + else:
  56 + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
  57 + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
  58 + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
  59 +
  60 + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
  61 + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav"
  62 +
  63 + if not Path(encoder).is_file():
  64 + print("skipping test_transducer_single_file()")
  65 + return
  66 +
  67 + for decoding_method in ["greedy_search", "modified_beam_search"]:
  68 + recognizer = sherpa_onnx.OnlineRecognizer(
  69 + encoder=encoder,
  70 + decoder=decoder,
  71 + joiner=joiner,
  72 + tokens=tokens,
  73 + num_threads=1,
  74 + decoding_method=decoding_method,
  75 + )
  76 + s = recognizer.create_stream()
  77 + samples, sample_rate = read_wave(wave0)
  78 + s.accept_waveform(sample_rate, samples)
  79 +
  80 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  81 + s.accept_waveform(sample_rate, tail_paddings)
  82 +
  83 + s.input_finished()
  84 + while recognizer.is_ready(s):
  85 + recognizer.decode_stream(s)
  86 + print(recognizer.get_result(s))
  87 +
  88 + def test_transducer_multiple_files(self):
  89 + for use_int8 in [True, False]:
  90 + if use_int8:
  91 + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
  92 + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx"
  93 + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
  94 + else:
  95 + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx"
  96 + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
  97 + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx"
  98 +
  99 + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
  100 + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav"
  101 + wave1 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav"
  102 + wave2 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/2.wav"
  103 + wave3 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/3.wav"
  104 + wave4 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/8k.wav"
  105 +
  106 + if not Path(encoder).is_file():
  107 + print("skipping test_transducer_multiple_files()")
  108 + return
  109 +
  110 + for decoding_method in ["greedy_search", "modified_beam_search"]:
  111 + recognizer = sherpa_onnx.OnlineRecognizer(
  112 + encoder=encoder,
  113 + decoder=decoder,
  114 + joiner=joiner,
  115 + tokens=tokens,
  116 + num_threads=1,
  117 + decoding_method=decoding_method,
  118 + )
  119 + streams = []
  120 + waves = [wave0, wave1, wave2, wave3, wave4]
  121 + for wave in waves:
  122 + s = recognizer.create_stream()
  123 + samples, sample_rate = read_wave(wave)
  124 + s.accept_waveform(sample_rate, samples)
  125 +
  126 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  127 + s.accept_waveform(sample_rate, tail_paddings)
  128 + s.input_finished()
  129 + streams.append(s)
  130 +
  131 + while True:
  132 + ready_list = []
  133 + for s in streams:
  134 + if recognizer.is_ready(s):
  135 + ready_list.append(s)
  136 + if len(ready_list) == 0:
  137 + break
  138 + recognizer.decode_streams(ready_list)
  139 + results = [recognizer.get_result(s) for s in streams]
  140 + for wave_filename, result in zip(waves, results):
  141 + print(f"{wave_filename}\n{result}")
  142 + print("-" * 10)
  143 +
  144 +
  145 +if __name__ == "__main__":
  146 + unittest.main()