Committed by
GitHub
Add Python API for keyword spotting (#576)
* Add alsa & microphone support for keyword spotting * Add python wrapper
正在显示
15 个修改的文件
包含
1191 行增加
和
1 行删除
| @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data | @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data | ||
| 293 | python3 sherpa-onnx/python/tests/test_text2token.py --verbose | 293 | python3 sherpa-onnx/python/tests/test_text2token.py --verbose |
| 294 | 294 | ||
| 295 | rm -rf /tmp/sherpa-test-data | 295 | rm -rf /tmp/sherpa-test-data |
| 296 | + | ||
| 297 | +mkdir -p /tmp/onnx-models | ||
| 298 | +dir=/tmp/onnx-models | ||
| 299 | + | ||
| 300 | +log "Test keyword spotting models" | ||
| 301 | + | ||
| 302 | +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" | ||
| 303 | +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)") | ||
| 304 | + | ||
| 305 | +echo "sherpa_onnx version: $sherpa_onnx_version" | ||
| 306 | + | ||
| 307 | +pwd | ||
| 308 | +ls -lh | ||
| 309 | + | ||
| 310 | +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 | ||
| 311 | +log "Start testing ${repo}" | ||
| 312 | + | ||
| 313 | +pushd $dir | ||
| 314 | +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz | ||
| 315 | +tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz | ||
| 316 | +popd | ||
| 317 | + | ||
| 318 | +repo=$dir/$repo | ||
| 319 | +ls -lh $repo | ||
| 320 | + | ||
| 321 | +python3 ./python-api-examples/keyword-spotter.py \ | ||
| 322 | + --tokens=$repo/tokens.txt \ | ||
| 323 | + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 324 | + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 325 | + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 326 | + --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 327 | + $repo/test_wavs/0.wav \ | ||
| 328 | + $repo/test_wavs/1.wav | ||
| 329 | + | ||
| 330 | +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 | ||
| 331 | +log "Start testing ${repo}" | ||
| 332 | + | ||
| 333 | +pushd $dir | ||
| 334 | +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz | ||
| 335 | +tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz | ||
| 336 | +popd | ||
| 337 | + | ||
| 338 | +repo=$dir/$repo | ||
| 339 | +ls -lh $repo | ||
| 340 | + | ||
| 341 | +python3 ./python-api-examples/keyword-spotter.py \ | ||
| 342 | + --tokens=$repo/tokens.txt \ | ||
| 343 | + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 344 | + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 345 | + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 346 | + --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 347 | + $repo/test_wavs/3.wav \ | ||
| 348 | + $repo/test_wavs/4.wav \ | ||
| 349 | + $repo/test_wavs/5.wav | ||
| 350 | + | ||
| 351 | +python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose | ||
| 352 | + | ||
| 353 | +rm -r $dir |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +# Real-time keyword spotting from a microphone with sherpa-onnx Python API | ||
| 4 | +# | ||
| 5 | +# Please refer to | ||
| 6 | +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 7 | +# to download pre-trained models | ||
| 8 | + | ||
| 9 | +import argparse | ||
| 10 | +import sys | ||
| 11 | +from pathlib import Path | ||
| 12 | + | ||
| 13 | +from typing import List | ||
| 14 | + | ||
| 15 | +try: | ||
| 16 | + import sounddevice as sd | ||
| 17 | +except ImportError: | ||
| 18 | + print("Please install sounddevice first. You can use") | ||
| 19 | + print() | ||
| 20 | + print(" pip install sounddevice") | ||
| 21 | + print() | ||
| 22 | + print("to install it") | ||
| 23 | + sys.exit(-1) | ||
| 24 | + | ||
| 25 | +import sherpa_onnx | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def assert_file_exists(filename: str): | ||
| 29 | + assert Path(filename).is_file(), ( | ||
| 30 | + f"{filename} does not exist!\n" | ||
| 31 | + "Please refer to " | ||
| 32 | + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" | ||
| 33 | + ) | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def get_args(): | ||
| 37 | + parser = argparse.ArgumentParser( | ||
| 38 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 39 | + ) | ||
| 40 | + | ||
| 41 | + parser.add_argument( | ||
| 42 | + "--tokens", | ||
| 43 | + type=str, | ||
| 44 | + help="Path to tokens.txt", | ||
| 45 | + ) | ||
| 46 | + | ||
| 47 | + parser.add_argument( | ||
| 48 | + "--encoder", | ||
| 49 | + type=str, | ||
| 50 | + help="Path to the transducer encoder model", | ||
| 51 | + ) | ||
| 52 | + | ||
| 53 | + parser.add_argument( | ||
| 54 | + "--decoder", | ||
| 55 | + type=str, | ||
| 56 | + help="Path to the transducer decoder model", | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + parser.add_argument( | ||
| 60 | + "--joiner", | ||
| 61 | + type=str, | ||
| 62 | + help="Path to the transducer joiner model", | ||
| 63 | + ) | ||
| 64 | + | ||
| 65 | + parser.add_argument( | ||
| 66 | + "--num-threads", | ||
| 67 | + type=int, | ||
| 68 | + default=1, | ||
| 69 | + help="Number of threads for neural network computation", | ||
| 70 | + ) | ||
| 71 | + | ||
| 72 | + parser.add_argument( | ||
| 73 | + "--provider", | ||
| 74 | + type=str, | ||
| 75 | + default="cpu", | ||
| 76 | + help="Valid values: cpu, cuda, coreml", | ||
| 77 | + ) | ||
| 78 | + | ||
| 79 | + parser.add_argument( | ||
| 80 | + "--max-active-paths", | ||
| 81 | + type=int, | ||
| 82 | + default=4, | ||
| 83 | + help=""" | ||
| 84 | + It specifies number of active paths to keep during decoding. | ||
| 85 | + """, | ||
| 86 | + ) | ||
| 87 | + | ||
| 88 | + parser.add_argument( | ||
| 89 | + "--num-trailing-blanks", | ||
| 90 | + type=int, | ||
| 91 | + default=1, | ||
| 92 | + help="""The number of trailing blanks a keyword should be followed. Setting | ||
| 93 | + to a larger value (e.g. 8) when your keywords has overlapping tokens | ||
| 94 | + between each other. | ||
| 95 | + """, | ||
| 96 | + ) | ||
| 97 | + | ||
| 98 | + parser.add_argument( | ||
| 99 | + "--keywords-file", | ||
| 100 | + type=str, | ||
| 101 | + help=""" | ||
| 102 | + The file containing keywords, one words/phrases per line, and for each | ||
| 103 | + phrase the bpe/cjkchar/pinyin are separated by a space. For example: | ||
| 104 | + | ||
| 105 | + ▁HE LL O ▁WORLD | ||
| 106 | + x iǎo ài t óng x ué | ||
| 107 | + """, | ||
| 108 | + ) | ||
| 109 | + | ||
| 110 | + parser.add_argument( | ||
| 111 | + "--keywords-score", | ||
| 112 | + type=float, | ||
| 113 | + default=1.0, | ||
| 114 | + help=""" | ||
| 115 | + The boosting score of each token for keywords. The larger the easier to | ||
| 116 | + survive beam search. | ||
| 117 | + """, | ||
| 118 | + ) | ||
| 119 | + | ||
| 120 | + parser.add_argument( | ||
| 121 | + "--keywords-threshold", | ||
| 122 | + type=float, | ||
| 123 | + default=0.25, | ||
| 124 | + help=""" | ||
| 125 | + The trigger threshold (i.e. probability) of the keyword. The larger the | ||
| 126 | + harder to trigger. | ||
| 127 | + """, | ||
| 128 | + ) | ||
| 129 | + | ||
| 130 | + return parser.parse_args() | ||
| 131 | + | ||
| 132 | + | ||
| 133 | +def main(): | ||
| 134 | + args = get_args() | ||
| 135 | + | ||
| 136 | + devices = sd.query_devices() | ||
| 137 | + if len(devices) == 0: | ||
| 138 | + print("No microphone devices found") | ||
| 139 | + sys.exit(0) | ||
| 140 | + | ||
| 141 | + print(devices) | ||
| 142 | + default_input_device_idx = sd.default.device[0] | ||
| 143 | + print(f'Use default device: {devices[default_input_device_idx]["name"]}') | ||
| 144 | + | ||
| 145 | + assert_file_exists(args.tokens) | ||
| 146 | + assert_file_exists(args.encoder) | ||
| 147 | + assert_file_exists(args.decoder) | ||
| 148 | + assert_file_exists(args.joiner) | ||
| 149 | + | ||
| 150 | + assert Path( | ||
| 151 | + args.keywords_file | ||
| 152 | + ).is_file(), ( | ||
| 153 | + f"keywords_file : {args.keywords_file} not exist, please provide a valid path." | ||
| 154 | + ) | ||
| 155 | + | ||
| 156 | + keyword_spotter = sherpa_onnx.KeywordSpotter( | ||
| 157 | + tokens=args.tokens, | ||
| 158 | + encoder=args.encoder, | ||
| 159 | + decoder=args.decoder, | ||
| 160 | + joiner=args.joiner, | ||
| 161 | + num_threads=args.num_threads, | ||
| 162 | + max_active_paths=args.max_active_paths, | ||
| 163 | + keywords_file=args.keywords_file, | ||
| 164 | + keywords_score=args.keywords_score, | ||
| 165 | + keywords_threshold=args.keywords_threshold, | ||
| 166 | + num_tailing_blanks=args.rnum_tailing_blanks, | ||
| 167 | + provider=args.provider, | ||
| 168 | + ) | ||
| 169 | + | ||
| 170 | + print("Started! Please speak") | ||
| 171 | + | ||
| 172 | + sample_rate = 16000 | ||
| 173 | + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | ||
| 174 | + stream = keyword_spotter.create_stream() | ||
| 175 | + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: | ||
| 176 | + while True: | ||
| 177 | + samples, _ = s.read(samples_per_read) # a blocking read | ||
| 178 | + samples = samples.reshape(-1) | ||
| 179 | + stream.accept_waveform(sample_rate, samples) | ||
| 180 | + while keyword_spotter.is_ready(stream): | ||
| 181 | + keyword_spotter.decode_stream(stream) | ||
| 182 | + result = keyword_spotter.get_result(stream) | ||
| 183 | + if result: | ||
| 184 | + print("\r{}".format(result), end="", flush=True) | ||
| 185 | + | ||
| 186 | + | ||
| 187 | +if __name__ == "__main__": | ||
| 188 | + try: | ||
| 189 | + main() | ||
| 190 | + except KeyboardInterrupt: | ||
| 191 | + print("\nCaught Ctrl + C. Exiting") |
python-api-examples/keyword-spotter.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file demonstrates how to use sherpa-onnx Python API to do keyword spotting | ||
| 5 | +from wave file(s). | ||
| 6 | + | ||
| 7 | +Please refer to | ||
| 8 | +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 9 | +to download pre-trained models. | ||
| 10 | +""" | ||
| 11 | +import argparse | ||
| 12 | +import time | ||
| 13 | +import wave | ||
| 14 | +from pathlib import Path | ||
| 15 | +from typing import List, Tuple | ||
| 16 | + | ||
| 17 | +import numpy as np | ||
| 18 | +import sherpa_onnx | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def get_args(): | ||
| 22 | + parser = argparse.ArgumentParser( | ||
| 23 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 24 | + ) | ||
| 25 | + | ||
| 26 | + parser.add_argument( | ||
| 27 | + "--tokens", | ||
| 28 | + type=str, | ||
| 29 | + help="Path to tokens.txt", | ||
| 30 | + ) | ||
| 31 | + | ||
| 32 | + parser.add_argument( | ||
| 33 | + "--encoder", | ||
| 34 | + type=str, | ||
| 35 | + help="Path to the transducer encoder model", | ||
| 36 | + ) | ||
| 37 | + | ||
| 38 | + parser.add_argument( | ||
| 39 | + "--decoder", | ||
| 40 | + type=str, | ||
| 41 | + help="Path to the transducer decoder model", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument( | ||
| 45 | + "--joiner", | ||
| 46 | + type=str, | ||
| 47 | + help="Path to the transducer joiner model", | ||
| 48 | + ) | ||
| 49 | + | ||
| 50 | + parser.add_argument( | ||
| 51 | + "--num-threads", | ||
| 52 | + type=int, | ||
| 53 | + default=1, | ||
| 54 | + help="Number of threads for neural network computation", | ||
| 55 | + ) | ||
| 56 | + | ||
| 57 | + parser.add_argument( | ||
| 58 | + "--provider", | ||
| 59 | + type=str, | ||
| 60 | + default="cpu", | ||
| 61 | + help="Valid values: cpu, cuda, coreml", | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + parser.add_argument( | ||
| 65 | + "--max-active-paths", | ||
| 66 | + type=int, | ||
| 67 | + default=4, | ||
| 68 | + help=""" | ||
| 69 | + It specifies number of active paths to keep during decoding. | ||
| 70 | + """, | ||
| 71 | + ) | ||
| 72 | + | ||
| 73 | + parser.add_argument( | ||
| 74 | + "--num-trailing-blanks", | ||
| 75 | + type=int, | ||
| 76 | + default=1, | ||
| 77 | + help="""The number of trailing blanks a keyword should be followed. Setting | ||
| 78 | + to a larger value (e.g. 8) when your keywords has overlapping tokens | ||
| 79 | + between each other. | ||
| 80 | + """, | ||
| 81 | + ) | ||
| 82 | + | ||
| 83 | + parser.add_argument( | ||
| 84 | + "--keywords-file", | ||
| 85 | + type=str, | ||
| 86 | + help=""" | ||
| 87 | + The file containing keywords, one words/phrases per line, and for each | ||
| 88 | + phrase the bpe/cjkchar/pinyin are separated by a space. For example: | ||
| 89 | + | ||
| 90 | + ▁HE LL O ▁WORLD | ||
| 91 | + x iǎo ài t óng x ué | ||
| 92 | + """, | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + parser.add_argument( | ||
| 96 | + "--keywords-score", | ||
| 97 | + type=float, | ||
| 98 | + default=1.0, | ||
| 99 | + help=""" | ||
| 100 | + The boosting score of each token for keywords. The larger the easier to | ||
| 101 | + survive beam search. | ||
| 102 | + """, | ||
| 103 | + ) | ||
| 104 | + | ||
| 105 | + parser.add_argument( | ||
| 106 | + "--keywords-threshold", | ||
| 107 | + type=float, | ||
| 108 | + default=0.25, | ||
| 109 | + help=""" | ||
| 110 | + The trigger threshold (i.e. probability) of the keyword. The larger the | ||
| 111 | + harder to trigger. | ||
| 112 | + """, | ||
| 113 | + ) | ||
| 114 | + | ||
| 115 | + parser.add_argument( | ||
| 116 | + "sound_files", | ||
| 117 | + type=str, | ||
| 118 | + nargs="+", | ||
| 119 | + help="The input sound file(s) to decode. Each file must be of WAVE" | ||
| 120 | + "format with a single channel, and each sample has 16-bit, " | ||
| 121 | + "i.e., int16_t. " | ||
| 122 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 123 | + "be 16 kHz", | ||
| 124 | + ) | ||
| 125 | + | ||
| 126 | + return parser.parse_args() | ||
| 127 | + | ||
| 128 | + | ||
| 129 | +def assert_file_exists(filename: str): | ||
| 130 | + assert Path(filename).is_file(), ( | ||
| 131 | + f"{filename} does not exist!\n" | ||
| 132 | + "Please refer to " | ||
| 133 | + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" | ||
| 134 | + ) | ||
| 135 | + | ||
| 136 | + | ||
| 137 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 138 | + """ | ||
| 139 | + Args: | ||
| 140 | + wave_filename: | ||
| 141 | + Path to a wave file. It should be single channel and each sample should | ||
| 142 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 143 | + Returns: | ||
| 144 | + Return a tuple containing: | ||
| 145 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 146 | + normalized to the range [-1, 1]. | ||
| 147 | + - sample rate of the wave file | ||
| 148 | + """ | ||
| 149 | + | ||
| 150 | + with wave.open(wave_filename) as f: | ||
| 151 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 152 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 153 | + num_samples = f.getnframes() | ||
| 154 | + samples = f.readframes(num_samples) | ||
| 155 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 156 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 157 | + | ||
| 158 | + samples_float32 = samples_float32 / 32768 | ||
| 159 | + return samples_float32, f.getframerate() | ||
| 160 | + | ||
| 161 | + | ||
| 162 | +def main(): | ||
| 163 | + args = get_args() | ||
| 164 | + assert_file_exists(args.tokens) | ||
| 165 | + assert_file_exists(args.encoder) | ||
| 166 | + assert_file_exists(args.decoder) | ||
| 167 | + assert_file_exists(args.joiner) | ||
| 168 | + | ||
| 169 | + assert Path( | ||
| 170 | + args.keywords_file | ||
| 171 | + ).is_file(), ( | ||
| 172 | + f"keywords_file : {args.keywords_file} not exist, please provide a valid path." | ||
| 173 | + ) | ||
| 174 | + | ||
| 175 | + keyword_spotter = sherpa_onnx.KeywordSpotter( | ||
| 176 | + tokens=args.tokens, | ||
| 177 | + encoder=args.encoder, | ||
| 178 | + decoder=args.decoder, | ||
| 179 | + joiner=args.joiner, | ||
| 180 | + num_threads=args.num_threads, | ||
| 181 | + max_active_paths=args.max_active_paths, | ||
| 182 | + keywords_file=args.keywords_file, | ||
| 183 | + keywords_score=args.keywords_score, | ||
| 184 | + keywords_threshold=args.keywords_threshold, | ||
| 185 | + num_trailing_blanks=args.num_trailing_blanks, | ||
| 186 | + provider=args.provider, | ||
| 187 | + ) | ||
| 188 | + | ||
| 189 | + print("Started!") | ||
| 190 | + start_time = time.time() | ||
| 191 | + | ||
| 192 | + streams = [] | ||
| 193 | + total_duration = 0 | ||
| 194 | + for wave_filename in args.sound_files: | ||
| 195 | + assert_file_exists(wave_filename) | ||
| 196 | + samples, sample_rate = read_wave(wave_filename) | ||
| 197 | + duration = len(samples) / sample_rate | ||
| 198 | + total_duration += duration | ||
| 199 | + | ||
| 200 | + s = keyword_spotter.create_stream() | ||
| 201 | + | ||
| 202 | + s.accept_waveform(sample_rate, samples) | ||
| 203 | + | ||
| 204 | + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) | ||
| 205 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 206 | + | ||
| 207 | + s.input_finished() | ||
| 208 | + | ||
| 209 | + streams.append(s) | ||
| 210 | + | ||
| 211 | + results = [""] * len(streams) | ||
| 212 | + while True: | ||
| 213 | + ready_list = [] | ||
| 214 | + for i, s in enumerate(streams): | ||
| 215 | + if keyword_spotter.is_ready(s): | ||
| 216 | + ready_list.append(s) | ||
| 217 | + r = keyword_spotter.get_result(s) | ||
| 218 | + if r: | ||
| 219 | + results[i] += f"{r}/" | ||
| 220 | + print(f"{r} is detected.") | ||
| 221 | + if len(ready_list) == 0: | ||
| 222 | + break | ||
| 223 | + keyword_spotter.decode_streams(ready_list) | ||
| 224 | + end_time = time.time() | ||
| 225 | + print("Done!") | ||
| 226 | + | ||
| 227 | + for wave_filename, result in zip(args.sound_files, results): | ||
| 228 | + print(f"{wave_filename}\n{result}") | ||
| 229 | + print("-" * 10) | ||
| 230 | + | ||
| 231 | + elapsed_seconds = end_time - start_time | ||
| 232 | + rtf = elapsed_seconds / total_duration | ||
| 233 | + print(f"num_threads: {args.num_threads}") | ||
| 234 | + print(f"Wave duration: {total_duration:.3f} s") | ||
| 235 | + print(f"Elapsed time: {elapsed_seconds:.3f} s") | ||
| 236 | + print( | ||
| 237 | + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" | ||
| 238 | + ) | ||
| 239 | + | ||
| 240 | + | ||
| 241 | +if __name__ == "__main__": | ||
| 242 | + main() |
| @@ -230,12 +230,14 @@ endif() | @@ -230,12 +230,14 @@ endif() | ||
| 230 | 230 | ||
| 231 | if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) | 231 | if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) |
| 232 | add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) | 232 | add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) |
| 233 | + add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) | ||
| 233 | add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc) | 234 | add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc) |
| 234 | add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) | 235 | add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) |
| 235 | add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) | 236 | add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) |
| 236 | 237 | ||
| 237 | set(exes | 238 | set(exes |
| 238 | sherpa-onnx-alsa | 239 | sherpa-onnx-alsa |
| 240 | + sherpa-onnx-keyword-spotter-alsa | ||
| 239 | sherpa-onnx-alsa-offline | 241 | sherpa-onnx-alsa-offline |
| 240 | sherpa-onnx-offline-tts-play-alsa | 242 | sherpa-onnx-offline-tts-play-alsa |
| 241 | sherpa-onnx-alsa-offline-speaker-identification | 243 | sherpa-onnx-alsa-offline-speaker-identification |
| @@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) | @@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) | ||
| 278 | microphone.cc | 280 | microphone.cc |
| 279 | ) | 281 | ) |
| 280 | 282 | ||
| 283 | + add_executable(sherpa-onnx-keyword-spotter-microphone | ||
| 284 | + sherpa-onnx-keyword-spotter-microphone.cc | ||
| 285 | + microphone.cc | ||
| 286 | + ) | ||
| 287 | + | ||
| 281 | add_executable(sherpa-onnx-microphone | 288 | add_executable(sherpa-onnx-microphone |
| 282 | sherpa-onnx-microphone.cc | 289 | sherpa-onnx-microphone.cc |
| 283 | microphone.cc | 290 | microphone.cc |
| @@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) | @@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) | ||
| 311 | 318 | ||
| 312 | set(exes | 319 | set(exes |
| 313 | sherpa-onnx-microphone | 320 | sherpa-onnx-microphone |
| 321 | + sherpa-onnx-keyword-spotter-microphone | ||
| 314 | sherpa-onnx-microphone-offline | 322 | sherpa-onnx-microphone-offline |
| 315 | sherpa-onnx-microphone-offline-speaker-identification | 323 | sherpa-onnx-microphone-offline-speaker-identification |
| 316 | sherpa-onnx-offline-tts-play | 324 | sherpa-onnx-offline-tts-play |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#include <signal.h> | ||
| 5 | +#include <stdio.h> | ||
| 6 | +#include <stdlib.h> | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <cstdint> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/alsa.h" | ||
| 12 | +#include "sherpa-onnx/csrc/display.h" | ||
| 13 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 14 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 15 | + | ||
| 16 | +bool stop = false; | ||
| 17 | + | ||
| 18 | +static void Handler(int sig) { | ||
| 19 | + stop = true; | ||
| 20 | + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +int main(int32_t argc, char *argv[]) { | ||
| 24 | + signal(SIGINT, Handler); | ||
| 25 | + | ||
| 26 | + const char *kUsageMessage = R"usage( | ||
| 27 | +Usage: | ||
| 28 | + ./bin/sherpa-onnx-keyword-spotter-alsa \ | ||
| 29 | + --tokens=/path/to/tokens.txt \ | ||
| 30 | + --encoder=/path/to/encoder.onnx \ | ||
| 31 | + --decoder=/path/to/decoder.onnx \ | ||
| 32 | + --joiner=/path/to/joiner.onnx \ | ||
| 33 | + --provider=cpu \ | ||
| 34 | + --num-threads=2 \ | ||
| 35 | + --keywords-file=keywords.txt \ | ||
| 36 | + device_name | ||
| 37 | + | ||
| 38 | +Please refer to | ||
| 39 | +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 40 | +for a list of pre-trained models to download. | ||
| 41 | + | ||
| 42 | +The device name specifies which microphone to use in case there are several | ||
| 43 | +on you system. You can use | ||
| 44 | + | ||
| 45 | + arecord -l | ||
| 46 | + | ||
| 47 | +to find all available microphones on your computer. For instance, if it outputs | ||
| 48 | + | ||
| 49 | +**** List of CAPTURE Hardware Devices **** | ||
| 50 | +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] | ||
| 51 | + Subdevices: 1/1 | ||
| 52 | + Subdevice #0: subdevice #0 | ||
| 53 | + | ||
| 54 | +and if you want to select card 3 and the device 0 on that card, please use: | ||
| 55 | + | ||
| 56 | + hw:3,0 | ||
| 57 | + | ||
| 58 | +or | ||
| 59 | + | ||
| 60 | + plughw:3,0 | ||
| 61 | + | ||
| 62 | +as the device_name. | ||
| 63 | +)usage"; | ||
| 64 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 65 | + sherpa_onnx::KeywordSpotterConfig config; | ||
| 66 | + | ||
| 67 | + config.Register(&po); | ||
| 68 | + | ||
| 69 | + po.Read(argc, argv); | ||
| 70 | + if (po.NumArgs() != 1) { | ||
| 71 | + fprintf(stderr, "Please provide only 1 argument: the device name\n"); | ||
| 72 | + po.PrintUsage(); | ||
| 73 | + exit(EXIT_FAILURE); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 77 | + | ||
| 78 | + if (!config.Validate()) { | ||
| 79 | + fprintf(stderr, "Errors in config!\n"); | ||
| 80 | + return -1; | ||
| 81 | + } | ||
| 82 | + sherpa_onnx::KeywordSpotter spotter(config); | ||
| 83 | + | ||
| 84 | + int32_t expected_sample_rate = config.feat_config.sampling_rate; | ||
| 85 | + | ||
| 86 | + std::string device_name = po.GetArg(1); | ||
| 87 | + sherpa_onnx::Alsa alsa(device_name.c_str()); | ||
| 88 | + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); | ||
| 89 | + | ||
| 90 | + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { | ||
| 91 | + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), | ||
| 92 | + expected_sample_rate); | ||
| 93 | + exit(-1); | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); | ||
| 97 | + | ||
| 98 | + std::string last_text; | ||
| 99 | + | ||
| 100 | + auto stream = spotter.CreateStream(); | ||
| 101 | + | ||
| 102 | + sherpa_onnx::Display display; | ||
| 103 | + | ||
| 104 | + int32_t keyword_index = 0; | ||
| 105 | + while (!stop) { | ||
| 106 | + const std::vector<float> &samples = alsa.Read(chunk); | ||
| 107 | + | ||
| 108 | + stream->AcceptWaveform(expected_sample_rate, samples.data(), | ||
| 109 | + samples.size()); | ||
| 110 | + | ||
| 111 | + while (spotter.IsReady(stream.get())) { | ||
| 112 | + spotter.DecodeStream(stream.get()); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + const auto r = spotter.GetResult(stream.get()); | ||
| 116 | + if (!r.keyword.empty()) { | ||
| 117 | + display.Print(keyword_index, r.AsJsonString()); | ||
| 118 | + fflush(stderr); | ||
| 119 | + keyword_index++; | ||
| 120 | + } | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + return 0; | ||
| 124 | +} |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <signal.h> | ||
| 6 | +#include <stdio.h> | ||
| 7 | +#include <stdlib.h> | ||
| 8 | + | ||
| 9 | +#include <algorithm> | ||
| 10 | + | ||
| 11 | +#include "portaudio.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/display.h" | ||
| 13 | +#include "sherpa-onnx/csrc/microphone.h" | ||
| 14 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 15 | + | ||
| 16 | +bool stop = false; | ||
| 17 | + | ||
| 18 | +static int32_t RecordCallback(const void *input_buffer, | ||
| 19 | + void * /*output_buffer*/, | ||
| 20 | + unsigned long frames_per_buffer, // NOLINT | ||
| 21 | + const PaStreamCallbackTimeInfo * /*time_info*/, | ||
| 22 | + PaStreamCallbackFlags /*status_flags*/, | ||
| 23 | + void *user_data) { | ||
| 24 | + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(user_data); | ||
| 25 | + | ||
| 26 | + stream->AcceptWaveform(16000, reinterpret_cast<const float *>(input_buffer), | ||
| 27 | + frames_per_buffer); | ||
| 28 | + | ||
| 29 | + return stop ? paComplete : paContinue; | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +static void Handler(int32_t sig) { | ||
| 33 | + stop = true; | ||
| 34 | + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); | ||
| 35 | +} | ||
| 36 | + | ||
| 37 | +int32_t main(int32_t argc, char *argv[]) { | ||
| 38 | + signal(SIGINT, Handler); | ||
| 39 | + | ||
| 40 | + const char *kUsageMessage = R"usage( | ||
| 41 | +This program uses streaming models with microphone for keyword spotting. | ||
| 42 | +Usage: | ||
| 43 | + | ||
| 44 | + ./bin/sherpa-onnx-keyword-spotter-microphone \ | ||
| 45 | + --tokens=/path/to/tokens.txt \ | ||
| 46 | + --encoder=/path/to/encoder.onnx \ | ||
| 47 | + --decoder=/path/to/decoder.onnx \ | ||
| 48 | + --joiner=/path/to/joiner.onnx \ | ||
| 49 | + --provider=cpu \ | ||
| 50 | + --num-threads=1 \ | ||
| 51 | + --keywords-file=keywords.txt | ||
| 52 | + | ||
| 53 | +Please refer to | ||
| 54 | +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 55 | +for a list of pre-trained models to download. | ||
| 56 | +)usage"; | ||
| 57 | + | ||
| 58 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 59 | + sherpa_onnx::KeywordSpotterConfig config; | ||
| 60 | + | ||
| 61 | + config.Register(&po); | ||
| 62 | + po.Read(argc, argv); | ||
| 63 | + if (po.NumArgs() != 0) { | ||
| 64 | + po.PrintUsage(); | ||
| 65 | + exit(EXIT_FAILURE); | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 69 | + | ||
| 70 | + if (!config.Validate()) { | ||
| 71 | + fprintf(stderr, "Errors in config!\n"); | ||
| 72 | + return -1; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + sherpa_onnx::KeywordSpotter spotter(config); | ||
| 76 | + auto s = spotter.CreateStream(); | ||
| 77 | + | ||
| 78 | + sherpa_onnx::Microphone mic; | ||
| 79 | + | ||
| 80 | + PaDeviceIndex num_devices = Pa_GetDeviceCount(); | ||
| 81 | + fprintf(stderr, "Num devices: %d\n", num_devices); | ||
| 82 | + | ||
| 83 | + PaStreamParameters param; | ||
| 84 | + | ||
| 85 | + param.device = Pa_GetDefaultInputDevice(); | ||
| 86 | + if (param.device == paNoDevice) { | ||
| 87 | + fprintf(stderr, "No default input device found\n"); | ||
| 88 | + exit(EXIT_FAILURE); | ||
| 89 | + } | ||
| 90 | + fprintf(stderr, "Use default device: %d\n", param.device); | ||
| 91 | + | ||
| 92 | + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); | ||
| 93 | + fprintf(stderr, " Name: %s\n", info->name); | ||
| 94 | + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); | ||
| 95 | + | ||
| 96 | + param.channelCount = 1; | ||
| 97 | + param.sampleFormat = paFloat32; | ||
| 98 | + | ||
| 99 | + param.suggestedLatency = info->defaultLowInputLatency; | ||
| 100 | + param.hostApiSpecificStreamInfo = nullptr; | ||
| 101 | + float sample_rate = 16000; | ||
| 102 | + | ||
| 103 | + PaStream *stream; | ||
| 104 | + PaError err = | ||
| 105 | + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ | ||
| 106 | + sample_rate, | ||
| 107 | + 0, // frames per buffer | ||
| 108 | + paClipOff, // we won't output out of range samples | ||
| 109 | + // so don't bother clipping them | ||
| 110 | + RecordCallback, s.get()); | ||
| 111 | + if (err != paNoError) { | ||
| 112 | + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); | ||
| 113 | + exit(EXIT_FAILURE); | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + err = Pa_StartStream(stream); | ||
| 117 | + fprintf(stderr, "Started\n"); | ||
| 118 | + | ||
| 119 | + if (err != paNoError) { | ||
| 120 | + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); | ||
| 121 | + exit(EXIT_FAILURE); | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + int32_t keyword_index = 0; | ||
| 125 | + sherpa_onnx::Display display; | ||
| 126 | + while (!stop) { | ||
| 127 | + while (spotter.IsReady(s.get())) { | ||
| 128 | + spotter.DecodeStream(s.get()); | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + const auto r = spotter.GetResult(s.get()); | ||
| 132 | + if (!r.keyword.empty()) { | ||
| 133 | + display.Print(keyword_index, r.AsJsonString()); | ||
| 134 | + fflush(stderr); | ||
| 135 | + keyword_index++; | ||
| 136 | + } | ||
| 137 | + | ||
| 138 | + Pa_Sleep(20); // sleep for 20ms | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + err = Pa_CloseStream(stream); | ||
| 142 | + if (err != paNoError) { | ||
| 143 | + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); | ||
| 144 | + exit(EXIT_FAILURE); | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + return 0; | ||
| 148 | +} |
| @@ -12,7 +12,6 @@ | @@ -12,7 +12,6 @@ | ||
| 12 | #include "sherpa-onnx/csrc/keyword-spotter.h" | 12 | #include "sherpa-onnx/csrc/keyword-spotter.h" |
| 13 | #include "sherpa-onnx/csrc/online-stream.h" | 13 | #include "sherpa-onnx/csrc/online-stream.h" |
| 14 | #include "sherpa-onnx/csrc/parse-options.h" | 14 | #include "sherpa-onnx/csrc/parse-options.h" |
| 15 | -#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 16 | #include "sherpa-onnx/csrc/wave-reader.h" | 15 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 17 | 16 | ||
| 18 | typedef struct { | 17 | typedef struct { |
| @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 5 | display.cc | 5 | display.cc |
| 6 | endpoint.cc | 6 | endpoint.cc |
| 7 | features.cc | 7 | features.cc |
| 8 | + keyword-spotter.cc | ||
| 8 | offline-ctc-fst-decoder-config.cc | 9 | offline-ctc-fst-decoder-config.cc |
| 9 | offline-lm-config.cc | 10 | offline-lm-config.cc |
| 10 | offline-model-config.cc | 11 | offline-model-config.cc |
sherpa-onnx/python/csrc/keyword-spotter.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/keyword-spotter.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/keyword-spotter.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +static void PybindKeywordResult(py::module *m) { | ||
| 15 | + using PyClass = KeywordResult; | ||
| 16 | + py::class_<PyClass>(*m, "KeywordResult") | ||
| 17 | + .def_property_readonly( | ||
| 18 | + "keyword", | ||
| 19 | + [](PyClass &self) -> py::str { | ||
| 20 | + return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(), | ||
| 21 | + self.keyword.size(), "ignore")); | ||
| 22 | + }) | ||
| 23 | + .def_property_readonly( | ||
| 24 | + "tokens", | ||
| 25 | + [](PyClass &self) -> std::vector<std::string> { return self.tokens; }) | ||
| 26 | + .def_property_readonly( | ||
| 27 | + "timestamps", | ||
| 28 | + [](PyClass &self) -> std::vector<float> { return self.timestamps; }); | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +static void PybindKeywordSpotterConfig(py::module *m) { | ||
| 32 | + using PyClass = KeywordSpotterConfig; | ||
| 33 | + py::class_<PyClass>(*m, "KeywordSpotterConfig") | ||
| 34 | + .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | ||
| 35 | + int32_t, int32_t, float, float, const std::string &>(), | ||
| 36 | + py::arg("feat_config"), py::arg("model_config"), | ||
| 37 | + py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1, | ||
| 38 | + py::arg("keywords_score") = 1.0, | ||
| 39 | + py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "") | ||
| 40 | + .def_readwrite("feat_config", &PyClass::feat_config) | ||
| 41 | + .def_readwrite("model_config", &PyClass::model_config) | ||
| 42 | + .def_readwrite("max_active_paths", &PyClass::max_active_paths) | ||
| 43 | + .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks) | ||
| 44 | + .def_readwrite("keywords_score", &PyClass::keywords_score) | ||
| 45 | + .def_readwrite("keywords_threshold", &PyClass::keywords_threshold) | ||
| 46 | + .def_readwrite("keywords_file", &PyClass::keywords_file) | ||
| 47 | + .def("__str__", &PyClass::ToString); | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +void PybindKeywordSpotter(py::module *m) { | ||
| 51 | + PybindKeywordResult(m); | ||
| 52 | + PybindKeywordSpotterConfig(m); | ||
| 53 | + | ||
| 54 | + using PyClass = KeywordSpotter; | ||
| 55 | + py::class_<PyClass>(*m, "KeywordSpotter") | ||
| 56 | + .def(py::init<const KeywordSpotterConfig &>(), py::arg("config"), | ||
| 57 | + py::call_guard<py::gil_scoped_release>()) | ||
| 58 | + .def( | ||
| 59 | + "create_stream", | ||
| 60 | + [](const PyClass &self) { return self.CreateStream(); }, | ||
| 61 | + py::call_guard<py::gil_scoped_release>()) | ||
| 62 | + .def( | ||
| 63 | + "create_stream", | ||
| 64 | + [](PyClass &self, const std::string &keywords) { | ||
| 65 | + return self.CreateStream(keywords); | ||
| 66 | + }, | ||
| 67 | + py::arg("keywords"), py::call_guard<py::gil_scoped_release>()) | ||
| 68 | + .def("is_ready", &PyClass::IsReady, | ||
| 69 | + py::call_guard<py::gil_scoped_release>()) | ||
| 70 | + .def("decode_stream", &PyClass::DecodeStream, | ||
| 71 | + py::call_guard<py::gil_scoped_release>()) | ||
| 72 | + .def( | ||
| 73 | + "decode_streams", | ||
| 74 | + [](PyClass &self, std::vector<OnlineStream *> ss) { | ||
| 75 | + self.DecodeStreams(ss.data(), ss.size()); | ||
| 76 | + }, | ||
| 77 | + py::call_guard<py::gil_scoped_release>()) | ||
| 78 | + .def("get_result", &PyClass::GetResult, | ||
| 79 | + py::call_guard<py::gil_scoped_release>()); | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/keyword-spotter.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/keyword-spotter.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindKeywordSpotter(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include "sherpa-onnx/python/csrc/display.h" | 8 | #include "sherpa-onnx/python/csrc/display.h" |
| 9 | #include "sherpa-onnx/python/csrc/endpoint.h" | 9 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 10 | #include "sherpa-onnx/python/csrc/features.h" | 10 | #include "sherpa-onnx/python/csrc/features.h" |
| 11 | +#include "sherpa-onnx/python/csrc/keyword-spotter.h" | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 35 | PybindOnlineStream(&m); | 36 | PybindOnlineStream(&m); |
| 36 | PybindEndpoint(&m); | 37 | PybindEndpoint(&m); |
| 37 | PybindOnlineRecognizer(&m); | 38 | PybindOnlineRecognizer(&m); |
| 39 | + PybindKeywordSpotter(&m); | ||
| 38 | 40 | ||
| 39 | PybindDisplay(&m); | 41 | PybindDisplay(&m); |
| 40 | 42 |
| @@ -17,6 +17,7 @@ from _sherpa_onnx import ( | @@ -17,6 +17,7 @@ from _sherpa_onnx import ( | ||
| 17 | VoiceActivityDetector, | 17 | VoiceActivityDetector, |
| 18 | ) | 18 | ) |
| 19 | 19 | ||
| 20 | +from .keyword_spotter import KeywordSpotter | ||
| 20 | from .offline_recognizer import OfflineRecognizer | 21 | from .offline_recognizer import OfflineRecognizer |
| 21 | from .online_recognizer import OnlineRecognizer | 22 | from .online_recognizer import OnlineRecognizer |
| 22 | from .utils import text2token | 23 | from .utils import text2token |
| 1 | +# Copyright (c) 2023 Xiaomi Corporation | ||
| 2 | + | ||
| 3 | +from pathlib import Path | ||
| 4 | +from typing import List, Optional | ||
| 5 | + | ||
| 6 | +from _sherpa_onnx import ( | ||
| 7 | + FeatureExtractorConfig, | ||
| 8 | + KeywordSpotterConfig, | ||
| 9 | + OnlineModelConfig, | ||
| 10 | + OnlineTransducerModelConfig, | ||
| 11 | + OnlineStream, | ||
| 12 | +) | ||
| 13 | + | ||
| 14 | +from _sherpa_onnx import KeywordSpotter as _KeywordSpotter | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +def _assert_file_exists(f: str): | ||
| 18 | + assert Path(f).is_file(), f"{f} does not exist" | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +class KeywordSpotter(object): | ||
| 22 | + """A class for keyword spotting. | ||
| 23 | + | ||
| 24 | + Please refer to the following files for usages | ||
| 25 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py | ||
| 26 | + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter-from-microphone.py | ||
| 27 | + """ | ||
| 28 | + | ||
| 29 | + def __init__( | ||
| 30 | + self, | ||
| 31 | + tokens: str, | ||
| 32 | + encoder: str, | ||
| 33 | + decoder: str, | ||
| 34 | + joiner: str, | ||
| 35 | + keywords_file: str, | ||
| 36 | + num_threads: int = 2, | ||
| 37 | + sample_rate: float = 16000, | ||
| 38 | + feature_dim: int = 80, | ||
| 39 | + max_active_paths: int = 4, | ||
| 40 | + keywords_score: float = 1.0, | ||
| 41 | + keywords_threshold: float = 0.25, | ||
| 42 | + num_trailing_blanks: int = 1, | ||
| 43 | + provider: str = "cpu", | ||
| 44 | + ): | ||
| 45 | + """ | ||
| 46 | + Please refer to | ||
| 47 | + `<https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html>`_ | ||
| 48 | + to download pre-trained models for different languages, e.g., Chinese, | ||
| 49 | + English, etc. | ||
| 50 | + | ||
| 51 | + Args: | ||
| 52 | + tokens: | ||
| 53 | + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two | ||
| 54 | + columns:: | ||
| 55 | + | ||
| 56 | + symbol integer_id | ||
| 57 | + | ||
| 58 | + encoder: | ||
| 59 | + Path to ``encoder.onnx``. | ||
| 60 | + decoder: | ||
| 61 | + Path to ``decoder.onnx``. | ||
| 62 | + joiner: | ||
| 63 | + Path to ``joiner.onnx``. | ||
| 64 | + keywords_file: | ||
| 65 | + The file containing keywords, one word/phrase per line, and for each | ||
| 66 | + phrase the bpe/cjkchar/pinyin are separated by a space. | ||
| 67 | + num_threads: | ||
| 68 | + Number of threads for neural network computation. | ||
| 69 | + sample_rate: | ||
| 70 | + Sample rate of the training data used to train the model. | ||
| 71 | + feature_dim: | ||
| 72 | + Dimension of the feature used to train the model. | ||
| 73 | + max_active_paths: | ||
| 74 | + Use only when decoding_method is modified_beam_search. It specifies | ||
| 75 | + the maximum number of active paths during beam search. | ||
| 76 | + keywords_score: | ||
| 77 | + The boosting score of each token for keywords. The larger the easier to | ||
| 78 | + survive beam search. | ||
| 79 | + keywords_threshold: | ||
| 80 | + The trigger threshold (i.e. probability) of the keyword. The larger the | ||
| 81 | + harder to trigger. | ||
| 82 | + num_trailing_blanks: | ||
| 83 | + The number of trailing blanks a keyword should be followed. Setting | ||
| 84 | + to a larger value (e.g. 8) when your keywords has overlapping tokens | ||
| 85 | + between each other. | ||
| 86 | + provider: | ||
| 87 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 88 | + """ | ||
| 89 | + _assert_file_exists(tokens) | ||
| 90 | + _assert_file_exists(encoder) | ||
| 91 | + _assert_file_exists(decoder) | ||
| 92 | + _assert_file_exists(joiner) | ||
| 93 | + | ||
| 94 | + assert num_threads > 0, num_threads | ||
| 95 | + | ||
| 96 | + transducer_config = OnlineTransducerModelConfig( | ||
| 97 | + encoder=encoder, | ||
| 98 | + decoder=decoder, | ||
| 99 | + joiner=joiner, | ||
| 100 | + ) | ||
| 101 | + | ||
| 102 | + model_config = OnlineModelConfig( | ||
| 103 | + transducer=transducer_config, | ||
| 104 | + tokens=tokens, | ||
| 105 | + num_threads=num_threads, | ||
| 106 | + provider=provider, | ||
| 107 | + ) | ||
| 108 | + | ||
| 109 | + feat_config = FeatureExtractorConfig( | ||
| 110 | + sampling_rate=sample_rate, | ||
| 111 | + feature_dim=feature_dim, | ||
| 112 | + ) | ||
| 113 | + | ||
| 114 | + keywords_spotter_config = KeywordSpotterConfig( | ||
| 115 | + feat_config=feat_config, | ||
| 116 | + model_config=model_config, | ||
| 117 | + max_active_paths=max_active_paths, | ||
| 118 | + num_trailing_blanks=num_trailing_blanks, | ||
| 119 | + keywords_score=keywords_score, | ||
| 120 | + keywords_threshold=keywords_threshold, | ||
| 121 | + keywords_file=keywords_file, | ||
| 122 | + ) | ||
| 123 | + self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) | ||
| 124 | + | ||
| 125 | + def create_stream(self, keywords: Optional[str] = None): | ||
| 126 | + if keywords is None: | ||
| 127 | + return self.keyword_spotter.create_stream() | ||
| 128 | + else: | ||
| 129 | + return self.keyword_spotter.create_stream(keywords) | ||
| 130 | + | ||
| 131 | + def decode_stream(self, s: OnlineStream): | ||
| 132 | + self.keyword_spotter.decode_stream(s) | ||
| 133 | + | ||
| 134 | + def decode_streams(self, ss: List[OnlineStream]): | ||
| 135 | + self.keyword_spotter.decode_streams(ss) | ||
| 136 | + | ||
| 137 | + def is_ready(self, s: OnlineStream) -> bool: | ||
| 138 | + return self.keyword_spotter.is_ready(s) | ||
| 139 | + | ||
| 140 | + def get_result(self, s: OnlineStream) -> str: | ||
| 141 | + return self.keyword_spotter.get_result(s).keyword.strip() | ||
| 142 | + | ||
| 143 | + def tokens(self, s: OnlineStream) -> List[str]: | ||
| 144 | + return self.keyword_spotter.get_result(s).tokens | ||
| 145 | + | ||
| 146 | + def timestamps(self, s: OnlineStream) -> List[float]: | ||
| 147 | + return self.keyword_spotter.get_result(s).timestamps |
| @@ -20,6 +20,7 @@ endfunction() | @@ -20,6 +20,7 @@ endfunction() | ||
| 20 | # please sort the files in alphabetic order | 20 | # please sort the files in alphabetic order |
| 21 | set(py_test_files | 21 | set(py_test_files |
| 22 | test_feature_extractor_config.py | 22 | test_feature_extractor_config.py |
| 23 | + test_keyword_spotter.py | ||
| 23 | test_offline_recognizer.py | 24 | test_offline_recognizer.py |
| 24 | test_online_recognizer.py | 25 | test_online_recognizer.py |
| 25 | test_online_transducer_model_config.py | 26 | test_online_transducer_model_config.py |
| 1 | +# sherpa-onnx/python/tests/test_keyword_spotter.py | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +# | ||
| 5 | +# To run this single test, use | ||
| 6 | +# | ||
| 7 | +# ctest --verbose -R test_keyword_spotter_py | ||
| 8 | + | ||
| 9 | +import unittest | ||
| 10 | +import wave | ||
| 11 | +from pathlib import Path | ||
| 12 | +from typing import Tuple | ||
| 13 | + | ||
| 14 | +import numpy as np | ||
| 15 | +import sherpa_onnx | ||
| 16 | + | ||
| 17 | +d = "/tmp/onnx-models" | ||
| 18 | +# Please refer to | ||
| 19 | +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html | ||
| 20 | +# to download pre-trained models for testing | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 24 | + """ | ||
| 25 | + Args: | ||
| 26 | + wave_filename: | ||
| 27 | + Path to a wave file. It should be single channel and each sample should | ||
| 28 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 29 | + Returns: | ||
| 30 | + Return a tuple containing: | ||
| 31 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 32 | + normalized to the range [-1, 1]. | ||
| 33 | + - sample rate of the wave file | ||
| 34 | + """ | ||
| 35 | + | ||
| 36 | + with wave.open(wave_filename) as f: | ||
| 37 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 38 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 39 | + num_samples = f.getnframes() | ||
| 40 | + samples = f.readframes(num_samples) | ||
| 41 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 42 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 43 | + | ||
| 44 | + samples_float32 = samples_float32 / 32768 | ||
| 45 | + return samples_float32, f.getframerate() | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +class TestKeywordSpotter(unittest.TestCase): | ||
| 49 | + def test_zipformer_transducer_en(self): | ||
| 50 | + for use_int8 in [True, False]: | ||
| 51 | + if use_int8: | ||
| 52 | + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 53 | + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 54 | + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 55 | + else: | ||
| 56 | + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 57 | + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 58 | + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 59 | + | ||
| 60 | + tokens = ( | ||
| 61 | + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" | ||
| 62 | + ) | ||
| 63 | + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" | ||
| 64 | + wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" | ||
| 65 | + wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" | ||
| 66 | + | ||
| 67 | + if not Path(encoder).is_file(): | ||
| 68 | + print("skipping test_zipformer_transducer_en()") | ||
| 69 | + return | ||
| 70 | + keyword_spotter = sherpa_onnx.KeywordSpotter( | ||
| 71 | + encoder=encoder, | ||
| 72 | + decoder=decoder, | ||
| 73 | + joiner=joiner, | ||
| 74 | + tokens=tokens, | ||
| 75 | + num_threads=1, | ||
| 76 | + keywords_file=keywords_file, | ||
| 77 | + provider="cpu", | ||
| 78 | + ) | ||
| 79 | + streams = [] | ||
| 80 | + waves = [wave0, wave1] | ||
| 81 | + for wave in waves: | ||
| 82 | + s = keyword_spotter.create_stream() | ||
| 83 | + samples, sample_rate = read_wave(wave) | ||
| 84 | + s.accept_waveform(sample_rate, samples) | ||
| 85 | + | ||
| 86 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 87 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 88 | + s.input_finished() | ||
| 89 | + streams.append(s) | ||
| 90 | + | ||
| 91 | + results = [""] * len(streams) | ||
| 92 | + while True: | ||
| 93 | + ready_list = [] | ||
| 94 | + for i, s in enumerate(streams): | ||
| 95 | + if keyword_spotter.is_ready(s): | ||
| 96 | + ready_list.append(s) | ||
| 97 | + r = keyword_spotter.get_result(s) | ||
| 98 | + if r: | ||
| 99 | + print(f"{r} is detected.") | ||
| 100 | + results[i] += f"{r}/" | ||
| 101 | + if len(ready_list) == 0: | ||
| 102 | + break | ||
| 103 | + keyword_spotter.decode_streams(ready_list) | ||
| 104 | + for wave_filename, result in zip(waves, results): | ||
| 105 | + print(f"{wave_filename}\n{result[0:-1]}") | ||
| 106 | + print("-" * 10) | ||
| 107 | + | ||
| 108 | + def test_zipformer_transducer_cn(self): | ||
| 109 | + for use_int8 in [True, False]: | ||
| 110 | + if use_int8: | ||
| 111 | + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 112 | + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 113 | + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 114 | + else: | ||
| 115 | + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 116 | + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 117 | + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" | ||
| 118 | + | ||
| 119 | + tokens = ( | ||
| 120 | + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" | ||
| 121 | + ) | ||
| 122 | + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" | ||
| 123 | + wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" | ||
| 124 | + wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" | ||
| 125 | + wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav" | ||
| 126 | + | ||
| 127 | + if not Path(encoder).is_file(): | ||
| 128 | + print("skipping test_zipformer_transducer_cn()") | ||
| 129 | + return | ||
| 130 | + keyword_spotter = sherpa_onnx.KeywordSpotter( | ||
| 131 | + encoder=encoder, | ||
| 132 | + decoder=decoder, | ||
| 133 | + joiner=joiner, | ||
| 134 | + tokens=tokens, | ||
| 135 | + num_threads=1, | ||
| 136 | + keywords_file=keywords_file, | ||
| 137 | + provider="cpu", | ||
| 138 | + ) | ||
| 139 | + streams = [] | ||
| 140 | + waves = [wave0, wave1, wave2] | ||
| 141 | + for wave in waves: | ||
| 142 | + s = keyword_spotter.create_stream() | ||
| 143 | + samples, sample_rate = read_wave(wave) | ||
| 144 | + s.accept_waveform(sample_rate, samples) | ||
| 145 | + | ||
| 146 | + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 147 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 148 | + s.input_finished() | ||
| 149 | + streams.append(s) | ||
| 150 | + | ||
| 151 | + results = [""] * len(streams) | ||
| 152 | + while True: | ||
| 153 | + ready_list = [] | ||
| 154 | + for i, s in enumerate(streams): | ||
| 155 | + if keyword_spotter.is_ready(s): | ||
| 156 | + ready_list.append(s) | ||
| 157 | + r = keyword_spotter.get_result(s) | ||
| 158 | + if r: | ||
| 159 | + print(f"{r} is detected.") | ||
| 160 | + results[i] += f"{r}/" | ||
| 161 | + if len(ready_list) == 0: | ||
| 162 | + break | ||
| 163 | + keyword_spotter.decode_streams(ready_list) | ||
| 164 | + for wave_filename, result in zip(waves, results): | ||
| 165 | + print(f"{wave_filename}\n{result[0:-1]}") | ||
| 166 | + print("-" * 10) | ||
| 167 | + | ||
| 168 | + | ||
| 169 | +if __name__ == "__main__": | ||
| 170 | + unittest.main() |
-
请 注册 或 登录 后发表评论