正在显示
11 个修改的文件
包含
470 行增加
和
30 行删除
| @@ -8,15 +8,20 @@ log() { | @@ -8,15 +8,20 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +mkdir -p /tmp/icefall-models | ||
| 12 | +dir=/tmp/icefall-models | ||
| 11 | 13 | ||
| 14 | +log "Test streaming transducer models" | ||
| 15 | + | ||
| 16 | +pushd $dir | ||
| 12 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 | 17 | repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 |
| 13 | 18 | ||
| 14 | log "Start testing ${repo_url}" | 19 | log "Start testing ${repo_url}" |
| 15 | -repo=$(basename $repo_url) | 20 | +repo=$dir/$(basename $repo_url) |
| 16 | log "Download pretrained model and test-data from $repo_url" | 21 | log "Download pretrained model and test-data from $repo_url" |
| 17 | 22 | ||
| 18 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 23 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 19 | -pushd $repo | 24 | +cd $repo |
| 20 | git lfs pull --include "*.onnx" | 25 | git lfs pull --include "*.onnx" |
| 21 | popd | 26 | popd |
| 22 | 27 | ||
| @@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \ | @@ -38,4 +43,88 @@ python3 ./python-api-examples/online-decode-files.py \ | ||
| 38 | $repo/test_wavs/0.wav \ | 43 | $repo/test_wavs/0.wav \ |
| 39 | $repo/test_wavs/1.wav \ | 44 | $repo/test_wavs/1.wav \ |
| 40 | $repo/test_wavs/2.wav \ | 45 | $repo/test_wavs/2.wav \ |
| 41 | - $repo/test_wavs/3.wav | 46 | + $repo/test_wavs/3.wav \ |
| 47 | + $repo/test_wavs/8k.wav | ||
| 48 | + | ||
| 49 | +python3 ./python-api-examples/online-decode-files.py \ | ||
| 50 | + --tokens=$repo/tokens.txt \ | ||
| 51 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 52 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 53 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 54 | + $repo/test_wavs/0.wav \ | ||
| 55 | + $repo/test_wavs/1.wav \ | ||
| 56 | + $repo/test_wavs/2.wav \ | ||
| 57 | + $repo/test_wavs/3.wav \ | ||
| 58 | + $repo/test_wavs/8k.wav | ||
| 59 | + | ||
| 60 | +python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose | ||
| 61 | + | ||
| 62 | +log "Test non-streaming transducer models" | ||
| 63 | + | ||
| 64 | +pushd $dir | ||
| 65 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-en-2023-04-01 | ||
| 66 | + | ||
| 67 | +log "Start testing ${repo_url}" | ||
| 68 | +repo=$dir/$(basename $repo_url) | ||
| 69 | +log "Download pretrained model and test-data from $repo_url" | ||
| 70 | + | ||
| 71 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 72 | +cd $repo | ||
| 73 | +git lfs pull --include "*.onnx" | ||
| 74 | +popd | ||
| 75 | + | ||
| 76 | +ls -lh $repo | ||
| 77 | + | ||
| 78 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 79 | + --tokens=$repo/tokens.txt \ | ||
| 80 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 81 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 82 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 83 | + $repo/test_wavs/0.wav \ | ||
| 84 | + $repo/test_wavs/1.wav \ | ||
| 85 | + $repo/test_wavs/8k.wav | ||
| 86 | + | ||
| 87 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 88 | + --tokens=$repo/tokens.txt \ | ||
| 89 | + --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | ||
| 90 | + --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | ||
| 91 | + --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | ||
| 92 | + $repo/test_wavs/0.wav \ | ||
| 93 | + $repo/test_wavs/1.wav \ | ||
| 94 | + $repo/test_wavs/8k.wav | ||
| 95 | + | ||
| 96 | +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | ||
| 97 | + | ||
| 98 | +log "Test non-streaming paraformer models" | ||
| 99 | + | ||
| 100 | +pushd $dir | ||
| 101 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 102 | + | ||
| 103 | +log "Start testing ${repo_url}" | ||
| 104 | +repo=$dir/$(basename $repo_url) | ||
| 105 | +log "Download pretrained model and test-data from $repo_url" | ||
| 106 | + | ||
| 107 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 108 | +cd $repo | ||
| 109 | +git lfs pull --include "*.onnx" | ||
| 110 | +popd | ||
| 111 | + | ||
| 112 | +ls -lh $repo | ||
| 113 | + | ||
| 114 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 115 | + --tokens=$repo/tokens.txt \ | ||
| 116 | + --paraformer=$repo/model.onnx \ | ||
| 117 | + $repo/test_wavs/0.wav \ | ||
| 118 | + $repo/test_wavs/1.wav \ | ||
| 119 | + $repo/test_wavs/2.wav \ | ||
| 120 | + $repo/test_wavs/8k.wav | ||
| 121 | + | ||
| 122 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 123 | + --tokens=$repo/tokens.txt \ | ||
| 124 | + --paraformer=$repo/model.int8.onnx \ | ||
| 125 | + $repo/test_wavs/0.wav \ | ||
| 126 | + $repo/test_wavs/1.wav \ | ||
| 127 | + $repo/test_wavs/2.wav \ | ||
| 128 | + $repo/test_wavs/8k.wav | ||
| 129 | + | ||
| 130 | +python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
python-api-examples/offline-decode-files.py
100644 → 100755
| @@ -46,6 +46,7 @@ from typing import Tuple | @@ -46,6 +46,7 @@ from typing import Tuple | ||
| 46 | import numpy as np | 46 | import numpy as np |
| 47 | import sherpa_onnx | 47 | import sherpa_onnx |
| 48 | 48 | ||
| 49 | + | ||
| 49 | def get_args(): | 50 | def get_args(): |
| 50 | parser = argparse.ArgumentParser( | 51 | parser = argparse.ArgumentParser( |
| 51 | formatter_class=argparse.ArgumentDefaultsHelpFormatter | 52 | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| @@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | @@ -165,6 +166,7 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 165 | samples_float32 = samples_float32 / 32768 | 166 | samples_float32 = samples_float32 / 32768 |
| 166 | return samples_float32, f.getframerate() | 167 | return samples_float32, f.getframerate() |
| 167 | 168 | ||
| 169 | + | ||
| 168 | def main(): | 170 | def main(): |
| 169 | args = get_args() | 171 | args = get_args() |
| 170 | assert_file_exists(args.tokens) | 172 | assert_file_exists(args.tokens) |
| @@ -183,7 +185,7 @@ def main(): | @@ -183,7 +185,7 @@ def main(): | ||
| 183 | sample_rate=args.sample_rate, | 185 | sample_rate=args.sample_rate, |
| 184 | feature_dim=args.feature_dim, | 186 | feature_dim=args.feature_dim, |
| 185 | decoding_method=args.decoding_method, | 187 | decoding_method=args.decoding_method, |
| 186 | - debug=args.debug | 188 | + debug=args.debug, |
| 187 | ) | 189 | ) |
| 188 | else: | 190 | else: |
| 189 | assert_file_exists(args.paraformer) | 191 | assert_file_exists(args.paraformer) |
| @@ -194,10 +196,9 @@ def main(): | @@ -194,10 +196,9 @@ def main(): | ||
| 194 | sample_rate=args.sample_rate, | 196 | sample_rate=args.sample_rate, |
| 195 | feature_dim=args.feature_dim, | 197 | feature_dim=args.feature_dim, |
| 196 | decoding_method=args.decoding_method, | 198 | decoding_method=args.decoding_method, |
| 197 | - debug=args.debug | 199 | + debug=args.debug, |
| 198 | ) | 200 | ) |
| 199 | 201 | ||
| 200 | - | ||
| 201 | print("Started!") | 202 | print("Started!") |
| 202 | start_time = time.time() | 203 | start_time = time.time() |
| 203 | 204 | ||
| @@ -212,12 +213,8 @@ def main(): | @@ -212,12 +213,8 @@ def main(): | ||
| 212 | s = recognizer.create_stream() | 213 | s = recognizer.create_stream() |
| 213 | s.accept_waveform(sample_rate, samples) | 214 | s.accept_waveform(sample_rate, samples) |
| 214 | 215 | ||
| 215 | - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 216 | - s.accept_waveform(sample_rate, tail_paddings) | ||
| 217 | - | ||
| 218 | streams.append(s) | 216 | streams.append(s) |
| 219 | 217 | ||
| 220 | - | ||
| 221 | recognizer.decode_streams(streams) | 218 | recognizer.decode_streams(streams) |
| 222 | results = [s.result.text for s in streams] | 219 | results = [s.result.text for s in streams] |
| 223 | end_time = time.time() | 220 | end_time = time.time() |
| @@ -18,8 +18,8 @@ namespace sherpa_onnx { | @@ -18,8 +18,8 @@ namespace sherpa_onnx { | ||
| 18 | 18 | ||
| 19 | void FeatureExtractorConfig::Register(ParseOptions *po) { | 19 | void FeatureExtractorConfig::Register(ParseOptions *po) { |
| 20 | po->Register("sample-rate", &sampling_rate, | 20 | po->Register("sample-rate", &sampling_rate, |
| 21 | - "Sampling rate of the input waveform. Must match the one " | ||
| 22 | - "expected by the model. Note: You can have a different " | 21 | + "Sampling rate of the input waveform. " |
| 22 | + "Note: You can have a different " | ||
| 23 | "sample rate for the input waveform. We will do resampling " | 23 | "sample rate for the input waveform. We will do resampling " |
| 24 | "inside the feature extractor"); | 24 | "inside the feature extractor"); |
| 25 | 25 |
| @@ -17,8 +17,8 @@ namespace sherpa_onnx { | @@ -17,8 +17,8 @@ namespace sherpa_onnx { | ||
| 17 | 17 | ||
| 18 | void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { | 18 | void OfflineFeatureExtractorConfig::Register(ParseOptions *po) { |
| 19 | po->Register("sample-rate", &sampling_rate, | 19 | po->Register("sample-rate", &sampling_rate, |
| 20 | - "Sampling rate of the input waveform. Must match the one " | ||
| 21 | - "expected by the model. Note: You can have a different " | 20 | + "Sampling rate of the input waveform. " |
| 21 | + "Note: You can have a different " | ||
| 22 | "sample rate for the input waveform. We will do resampling " | 22 | "sample rate for the input waveform. We will do resampling " |
| 23 | "inside the feature extractor"); | 23 | "inside the feature extractor"); |
| 24 | 24 |
| @@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) { | @@ -65,6 +65,7 @@ int32_t main(int32_t argc, char *argv[]) { | ||
| 65 | po.Register("port", &port, "The port on which the server will listen."); | 65 | po.Register("port", &port, "The port on which the server will listen."); |
| 66 | 66 | ||
| 67 | config.Register(&po); | 67 | config.Register(&po); |
| 68 | + po.DisableOption("sample-rate"); | ||
| 68 | 69 | ||
| 69 | if (argc == 1) { | 70 | if (argc == 1) { |
| 70 | po.PrintUsage(); | 71 | po.PrintUsage(); |
| @@ -18,7 +18,12 @@ def _assert_file_exists(f: str): | @@ -18,7 +18,12 @@ def _assert_file_exists(f: str): | ||
| 18 | 18 | ||
| 19 | 19 | ||
| 20 | class OfflineRecognizer(object): | 20 | class OfflineRecognizer(object): |
| 21 | - """A class for offline speech recognition.""" | 21 | + """A class for offline speech recognition. |
| 22 | + | ||
| 23 | + Please refer to the following files for usages | ||
| 24 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_offline_recognizer.py | ||
| 25 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/offline-decode-files.py | ||
| 26 | + """ | ||
| 22 | 27 | ||
| 23 | @classmethod | 28 | @classmethod |
| 24 | def from_transducer( | 29 | def from_transducer( |
| @@ -59,7 +64,7 @@ class OfflineRecognizer(object): | @@ -59,7 +64,7 @@ class OfflineRecognizer(object): | ||
| 59 | feature_dim: | 64 | feature_dim: |
| 60 | Dimension of the feature used to train the model. | 65 | Dimension of the feature used to train the model. |
| 61 | decoding_method: | 66 | decoding_method: |
| 62 | - Valid values are greedy_search, modified_beam_search. | 67 | + Support only greedy_search for now. |
| 63 | debug: | 68 | debug: |
| 64 | True to show debug messages. | 69 | True to show debug messages. |
| 65 | """ | 70 | """ |
| @@ -68,14 +73,12 @@ class OfflineRecognizer(object): | @@ -68,14 +73,12 @@ class OfflineRecognizer(object): | ||
| 68 | transducer=OfflineTransducerModelConfig( | 73 | transducer=OfflineTransducerModelConfig( |
| 69 | encoder_filename=encoder, | 74 | encoder_filename=encoder, |
| 70 | decoder_filename=decoder, | 75 | decoder_filename=decoder, |
| 71 | - joiner_filename=joiner | ||
| 72 | - ), | ||
| 73 | - paraformer=OfflineParaformerModelConfig( | ||
| 74 | - model="" | 76 | + joiner_filename=joiner, |
| 75 | ), | 77 | ), |
| 78 | + paraformer=OfflineParaformerModelConfig(model=""), | ||
| 76 | tokens=tokens, | 79 | tokens=tokens, |
| 77 | num_threads=num_threads, | 80 | num_threads=num_threads, |
| 78 | - debug=debug | 81 | + debug=debug, |
| 79 | ) | 82 | ) |
| 80 | 83 | ||
| 81 | feat_config = OfflineFeatureExtractorConfig( | 84 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -131,16 +134,12 @@ class OfflineRecognizer(object): | @@ -131,16 +134,12 @@ class OfflineRecognizer(object): | ||
| 131 | self = cls.__new__(cls) | 134 | self = cls.__new__(cls) |
| 132 | model_config = OfflineModelConfig( | 135 | model_config = OfflineModelConfig( |
| 133 | transducer=OfflineTransducerModelConfig( | 136 | transducer=OfflineTransducerModelConfig( |
| 134 | - encoder_filename="", | ||
| 135 | - decoder_filename="", | ||
| 136 | - joiner_filename="" | ||
| 137 | - ), | ||
| 138 | - paraformer=OfflineParaformerModelConfig( | ||
| 139 | - model=paraformer | 137 | + encoder_filename="", decoder_filename="", joiner_filename="" |
| 140 | ), | 138 | ), |
| 139 | + paraformer=OfflineParaformerModelConfig(model=paraformer), | ||
| 141 | tokens=tokens, | 140 | tokens=tokens, |
| 142 | num_threads=num_threads, | 141 | num_threads=num_threads, |
| 143 | - debug=debug | 142 | + debug=debug, |
| 144 | ) | 143 | ) |
| 145 | 144 | ||
| 146 | feat_config = OfflineFeatureExtractorConfig( | 145 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -164,4 +163,3 @@ class OfflineRecognizer(object): | @@ -164,4 +163,3 @@ class OfflineRecognizer(object): | ||
| 164 | 163 | ||
| 165 | def decode_streams(self, ss: List[OfflineStream]): | 164 | def decode_streams(self, ss: List[OfflineStream]): |
| 166 | self.recognizer.decode_streams(ss) | 165 | self.recognizer.decode_streams(ss) |
| 167 | - |
| @@ -17,7 +17,12 @@ def _assert_file_exists(f: str): | @@ -17,7 +17,12 @@ def _assert_file_exists(f: str): | ||
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | class OnlineRecognizer(object): | 19 | class OnlineRecognizer(object): |
| 20 | - """A class for streaming speech recognition.""" | 20 | + """A class for streaming speech recognition. |
| 21 | + | ||
| 22 | + Please refer to the following files for usages | ||
| 23 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py | ||
| 24 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py | ||
| 25 | + """ | ||
| 21 | 26 | ||
| 22 | def __init__( | 27 | def __init__( |
| 23 | self, | 28 | self, |
| @@ -18,6 +18,8 @@ endfunction() | @@ -18,6 +18,8 @@ endfunction() | ||
| 18 | # please sort the files in alphabetic order | 18 | # please sort the files in alphabetic order |
| 19 | set(py_test_files | 19 | set(py_test_files |
| 20 | test_feature_extractor_config.py | 20 | test_feature_extractor_config.py |
| 21 | + test_offline_recognizer.py | ||
| 22 | + test_online_recognizer.py | ||
| 21 | test_online_transducer_model_config.py | 23 | test_online_transducer_model_config.py |
| 22 | ) | 24 | ) |
| 23 | 25 |
| 1 | +# sherpa-onnx/python/tests/test_offline_recognizer.py | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +# | ||
| 5 | +# To run this single test, use | ||
| 6 | +# | ||
| 7 | +# ctest --verbose -R test_offline_recognizer_py | ||
| 8 | + | ||
| 9 | +import unittest | ||
| 10 | +import wave | ||
| 11 | +from pathlib import Path | ||
| 12 | +from typing import Tuple | ||
| 13 | + | ||
| 14 | +import numpy as np | ||
| 15 | +import sherpa_onnx | ||
| 16 | + | ||
| 17 | +d = "/tmp/icefall-models" | ||
| 18 | +# Please refer to | ||
| 19 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html | ||
| 20 | +# and | ||
| 21 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html | ||
| 22 | +# to download pre-trained models for testing | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 26 | + """ | ||
| 27 | + Args: | ||
| 28 | + wave_filename: | ||
| 29 | + Path to a wave file. It should be single channel and each sample should | ||
| 30 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 31 | + Returns: | ||
| 32 | + Return a tuple containing: | ||
| 33 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 34 | + normalized to the range [-1, 1]. | ||
| 35 | + - sample rate of the wave file | ||
| 36 | + """ | ||
| 37 | + | ||
| 38 | + with wave.open(wave_filename) as f: | ||
| 39 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 40 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 41 | + num_samples = f.getnframes() | ||
| 42 | + samples = f.readframes(num_samples) | ||
| 43 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 44 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 45 | + | ||
| 46 | + samples_float32 = samples_float32 / 32768 | ||
| 47 | + return samples_float32, f.getframerate() | ||
| 48 | + | ||
| 49 | + | ||
| 50 | +class TestOfflineRecognizer(unittest.TestCase): | ||
| 51 | + def test_transducer_single_file(self): | ||
| 52 | + for use_int8 in [True, False]: | ||
| 53 | + if use_int8: | ||
| 54 | + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" | ||
| 55 | + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx" | ||
| 56 | + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" | ||
| 57 | + else: | ||
| 58 | + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" | ||
| 59 | + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" | ||
| 60 | + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" | ||
| 61 | + | ||
| 62 | + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt" | ||
| 63 | + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav" | ||
| 64 | + | ||
| 65 | + if not Path(encoder).is_file(): | ||
| 66 | + print("skipping test_transducer_single_file()") | ||
| 67 | + return | ||
| 68 | + | ||
| 69 | + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | ||
| 70 | + encoder=encoder, | ||
| 71 | + decoder=decoder, | ||
| 72 | + joiner=joiner, | ||
| 73 | + tokens=tokens, | ||
| 74 | + num_threads=1, | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + s = recognizer.create_stream() | ||
| 78 | + samples, sample_rate = read_wave(wave0) | ||
| 79 | + s.accept_waveform(sample_rate, samples) | ||
| 80 | + recognizer.decode_stream(s) | ||
| 81 | + print(s.result.text) | ||
| 82 | + | ||
| 83 | + def test_transducer_multiple_files(self): | ||
| 84 | + for use_int8 in [True, False]: | ||
| 85 | + if use_int8: | ||
| 86 | + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" | ||
| 87 | + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.int8.onnx" | ||
| 88 | + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" | ||
| 89 | + else: | ||
| 90 | + encoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" | ||
| 91 | + decoder = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" | ||
| 92 | + joiner = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" | ||
| 93 | + | ||
| 94 | + tokens = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/tokens.txt" | ||
| 95 | + wave0 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/0.wav" | ||
| 96 | + wave1 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/1.wav" | ||
| 97 | + wave2 = f"{d}/sherpa-onnx-zipformer-en-2023-04-01/test_wavs/8k.wav" | ||
| 98 | + | ||
| 99 | + if not Path(encoder).is_file(): | ||
| 100 | + print("skipping test_transducer_multiple_files()") | ||
| 101 | + return | ||
| 102 | + | ||
| 103 | + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( | ||
| 104 | + encoder=encoder, | ||
| 105 | + decoder=decoder, | ||
| 106 | + joiner=joiner, | ||
| 107 | + tokens=tokens, | ||
| 108 | + num_threads=1, | ||
| 109 | + ) | ||
| 110 | + | ||
| 111 | + s0 = recognizer.create_stream() | ||
| 112 | + samples0, sample_rate0 = read_wave(wave0) | ||
| 113 | + s0.accept_waveform(sample_rate0, samples0) | ||
| 114 | + | ||
| 115 | + s1 = recognizer.create_stream() | ||
| 116 | + samples1, sample_rate1 = read_wave(wave1) | ||
| 117 | + s1.accept_waveform(sample_rate1, samples1) | ||
| 118 | + | ||
| 119 | + s2 = recognizer.create_stream() | ||
| 120 | + samples2, sample_rate2 = read_wave(wave2) | ||
| 121 | + s2.accept_waveform(sample_rate2, samples2) | ||
| 122 | + | ||
| 123 | + recognizer.decode_streams([s0, s1, s2]) | ||
| 124 | + print(s0.result.text) | ||
| 125 | + print(s1.result.text) | ||
| 126 | + print(s2.result.text) | ||
| 127 | + | ||
| 128 | + def test_paraformer_single_file(self): | ||
| 129 | + for use_int8 in [True, False]: | ||
| 130 | + if use_int8: | ||
| 131 | + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" | ||
| 132 | + else: | ||
| 133 | + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx" | ||
| 134 | + | ||
| 135 | + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" | ||
| 136 | + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav" | ||
| 137 | + | ||
| 138 | + if not Path(model).is_file(): | ||
| 139 | + print("skipping test_paraformer_single_file()") | ||
| 140 | + return | ||
| 141 | + | ||
| 142 | + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | ||
| 143 | + paraformer=model, | ||
| 144 | + tokens=tokens, | ||
| 145 | + num_threads=1, | ||
| 146 | + ) | ||
| 147 | + | ||
| 148 | + s = recognizer.create_stream() | ||
| 149 | + samples, sample_rate = read_wave(wave0) | ||
| 150 | + s.accept_waveform(sample_rate, samples) | ||
| 151 | + recognizer.decode_stream(s) | ||
| 152 | + print(s.result.text) | ||
| 153 | + | ||
| 154 | + def test_paraformer_multiple_files(self): | ||
| 155 | + for use_int8 in [True, False]: | ||
| 156 | + if use_int8: | ||
| 157 | + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" | ||
| 158 | + else: | ||
| 159 | + model = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/model.onnx" | ||
| 160 | + | ||
| 161 | + tokens = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" | ||
| 162 | + wave0 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/0.wav" | ||
| 163 | + wave1 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/1.wav" | ||
| 164 | + wave2 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/2.wav" | ||
| 165 | + wave3 = f"{d}/sherpa-onnx-paraformer-zh-2023-03-28/test_wavs/8k.wav" | ||
| 166 | + | ||
| 167 | + if not Path(model).is_file(): | ||
| 168 | + print("skipping test_paraformer_multiple_files()") | ||
| 169 | + return | ||
| 170 | + | ||
| 171 | + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | ||
| 172 | + paraformer=model, | ||
| 173 | + tokens=tokens, | ||
| 174 | + num_threads=1, | ||
| 175 | + ) | ||
| 176 | + | ||
| 177 | + s0 = recognizer.create_stream() | ||
| 178 | + samples0, sample_rate0 = read_wave(wave0) | ||
| 179 | + s0.accept_waveform(sample_rate0, samples0) | ||
| 180 | + | ||
| 181 | + s1 = recognizer.create_stream() | ||
| 182 | + samples1, sample_rate1 = read_wave(wave1) | ||
| 183 | + s1.accept_waveform(sample_rate1, samples1) | ||
| 184 | + | ||
| 185 | + s2 = recognizer.create_stream() | ||
| 186 | + samples2, sample_rate2 = read_wave(wave2) | ||
| 187 | + s2.accept_waveform(sample_rate2, samples2) | ||
| 188 | + | ||
| 189 | + s3 = recognizer.create_stream() | ||
| 190 | + samples3, sample_rate3 = read_wave(wave3) | ||
| 191 | + s3.accept_waveform(sample_rate3, samples3) | ||
| 192 | + | ||
| 193 | + recognizer.decode_streams([s0, s1, s2, s3]) | ||
| 194 | + print(s0.result.text) | ||
| 195 | + print(s1.result.text) | ||
| 196 | + print(s2.result.text) | ||
| 197 | + print(s3.result.text) | ||
| 198 | + | ||
| 199 | + | ||
| 200 | +if __name__ == "__main__": | ||
| 201 | + unittest.main() |
| 1 | +# sherpa-onnx/python/tests/test_online_recognizer.py | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +# | ||
| 5 | +# To run this single test, use | ||
| 6 | +# | ||
| 7 | +# ctest --verbose -R test_online_recognizer_py | ||
| 8 | + | ||
| 9 | +import unittest | ||
| 10 | +import wave | ||
| 11 | +from pathlib import Path | ||
| 12 | +from typing import Tuple | ||
| 13 | + | ||
| 14 | +import numpy as np | ||
| 15 | +import sherpa_onnx | ||
| 16 | + | ||
| 17 | +d = "/tmp/icefall-models" | ||
| 18 | +# Please refer to | ||
| 19 | +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html | ||
| 20 | +# to download pre-trained models for testing | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 24 | + """ | ||
| 25 | + Args: | ||
| 26 | + wave_filename: | ||
| 27 | + Path to a wave file. It should be single channel and each sample should | ||
| 28 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 29 | + Returns: | ||
| 30 | + Return a tuple containing: | ||
| 31 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 32 | + normalized to the range [-1, 1]. | ||
| 33 | + - sample rate of the wave file | ||
| 34 | + """ | ||
| 35 | + | ||
| 36 | + with wave.open(wave_filename) as f: | ||
| 37 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 38 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 39 | + num_samples = f.getnframes() | ||
| 40 | + samples = f.readframes(num_samples) | ||
| 41 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 42 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 43 | + | ||
| 44 | + samples_float32 = samples_float32 / 32768 | ||
| 45 | + return samples_float32, f.getframerate() | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +class TestOnlineRecognizer(unittest.TestCase): | ||
| 49 | + def test_transducer_single_file(self): | ||
| 50 | + for use_int8 in [True, False]: | ||
| 51 | + if use_int8: | ||
| 52 | + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" | ||
| 53 | + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx" | ||
| 54 | + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" | ||
| 55 | + else: | ||
| 56 | + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" | ||
| 57 | + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" | ||
| 58 | + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" | ||
| 59 | + | ||
| 60 | + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" | ||
| 61 | + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" | ||
| 62 | + | ||
| 63 | + if not Path(encoder).is_file(): | ||
| 64 | + print("skipping test_transducer_single_file()") | ||
| 65 | + return | ||
| 66 | + | ||
| 67 | + for decoding_method in ["greedy_search", "modified_beam_search"]: | ||
| 68 | + recognizer = sherpa_onnx.OnlineRecognizer( | ||
| 69 | + encoder=encoder, | ||
| 70 | + decoder=decoder, | ||
| 71 | + joiner=joiner, | ||
| 72 | + tokens=tokens, | ||
| 73 | + num_threads=1, | ||
| 74 | + decoding_method=decoding_method, | ||
| 75 | + ) | ||
| 76 | + s = recognizer.create_stream() | ||
| 77 | + samples, sample_rate = read_wave(wave0) | ||
| 78 | + s.accept_waveform(sample_rate, samples) | ||
| 79 | + | ||
| 80 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 81 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 82 | + | ||
| 83 | + s.input_finished() | ||
| 84 | + while recognizer.is_ready(s): | ||
| 85 | + recognizer.decode_stream(s) | ||
| 86 | + print(recognizer.get_result(s)) | ||
| 87 | + | ||
| 88 | + def test_transducer_multiple_files(self): | ||
| 89 | + for use_int8 in [True, False]: | ||
| 90 | + if use_int8: | ||
| 91 | + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" | ||
| 92 | + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.int8.onnx" | ||
| 93 | + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" | ||
| 94 | + else: | ||
| 95 | + encoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" | ||
| 96 | + decoder = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" | ||
| 97 | + joiner = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" | ||
| 98 | + | ||
| 99 | + tokens = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" | ||
| 100 | + wave0 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" | ||
| 101 | + wave1 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" | ||
| 102 | + wave2 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/2.wav" | ||
| 103 | + wave3 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/3.wav" | ||
| 104 | + wave4 = f"{d}/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/8k.wav" | ||
| 105 | + | ||
| 106 | + if not Path(encoder).is_file(): | ||
| 107 | + print("skipping test_transducer_multiple_files()") | ||
| 108 | + return | ||
| 109 | + | ||
| 110 | + for decoding_method in ["greedy_search", "modified_beam_search"]: | ||
| 111 | + recognizer = sherpa_onnx.OnlineRecognizer( | ||
| 112 | + encoder=encoder, | ||
| 113 | + decoder=decoder, | ||
| 114 | + joiner=joiner, | ||
| 115 | + tokens=tokens, | ||
| 116 | + num_threads=1, | ||
| 117 | + decoding_method=decoding_method, | ||
| 118 | + ) | ||
| 119 | + streams = [] | ||
| 120 | + waves = [wave0, wave1, wave2, wave3, wave4] | ||
| 121 | + for wave in waves: | ||
| 122 | + s = recognizer.create_stream() | ||
| 123 | + samples, sample_rate = read_wave(wave) | ||
| 124 | + s.accept_waveform(sample_rate, samples) | ||
| 125 | + | ||
| 126 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 127 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 128 | + s.input_finished() | ||
| 129 | + streams.append(s) | ||
| 130 | + | ||
| 131 | + while True: | ||
| 132 | + ready_list = [] | ||
| 133 | + for s in streams: | ||
| 134 | + if recognizer.is_ready(s): | ||
| 135 | + ready_list.append(s) | ||
| 136 | + if len(ready_list) == 0: | ||
| 137 | + break | ||
| 138 | + recognizer.decode_streams(ready_list) | ||
| 139 | + results = [recognizer.get_result(s) for s in streams] | ||
| 140 | + for wave_filename, result in zip(waves, results): | ||
| 141 | + print(f"{wave_filename}\n{result}") | ||
| 142 | + print("-" * 10) | ||
| 143 | + | ||
| 144 | + | ||
| 145 | +if __name__ == "__main__": | ||
| 146 | + unittest.main() |
-
请 注册 或 登录 后发表评论