Wei Kang
Committed by GitHub

Add Python API for keyword spotting (#576)

* Add alsa & microphone support for keyword spotting

* Add python wrapper
@@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
293 python3 sherpa-onnx/python/tests/test_text2token.py --verbose 293 python3 sherpa-onnx/python/tests/test_text2token.py --verbose
294 294
295 rm -rf /tmp/sherpa-test-data 295 rm -rf /tmp/sherpa-test-data
  296 +
  297 +mkdir -p /tmp/onnx-models
  298 +dir=/tmp/onnx-models
  299 +
  300 +log "Test keyword spotting models"
  301 +
  302 +python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
  303 +sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")
  304 +
  305 +echo "sherpa_onnx version: $sherpa_onnx_version"
  306 +
  307 +pwd
  308 +ls -lh
  309 +
  310 +repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
  311 +log "Start testing ${repo}"
  312 +
  313 +pushd $dir
  314 +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
  315 +tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
  316 +popd
  317 +
  318 +repo=$dir/$repo
  319 +ls -lh $repo
  320 +
  321 +python3 ./python-api-examples/keyword-spotter.py \
  322 + --tokens=$repo/tokens.txt \
  323 + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  324 + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  325 + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
  326 + --keywords-file=$repo/test_wavs/test_keywords.txt \
  327 + $repo/test_wavs/0.wav \
  328 + $repo/test_wavs/1.wav
  329 +
  330 +repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
  331 +log "Start testing ${repo}"
  332 +
  333 +pushd $dir
  334 +wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
  335 +tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
  336 +popd
  337 +
  338 +repo=$dir/$repo
  339 +ls -lh $repo
  340 +
  341 +python3 ./python-api-examples/keyword-spotter.py \
  342 + --tokens=$repo/tokens.txt \
  343 + --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  344 + --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
  345 + --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
  346 + --keywords-file=$repo/test_wavs/test_keywords.txt \
  347 + $repo/test_wavs/3.wav \
  348 + $repo/test_wavs/4.wav \
  349 + $repo/test_wavs/5.wav
  350 +
  351 +python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
  352 +
  353 +rm -r $dir
  1 +#!/usr/bin/env python3
  2 +
  3 +# Real-time keyword spotting from a microphone with sherpa-onnx Python API
  4 +#
  5 +# Please refer to
  6 +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  7 +# to download pre-trained models
  8 +
  9 +import argparse
  10 +import sys
  11 +from pathlib import Path
  12 +
  13 +from typing import List
  14 +
  15 +try:
  16 + import sounddevice as sd
  17 +except ImportError:
  18 + print("Please install sounddevice first. You can use")
  19 + print()
  20 + print(" pip install sounddevice")
  21 + print()
  22 + print("to install it")
  23 + sys.exit(-1)
  24 +
  25 +import sherpa_onnx
  26 +
  27 +
  28 +def assert_file_exists(filename: str):
  29 + assert Path(filename).is_file(), (
  30 + f"{filename} does not exist!\n"
  31 + "Please refer to "
  32 + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
  33 + )
  34 +
  35 +
  36 +def get_args():
  37 + parser = argparse.ArgumentParser(
  38 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  39 + )
  40 +
  41 + parser.add_argument(
  42 + "--tokens",
  43 + type=str,
  44 + help="Path to tokens.txt",
  45 + )
  46 +
  47 + parser.add_argument(
  48 + "--encoder",
  49 + type=str,
  50 + help="Path to the transducer encoder model",
  51 + )
  52 +
  53 + parser.add_argument(
  54 + "--decoder",
  55 + type=str,
  56 + help="Path to the transducer decoder model",
  57 + )
  58 +
  59 + parser.add_argument(
  60 + "--joiner",
  61 + type=str,
  62 + help="Path to the transducer joiner model",
  63 + )
  64 +
  65 + parser.add_argument(
  66 + "--num-threads",
  67 + type=int,
  68 + default=1,
  69 + help="Number of threads for neural network computation",
  70 + )
  71 +
  72 + parser.add_argument(
  73 + "--provider",
  74 + type=str,
  75 + default="cpu",
  76 + help="Valid values: cpu, cuda, coreml",
  77 + )
  78 +
  79 + parser.add_argument(
  80 + "--max-active-paths",
  81 + type=int,
  82 + default=4,
  83 + help="""
  84 + It specifies number of active paths to keep during decoding.
  85 + """,
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--num-trailing-blanks",
  90 + type=int,
  91 + default=1,
  92 + help="""The number of trailing blanks a keyword should be followed. Setting
  93 + to a larger value (e.g. 8) when your keywords has overlapping tokens
  94 + between each other.
  95 + """,
  96 + )
  97 +
  98 + parser.add_argument(
  99 + "--keywords-file",
  100 + type=str,
  101 + help="""
  102 + The file containing keywords, one words/phrases per line, and for each
  103 + phrase the bpe/cjkchar/pinyin are separated by a space. For example:
  104 +
  105 + ▁HE LL O ▁WORLD
  106 + x iǎo ài t óng x ué
  107 + """,
  108 + )
  109 +
  110 + parser.add_argument(
  111 + "--keywords-score",
  112 + type=float,
  113 + default=1.0,
  114 + help="""
  115 + The boosting score of each token for keywords. The larger the easier to
  116 + survive beam search.
  117 + """,
  118 + )
  119 +
  120 + parser.add_argument(
  121 + "--keywords-threshold",
  122 + type=float,
  123 + default=0.25,
  124 + help="""
  125 + The trigger threshold (i.e. probability) of the keyword. The larger the
  126 + harder to trigger.
  127 + """,
  128 + )
  129 +
  130 + return parser.parse_args()
  131 +
  132 +
  133 +def main():
  134 + args = get_args()
  135 +
  136 + devices = sd.query_devices()
  137 + if len(devices) == 0:
  138 + print("No microphone devices found")
  139 + sys.exit(0)
  140 +
  141 + print(devices)
  142 + default_input_device_idx = sd.default.device[0]
  143 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  144 +
  145 + assert_file_exists(args.tokens)
  146 + assert_file_exists(args.encoder)
  147 + assert_file_exists(args.decoder)
  148 + assert_file_exists(args.joiner)
  149 +
  150 + assert Path(
  151 + args.keywords_file
  152 + ).is_file(), (
  153 + f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
  154 + )
  155 +
  156 + keyword_spotter = sherpa_onnx.KeywordSpotter(
  157 + tokens=args.tokens,
  158 + encoder=args.encoder,
  159 + decoder=args.decoder,
  160 + joiner=args.joiner,
  161 + num_threads=args.num_threads,
  162 + max_active_paths=args.max_active_paths,
  163 + keywords_file=args.keywords_file,
  164 + keywords_score=args.keywords_score,
  165 + keywords_threshold=args.keywords_threshold,
  166 + num_tailing_blanks=args.rnum_tailing_blanks,
  167 + provider=args.provider,
  168 + )
  169 +
  170 + print("Started! Please speak")
  171 +
  172 + sample_rate = 16000
  173 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
  174 + stream = keyword_spotter.create_stream()
  175 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
  176 + while True:
  177 + samples, _ = s.read(samples_per_read) # a blocking read
  178 + samples = samples.reshape(-1)
  179 + stream.accept_waveform(sample_rate, samples)
  180 + while keyword_spotter.is_ready(stream):
  181 + keyword_spotter.decode_stream(stream)
  182 + result = keyword_spotter.get_result(stream)
  183 + if result:
  184 + print("\r{}".format(result), end="", flush=True)
  185 +
  186 +
  187 +if __name__ == "__main__":
  188 + try:
  189 + main()
  190 + except KeyboardInterrupt:
  191 + print("\nCaught Ctrl + C. Exiting")
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
  5 +from wave file(s).
  6 +
  7 +Please refer to
  8 +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  9 +to download pre-trained models.
  10 +"""
  11 +import argparse
  12 +import time
  13 +import wave
  14 +from pathlib import Path
  15 +from typing import List, Tuple
  16 +
  17 +import numpy as np
  18 +import sherpa_onnx
  19 +
  20 +
  21 +def get_args():
  22 + parser = argparse.ArgumentParser(
  23 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  24 + )
  25 +
  26 + parser.add_argument(
  27 + "--tokens",
  28 + type=str,
  29 + help="Path to tokens.txt",
  30 + )
  31 +
  32 + parser.add_argument(
  33 + "--encoder",
  34 + type=str,
  35 + help="Path to the transducer encoder model",
  36 + )
  37 +
  38 + parser.add_argument(
  39 + "--decoder",
  40 + type=str,
  41 + help="Path to the transducer decoder model",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--joiner",
  46 + type=str,
  47 + help="Path to the transducer joiner model",
  48 + )
  49 +
  50 + parser.add_argument(
  51 + "--num-threads",
  52 + type=int,
  53 + default=1,
  54 + help="Number of threads for neural network computation",
  55 + )
  56 +
  57 + parser.add_argument(
  58 + "--provider",
  59 + type=str,
  60 + default="cpu",
  61 + help="Valid values: cpu, cuda, coreml",
  62 + )
  63 +
  64 + parser.add_argument(
  65 + "--max-active-paths",
  66 + type=int,
  67 + default=4,
  68 + help="""
  69 + It specifies number of active paths to keep during decoding.
  70 + """,
  71 + )
  72 +
  73 + parser.add_argument(
  74 + "--num-trailing-blanks",
  75 + type=int,
  76 + default=1,
  77 + help="""The number of trailing blanks a keyword should be followed. Setting
  78 + to a larger value (e.g. 8) when your keywords has overlapping tokens
  79 + between each other.
  80 + """,
  81 + )
  82 +
  83 + parser.add_argument(
  84 + "--keywords-file",
  85 + type=str,
  86 + help="""
  87 + The file containing keywords, one words/phrases per line, and for each
  88 + phrase the bpe/cjkchar/pinyin are separated by a space. For example:
  89 +
  90 + ▁HE LL O ▁WORLD
  91 + x iǎo ài t óng x ué
  92 + """,
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--keywords-score",
  97 + type=float,
  98 + default=1.0,
  99 + help="""
  100 + The boosting score of each token for keywords. The larger the easier to
  101 + survive beam search.
  102 + """,
  103 + )
  104 +
  105 + parser.add_argument(
  106 + "--keywords-threshold",
  107 + type=float,
  108 + default=0.25,
  109 + help="""
  110 + The trigger threshold (i.e. probability) of the keyword. The larger the
  111 + harder to trigger.
  112 + """,
  113 + )
  114 +
  115 + parser.add_argument(
  116 + "sound_files",
  117 + type=str,
  118 + nargs="+",
  119 + help="The input sound file(s) to decode. Each file must be of WAVE"
  120 + "format with a single channel, and each sample has 16-bit, "
  121 + "i.e., int16_t. "
  122 + "The sample rate of the file can be arbitrary and does not need to "
  123 + "be 16 kHz",
  124 + )
  125 +
  126 + return parser.parse_args()
  127 +
  128 +
  129 +def assert_file_exists(filename: str):
  130 + assert Path(filename).is_file(), (
  131 + f"{filename} does not exist!\n"
  132 + "Please refer to "
  133 + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
  134 + )
  135 +
  136 +
  137 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  138 + """
  139 + Args:
  140 + wave_filename:
  141 + Path to a wave file. It should be single channel and each sample should
  142 + be 16-bit. Its sample rate does not need to be 16kHz.
  143 + Returns:
  144 + Return a tuple containing:
  145 + - A 1-D array of dtype np.float32 containing the samples, which are
  146 + normalized to the range [-1, 1].
  147 + - sample rate of the wave file
  148 + """
  149 +
  150 + with wave.open(wave_filename) as f:
  151 + assert f.getnchannels() == 1, f.getnchannels()
  152 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  153 + num_samples = f.getnframes()
  154 + samples = f.readframes(num_samples)
  155 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  156 + samples_float32 = samples_int16.astype(np.float32)
  157 +
  158 + samples_float32 = samples_float32 / 32768
  159 + return samples_float32, f.getframerate()
  160 +
  161 +
  162 +def main():
  163 + args = get_args()
  164 + assert_file_exists(args.tokens)
  165 + assert_file_exists(args.encoder)
  166 + assert_file_exists(args.decoder)
  167 + assert_file_exists(args.joiner)
  168 +
  169 + assert Path(
  170 + args.keywords_file
  171 + ).is_file(), (
  172 + f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
  173 + )
  174 +
  175 + keyword_spotter = sherpa_onnx.KeywordSpotter(
  176 + tokens=args.tokens,
  177 + encoder=args.encoder,
  178 + decoder=args.decoder,
  179 + joiner=args.joiner,
  180 + num_threads=args.num_threads,
  181 + max_active_paths=args.max_active_paths,
  182 + keywords_file=args.keywords_file,
  183 + keywords_score=args.keywords_score,
  184 + keywords_threshold=args.keywords_threshold,
  185 + num_trailing_blanks=args.num_trailing_blanks,
  186 + provider=args.provider,
  187 + )
  188 +
  189 + print("Started!")
  190 + start_time = time.time()
  191 +
  192 + streams = []
  193 + total_duration = 0
  194 + for wave_filename in args.sound_files:
  195 + assert_file_exists(wave_filename)
  196 + samples, sample_rate = read_wave(wave_filename)
  197 + duration = len(samples) / sample_rate
  198 + total_duration += duration
  199 +
  200 + s = keyword_spotter.create_stream()
  201 +
  202 + s.accept_waveform(sample_rate, samples)
  203 +
  204 + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
  205 + s.accept_waveform(sample_rate, tail_paddings)
  206 +
  207 + s.input_finished()
  208 +
  209 + streams.append(s)
  210 +
  211 + results = [""] * len(streams)
  212 + while True:
  213 + ready_list = []
  214 + for i, s in enumerate(streams):
  215 + if keyword_spotter.is_ready(s):
  216 + ready_list.append(s)
  217 + r = keyword_spotter.get_result(s)
  218 + if r:
  219 + results[i] += f"{r}/"
  220 + print(f"{r} is detected.")
  221 + if len(ready_list) == 0:
  222 + break
  223 + keyword_spotter.decode_streams(ready_list)
  224 + end_time = time.time()
  225 + print("Done!")
  226 +
  227 + for wave_filename, result in zip(args.sound_files, results):
  228 + print(f"{wave_filename}\n{result}")
  229 + print("-" * 10)
  230 +
  231 + elapsed_seconds = end_time - start_time
  232 + rtf = elapsed_seconds / total_duration
  233 + print(f"num_threads: {args.num_threads}")
  234 + print(f"Wave duration: {total_duration:.3f} s")
  235 + print(f"Elapsed time: {elapsed_seconds:.3f} s")
  236 + print(
  237 + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
  238 + )
  239 +
  240 +
  241 +if __name__ == "__main__":
  242 + main()
@@ -230,12 +230,14 @@ endif() @@ -230,12 +230,14 @@ endif()
230 230
231 if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) 231 if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY)
232 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) 232 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
  233 + add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
233 add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc) 234 add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc)
234 add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) 235 add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc)
235 add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) 236 add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
236 237
237 set(exes 238 set(exes
238 sherpa-onnx-alsa 239 sherpa-onnx-alsa
  240 + sherpa-onnx-keyword-spotter-alsa
239 sherpa-onnx-alsa-offline 241 sherpa-onnx-alsa-offline
240 sherpa-onnx-offline-tts-play-alsa 242 sherpa-onnx-offline-tts-play-alsa
241 sherpa-onnx-alsa-offline-speaker-identification 243 sherpa-onnx-alsa-offline-speaker-identification
@@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) @@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
278 microphone.cc 280 microphone.cc
279 ) 281 )
280 282
  283 + add_executable(sherpa-onnx-keyword-spotter-microphone
  284 + sherpa-onnx-keyword-spotter-microphone.cc
  285 + microphone.cc
  286 + )
  287 +
281 add_executable(sherpa-onnx-microphone 288 add_executable(sherpa-onnx-microphone
282 sherpa-onnx-microphone.cc 289 sherpa-onnx-microphone.cc
283 microphone.cc 290 microphone.cc
@@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) @@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
311 318
312 set(exes 319 set(exes
313 sherpa-onnx-microphone 320 sherpa-onnx-microphone
  321 + sherpa-onnx-keyword-spotter-microphone
314 sherpa-onnx-microphone-offline 322 sherpa-onnx-microphone-offline
315 sherpa-onnx-microphone-offline-speaker-identification 323 sherpa-onnx-microphone-offline-speaker-identification
316 sherpa-onnx-offline-tts-play 324 sherpa-onnx-offline-tts-play
  1 +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include <signal.h>
  5 +#include <stdio.h>
  6 +#include <stdlib.h>
  7 +
  8 +#include <algorithm>
  9 +#include <cstdint>
  10 +
  11 +#include "sherpa-onnx/csrc/alsa.h"
  12 +#include "sherpa-onnx/csrc/display.h"
  13 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  14 +#include "sherpa-onnx/csrc/parse-options.h"
  15 +
  16 +bool stop = false;
  17 +
  18 +static void Handler(int sig) {
  19 + stop = true;
  20 + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
  21 +}
  22 +
  23 +int main(int32_t argc, char *argv[]) {
  24 + signal(SIGINT, Handler);
  25 +
  26 + const char *kUsageMessage = R"usage(
  27 +Usage:
  28 + ./bin/sherpa-onnx-keyword-spotter-alsa \
  29 + --tokens=/path/to/tokens.txt \
  30 + --encoder=/path/to/encoder.onnx \
  31 + --decoder=/path/to/decoder.onnx \
  32 + --joiner=/path/to/joiner.onnx \
  33 + --provider=cpu \
  34 + --num-threads=2 \
  35 + --keywords-file=keywords.txt \
  36 + device_name
  37 +
  38 +Please refer to
  39 +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  40 +for a list of pre-trained models to download.
  41 +
  42 +The device name specifies which microphone to use in case there are several
  43 +on you system. You can use
  44 +
  45 + arecord -l
  46 +
  47 +to find all available microphones on your computer. For instance, if it outputs
  48 +
  49 +**** List of CAPTURE Hardware Devices ****
  50 +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
  51 + Subdevices: 1/1
  52 + Subdevice #0: subdevice #0
  53 +
  54 +and if you want to select card 3 and the device 0 on that card, please use:
  55 +
  56 + hw:3,0
  57 +
  58 +or
  59 +
  60 + plughw:3,0
  61 +
  62 +as the device_name.
  63 +)usage";
  64 + sherpa_onnx::ParseOptions po(kUsageMessage);
  65 + sherpa_onnx::KeywordSpotterConfig config;
  66 +
  67 + config.Register(&po);
  68 +
  69 + po.Read(argc, argv);
  70 + if (po.NumArgs() != 1) {
  71 + fprintf(stderr, "Please provide only 1 argument: the device name\n");
  72 + po.PrintUsage();
  73 + exit(EXIT_FAILURE);
  74 + }
  75 +
  76 + fprintf(stderr, "%s\n", config.ToString().c_str());
  77 +
  78 + if (!config.Validate()) {
  79 + fprintf(stderr, "Errors in config!\n");
  80 + return -1;
  81 + }
  82 + sherpa_onnx::KeywordSpotter spotter(config);
  83 +
  84 + int32_t expected_sample_rate = config.feat_config.sampling_rate;
  85 +
  86 + std::string device_name = po.GetArg(1);
  87 + sherpa_onnx::Alsa alsa(device_name.c_str());
  88 + fprintf(stderr, "Use recording device: %s\n", device_name.c_str());
  89 +
  90 + if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
  91 + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
  92 + expected_sample_rate);
  93 + exit(-1);
  94 + }
  95 +
  96 + int32_t chunk = 0.1 * alsa.GetActualSampleRate();
  97 +
  98 + std::string last_text;
  99 +
  100 + auto stream = spotter.CreateStream();
  101 +
  102 + sherpa_onnx::Display display;
  103 +
  104 + int32_t keyword_index = 0;
  105 + while (!stop) {
  106 + const std::vector<float> &samples = alsa.Read(chunk);
  107 +
  108 + stream->AcceptWaveform(expected_sample_rate, samples.data(),
  109 + samples.size());
  110 +
  111 + while (spotter.IsReady(stream.get())) {
  112 + spotter.DecodeStream(stream.get());
  113 + }
  114 +
  115 + const auto r = spotter.GetResult(stream.get());
  116 + if (!r.keyword.empty()) {
  117 + display.Print(keyword_index, r.AsJsonString());
  118 + fflush(stderr);
  119 + keyword_index++;
  120 + }
  121 + }
  122 +
  123 + return 0;
  124 +}
  1 +// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include <signal.h>
  6 +#include <stdio.h>
  7 +#include <stdlib.h>
  8 +
  9 +#include <algorithm>
  10 +
  11 +#include "portaudio.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/display.h"
  13 +#include "sherpa-onnx/csrc/microphone.h"
  14 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  15 +
  16 +bool stop = false;
  17 +
  18 +static int32_t RecordCallback(const void *input_buffer,
  19 + void * /*output_buffer*/,
  20 + unsigned long frames_per_buffer, // NOLINT
  21 + const PaStreamCallbackTimeInfo * /*time_info*/,
  22 + PaStreamCallbackFlags /*status_flags*/,
  23 + void *user_data) {
  24 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(user_data);
  25 +
  26 + stream->AcceptWaveform(16000, reinterpret_cast<const float *>(input_buffer),
  27 + frames_per_buffer);
  28 +
  29 + return stop ? paComplete : paContinue;
  30 +}
  31 +
  32 +static void Handler(int32_t sig) {
  33 + stop = true;
  34 + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
  35 +}
  36 +
  37 +int32_t main(int32_t argc, char *argv[]) {
  38 + signal(SIGINT, Handler);
  39 +
  40 + const char *kUsageMessage = R"usage(
  41 +This program uses streaming models with microphone for keyword spotting.
  42 +Usage:
  43 +
  44 + ./bin/sherpa-onnx-keyword-spotter-microphone \
  45 + --tokens=/path/to/tokens.txt \
  46 + --encoder=/path/to/encoder.onnx \
  47 + --decoder=/path/to/decoder.onnx \
  48 + --joiner=/path/to/joiner.onnx \
  49 + --provider=cpu \
  50 + --num-threads=1 \
  51 + --keywords-file=keywords.txt
  52 +
  53 +Please refer to
  54 +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  55 +for a list of pre-trained models to download.
  56 +)usage";
  57 +
  58 + sherpa_onnx::ParseOptions po(kUsageMessage);
  59 + sherpa_onnx::KeywordSpotterConfig config;
  60 +
  61 + config.Register(&po);
  62 + po.Read(argc, argv);
  63 + if (po.NumArgs() != 0) {
  64 + po.PrintUsage();
  65 + exit(EXIT_FAILURE);
  66 + }
  67 +
  68 + fprintf(stderr, "%s\n", config.ToString().c_str());
  69 +
  70 + if (!config.Validate()) {
  71 + fprintf(stderr, "Errors in config!\n");
  72 + return -1;
  73 + }
  74 +
  75 + sherpa_onnx::KeywordSpotter spotter(config);
  76 + auto s = spotter.CreateStream();
  77 +
  78 + sherpa_onnx::Microphone mic;
  79 +
  80 + PaDeviceIndex num_devices = Pa_GetDeviceCount();
  81 + fprintf(stderr, "Num devices: %d\n", num_devices);
  82 +
  83 + PaStreamParameters param;
  84 +
  85 + param.device = Pa_GetDefaultInputDevice();
  86 + if (param.device == paNoDevice) {
  87 + fprintf(stderr, "No default input device found\n");
  88 + exit(EXIT_FAILURE);
  89 + }
  90 + fprintf(stderr, "Use default device: %d\n", param.device);
  91 +
  92 + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
  93 + fprintf(stderr, " Name: %s\n", info->name);
  94 + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels);
  95 +
  96 + param.channelCount = 1;
  97 + param.sampleFormat = paFloat32;
  98 +
  99 + param.suggestedLatency = info->defaultLowInputLatency;
  100 + param.hostApiSpecificStreamInfo = nullptr;
  101 + float sample_rate = 16000;
  102 +
  103 + PaStream *stream;
  104 + PaError err =
  105 + Pa_OpenStream(&stream, &param, nullptr, /* &outputParameters, */
  106 + sample_rate,
  107 + 0, // frames per buffer
  108 + paClipOff, // we won't output out of range samples
  109 + // so don't bother clipping them
  110 + RecordCallback, s.get());
  111 + if (err != paNoError) {
  112 + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
  113 + exit(EXIT_FAILURE);
  114 + }
  115 +
  116 + err = Pa_StartStream(stream);
  117 + fprintf(stderr, "Started\n");
  118 +
  119 + if (err != paNoError) {
  120 + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
  121 + exit(EXIT_FAILURE);
  122 + }
  123 +
  124 + int32_t keyword_index = 0;
  125 + sherpa_onnx::Display display;
  126 + while (!stop) {
  127 + while (spotter.IsReady(s.get())) {
  128 + spotter.DecodeStream(s.get());
  129 + }
  130 +
  131 + const auto r = spotter.GetResult(s.get());
  132 + if (!r.keyword.empty()) {
  133 + display.Print(keyword_index, r.AsJsonString());
  134 + fflush(stderr);
  135 + keyword_index++;
  136 + }
  137 +
  138 + Pa_Sleep(20); // sleep for 20ms
  139 + }
  140 +
  141 + err = Pa_CloseStream(stream);
  142 + if (err != paNoError) {
  143 + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
  144 + exit(EXIT_FAILURE);
  145 + }
  146 +
  147 + return 0;
  148 +}
@@ -12,7 +12,6 @@ @@ -12,7 +12,6 @@
12 #include "sherpa-onnx/csrc/keyword-spotter.h" 12 #include "sherpa-onnx/csrc/keyword-spotter.h"
13 #include "sherpa-onnx/csrc/online-stream.h" 13 #include "sherpa-onnx/csrc/online-stream.h"
14 #include "sherpa-onnx/csrc/parse-options.h" 14 #include "sherpa-onnx/csrc/parse-options.h"
15 -#include "sherpa-onnx/csrc/symbol-table.h"  
16 #include "sherpa-onnx/csrc/wave-reader.h" 15 #include "sherpa-onnx/csrc/wave-reader.h"
17 16
18 typedef struct { 17 typedef struct {
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
5 display.cc 5 display.cc
6 endpoint.cc 6 endpoint.cc
7 features.cc 7 features.cc
  8 + keyword-spotter.cc
8 offline-ctc-fst-decoder-config.cc 9 offline-ctc-fst-decoder-config.cc
9 offline-lm-config.cc 10 offline-lm-config.cc
10 offline-model-config.cc 11 offline-model-config.cc
  1 +// sherpa-onnx/python/csrc/keyword-spotter.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/keyword-spotter.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/keyword-spotter.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +static void PybindKeywordResult(py::module *m) {
  15 + using PyClass = KeywordResult;
  16 + py::class_<PyClass>(*m, "KeywordResult")
  17 + .def_property_readonly(
  18 + "keyword",
  19 + [](PyClass &self) -> py::str {
  20 + return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(),
  21 + self.keyword.size(), "ignore"));
  22 + })
  23 + .def_property_readonly(
  24 + "tokens",
  25 + [](PyClass &self) -> std::vector<std::string> { return self.tokens; })
  26 + .def_property_readonly(
  27 + "timestamps",
  28 + [](PyClass &self) -> std::vector<float> { return self.timestamps; });
  29 +}
  30 +
  31 +static void PybindKeywordSpotterConfig(py::module *m) {
  32 + using PyClass = KeywordSpotterConfig;
  33 + py::class_<PyClass>(*m, "KeywordSpotterConfig")
  34 + .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
  35 + int32_t, int32_t, float, float, const std::string &>(),
  36 + py::arg("feat_config"), py::arg("model_config"),
  37 + py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1,
  38 + py::arg("keywords_score") = 1.0,
  39 + py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "")
  40 + .def_readwrite("feat_config", &PyClass::feat_config)
  41 + .def_readwrite("model_config", &PyClass::model_config)
  42 + .def_readwrite("max_active_paths", &PyClass::max_active_paths)
  43 + .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks)
  44 + .def_readwrite("keywords_score", &PyClass::keywords_score)
  45 + .def_readwrite("keywords_threshold", &PyClass::keywords_threshold)
  46 + .def_readwrite("keywords_file", &PyClass::keywords_file)
  47 + .def("__str__", &PyClass::ToString);
  48 +}
  49 +
  50 +void PybindKeywordSpotter(py::module *m) {
  51 + PybindKeywordResult(m);
  52 + PybindKeywordSpotterConfig(m);
  53 +
  54 + using PyClass = KeywordSpotter;
  55 + py::class_<PyClass>(*m, "KeywordSpotter")
  56 + .def(py::init<const KeywordSpotterConfig &>(), py::arg("config"),
  57 + py::call_guard<py::gil_scoped_release>())
  58 + .def(
  59 + "create_stream",
  60 + [](const PyClass &self) { return self.CreateStream(); },
  61 + py::call_guard<py::gil_scoped_release>())
  62 + .def(
  63 + "create_stream",
  64 + [](PyClass &self, const std::string &keywords) {
  65 + return self.CreateStream(keywords);
  66 + },
  67 + py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
  68 + .def("is_ready", &PyClass::IsReady,
  69 + py::call_guard<py::gil_scoped_release>())
  70 + .def("decode_stream", &PyClass::DecodeStream,
  71 + py::call_guard<py::gil_scoped_release>())
  72 + .def(
  73 + "decode_streams",
  74 + [](PyClass &self, std::vector<OnlineStream *> ss) {
  75 + self.DecodeStreams(ss.data(), ss.size());
  76 + },
  77 + py::call_guard<py::gil_scoped_release>())
  78 + .def("get_result", &PyClass::GetResult,
  79 + py::call_guard<py::gil_scoped_release>());
  80 +}
  81 +
  82 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/keyword-spotter.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindKeywordSpotter(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include "sherpa-onnx/python/csrc/display.h" 8 #include "sherpa-onnx/python/csrc/display.h"
9 #include "sherpa-onnx/python/csrc/endpoint.h" 9 #include "sherpa-onnx/python/csrc/endpoint.h"
10 #include "sherpa-onnx/python/csrc/features.h" 10 #include "sherpa-onnx/python/csrc/features.h"
  11 +#include "sherpa-onnx/python/csrc/keyword-spotter.h"
11 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" 12 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
12 #include "sherpa-onnx/python/csrc/offline-lm-config.h" 13 #include "sherpa-onnx/python/csrc/offline-lm-config.h"
13 #include "sherpa-onnx/python/csrc/offline-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-model-config.h"
@@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
35 PybindOnlineStream(&m); 36 PybindOnlineStream(&m);
36 PybindEndpoint(&m); 37 PybindEndpoint(&m);
37 PybindOnlineRecognizer(&m); 38 PybindOnlineRecognizer(&m);
  39 + PybindKeywordSpotter(&m);
38 40
39 PybindDisplay(&m); 41 PybindDisplay(&m);
40 42
@@ -17,6 +17,7 @@ from _sherpa_onnx import ( @@ -17,6 +17,7 @@ from _sherpa_onnx import (
17 VoiceActivityDetector, 17 VoiceActivityDetector,
18 ) 18 )
19 19
  20 +from .keyword_spotter import KeywordSpotter
20 from .offline_recognizer import OfflineRecognizer 21 from .offline_recognizer import OfflineRecognizer
21 from .online_recognizer import OnlineRecognizer 22 from .online_recognizer import OnlineRecognizer
22 from .utils import text2token 23 from .utils import text2token
  1 +# Copyright (c) 2023 Xiaomi Corporation
  2 +
  3 +from pathlib import Path
  4 +from typing import List, Optional
  5 +
  6 +from _sherpa_onnx import (
  7 + FeatureExtractorConfig,
  8 + KeywordSpotterConfig,
  9 + OnlineModelConfig,
  10 + OnlineTransducerModelConfig,
  11 + OnlineStream,
  12 +)
  13 +
  14 +from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
  15 +
  16 +
  17 +def _assert_file_exists(f: str):
  18 + assert Path(f).is_file(), f"{f} does not exist"
  19 +
  20 +
  21 +class KeywordSpotter(object):
  22 + """A class for keyword spotting.
  23 +
  24 + Please refer to the following files for usages
  25 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py
  26 + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter-from-microphone.py
  27 + """
  28 +
  29 + def __init__(
  30 + self,
  31 + tokens: str,
  32 + encoder: str,
  33 + decoder: str,
  34 + joiner: str,
  35 + keywords_file: str,
  36 + num_threads: int = 2,
  37 + sample_rate: float = 16000,
  38 + feature_dim: int = 80,
  39 + max_active_paths: int = 4,
  40 + keywords_score: float = 1.0,
  41 + keywords_threshold: float = 0.25,
  42 + num_trailing_blanks: int = 1,
  43 + provider: str = "cpu",
  44 + ):
  45 + """
  46 + Please refer to
  47 + `<https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html>`_
  48 + to download pre-trained models for different languages, e.g., Chinese,
  49 + English, etc.
  50 +
  51 + Args:
  52 + tokens:
  53 + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
  54 + columns::
  55 +
  56 + symbol integer_id
  57 +
  58 + encoder:
  59 + Path to ``encoder.onnx``.
  60 + decoder:
  61 + Path to ``decoder.onnx``.
  62 + joiner:
  63 + Path to ``joiner.onnx``.
  64 + keywords_file:
  65 + The file containing keywords, one word/phrase per line, and for each
  66 + phrase the bpe/cjkchar/pinyin are separated by a space.
  67 + num_threads:
  68 + Number of threads for neural network computation.
  69 + sample_rate:
  70 + Sample rate of the training data used to train the model.
  71 + feature_dim:
  72 + Dimension of the feature used to train the model.
  73 + max_active_paths:
  74 + Use only when decoding_method is modified_beam_search. It specifies
  75 + the maximum number of active paths during beam search.
  76 + keywords_score:
  77 + The boosting score of each token for keywords. The larger the easier to
  78 + survive beam search.
  79 + keywords_threshold:
  80 + The trigger threshold (i.e. probability) of the keyword. The larger the
  81 + harder to trigger.
  82 + num_trailing_blanks:
  83 + The number of trailing blanks a keyword should be followed. Setting
  84 + to a larger value (e.g. 8) when your keywords has overlapping tokens
  85 + between each other.
  86 + provider:
  87 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  88 + """
  89 + _assert_file_exists(tokens)
  90 + _assert_file_exists(encoder)
  91 + _assert_file_exists(decoder)
  92 + _assert_file_exists(joiner)
  93 +
  94 + assert num_threads > 0, num_threads
  95 +
  96 + transducer_config = OnlineTransducerModelConfig(
  97 + encoder=encoder,
  98 + decoder=decoder,
  99 + joiner=joiner,
  100 + )
  101 +
  102 + model_config = OnlineModelConfig(
  103 + transducer=transducer_config,
  104 + tokens=tokens,
  105 + num_threads=num_threads,
  106 + provider=provider,
  107 + )
  108 +
  109 + feat_config = FeatureExtractorConfig(
  110 + sampling_rate=sample_rate,
  111 + feature_dim=feature_dim,
  112 + )
  113 +
  114 + keywords_spotter_config = KeywordSpotterConfig(
  115 + feat_config=feat_config,
  116 + model_config=model_config,
  117 + max_active_paths=max_active_paths,
  118 + num_trailing_blanks=num_trailing_blanks,
  119 + keywords_score=keywords_score,
  120 + keywords_threshold=keywords_threshold,
  121 + keywords_file=keywords_file,
  122 + )
  123 + self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
  124 +
  125 + def create_stream(self, keywords: Optional[str] = None):
  126 + if keywords is None:
  127 + return self.keyword_spotter.create_stream()
  128 + else:
  129 + return self.keyword_spotter.create_stream(keywords)
  130 +
  131 + def decode_stream(self, s: OnlineStream):
  132 + self.keyword_spotter.decode_stream(s)
  133 +
  134 + def decode_streams(self, ss: List[OnlineStream]):
  135 + self.keyword_spotter.decode_streams(ss)
  136 +
  137 + def is_ready(self, s: OnlineStream) -> bool:
  138 + return self.keyword_spotter.is_ready(s)
  139 +
  140 + def get_result(self, s: OnlineStream) -> str:
  141 + return self.keyword_spotter.get_result(s).keyword.strip()
  142 +
  143 + def tokens(self, s: OnlineStream) -> List[str]:
  144 + return self.keyword_spotter.get_result(s).tokens
  145 +
  146 + def timestamps(self, s: OnlineStream) -> List[float]:
  147 + return self.keyword_spotter.get_result(s).timestamps
@@ -20,6 +20,7 @@ endfunction() @@ -20,6 +20,7 @@ endfunction()
20 # please sort the files in alphabetic order 20 # please sort the files in alphabetic order
21 set(py_test_files 21 set(py_test_files
22 test_feature_extractor_config.py 22 test_feature_extractor_config.py
  23 + test_keyword_spotter.py
23 test_offline_recognizer.py 24 test_offline_recognizer.py
24 test_online_recognizer.py 25 test_online_recognizer.py
25 test_online_transducer_model_config.py 26 test_online_transducer_model_config.py
  1 +# sherpa-onnx/python/tests/test_keyword_spotter.py
  2 +#
  3 +# Copyright (c) 2024 Xiaomi Corporation
  4 +#
  5 +# To run this single test, use
  6 +#
  7 +# ctest --verbose -R test_keyword_spotter_py
  8 +
  9 +import unittest
  10 +import wave
  11 +from pathlib import Path
  12 +from typing import Tuple
  13 +
  14 +import numpy as np
  15 +import sherpa_onnx
  16 +
  17 +d = "/tmp/onnx-models"
  18 +# Please refer to
  19 +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
  20 +# to download pre-trained models for testing
  21 +
  22 +
  23 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  24 + """
  25 + Args:
  26 + wave_filename:
  27 + Path to a wave file. It should be single channel and each sample should
  28 + be 16-bit. Its sample rate does not need to be 16kHz.
  29 + Returns:
  30 + Return a tuple containing:
  31 + - A 1-D array of dtype np.float32 containing the samples, which are
  32 + normalized to the range [-1, 1].
  33 + - sample rate of the wave file
  34 + """
  35 +
  36 + with wave.open(wave_filename) as f:
  37 + assert f.getnchannels() == 1, f.getnchannels()
  38 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  39 + num_samples = f.getnframes()
  40 + samples = f.readframes(num_samples)
  41 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  42 + samples_float32 = samples_int16.astype(np.float32)
  43 +
  44 + samples_float32 = samples_float32 / 32768
  45 + return samples_float32, f.getframerate()
  46 +
  47 +
  48 +class TestKeywordSpotter(unittest.TestCase):
  49 + def test_zipformer_transducer_en(self):
  50 + for use_int8 in [True, False]:
  51 + if use_int8:
  52 + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  53 + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  54 + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  55 + else:
  56 + encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  57 + decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  58 + joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  59 +
  60 + tokens = (
  61 + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt"
  62 + )
  63 + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
  64 + wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav"
  65 + wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav"
  66 +
  67 + if not Path(encoder).is_file():
  68 + print("skipping test_zipformer_transducer_en()")
  69 + return
  70 + keyword_spotter = sherpa_onnx.KeywordSpotter(
  71 + encoder=encoder,
  72 + decoder=decoder,
  73 + joiner=joiner,
  74 + tokens=tokens,
  75 + num_threads=1,
  76 + keywords_file=keywords_file,
  77 + provider="cpu",
  78 + )
  79 + streams = []
  80 + waves = [wave0, wave1]
  81 + for wave in waves:
  82 + s = keyword_spotter.create_stream()
  83 + samples, sample_rate = read_wave(wave)
  84 + s.accept_waveform(sample_rate, samples)
  85 +
  86 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  87 + s.accept_waveform(sample_rate, tail_paddings)
  88 + s.input_finished()
  89 + streams.append(s)
  90 +
  91 + results = [""] * len(streams)
  92 + while True:
  93 + ready_list = []
  94 + for i, s in enumerate(streams):
  95 + if keyword_spotter.is_ready(s):
  96 + ready_list.append(s)
  97 + r = keyword_spotter.get_result(s)
  98 + if r:
  99 + print(f"{r} is detected.")
  100 + results[i] += f"{r}/"
  101 + if len(ready_list) == 0:
  102 + break
  103 + keyword_spotter.decode_streams(ready_list)
  104 + for wave_filename, result in zip(waves, results):
  105 + print(f"{wave_filename}\n{result[0:-1]}")
  106 + print("-" * 10)
  107 +
  108 + def test_zipformer_transducer_cn(self):
  109 + for use_int8 in [True, False]:
  110 + if use_int8:
  111 + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  112 + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  113 + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  114 + else:
  115 + encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  116 + decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  117 + joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
  118 +
  119 + tokens = (
  120 + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
  121 + )
  122 + keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
  123 + wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
  124 + wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav"
  125 + wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav"
  126 +
  127 + if not Path(encoder).is_file():
  128 + print("skipping test_zipformer_transducer_cn()")
  129 + return
  130 + keyword_spotter = sherpa_onnx.KeywordSpotter(
  131 + encoder=encoder,
  132 + decoder=decoder,
  133 + joiner=joiner,
  134 + tokens=tokens,
  135 + num_threads=1,
  136 + keywords_file=keywords_file,
  137 + provider="cpu",
  138 + )
  139 + streams = []
  140 + waves = [wave0, wave1, wave2]
  141 + for wave in waves:
  142 + s = keyword_spotter.create_stream()
  143 + samples, sample_rate = read_wave(wave)
  144 + s.accept_waveform(sample_rate, samples)
  145 +
  146 + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
  147 + s.accept_waveform(sample_rate, tail_paddings)
  148 + s.input_finished()
  149 + streams.append(s)
  150 +
  151 + results = [""] * len(streams)
  152 + while True:
  153 + ready_list = []
  154 + for i, s in enumerate(streams):
  155 + if keyword_spotter.is_ready(s):
  156 + ready_list.append(s)
  157 + r = keyword_spotter.get_result(s)
  158 + if r:
  159 + print(f"{r} is detected.")
  160 + results[i] += f"{r}/"
  161 + if len(ready_list) == 0:
  162 + break
  163 + keyword_spotter.decode_streams(ready_list)
  164 + for wave_filename, result in zip(waves, results):
  165 + print(f"{wave_filename}\n{result[0:-1]}")
  166 + print("-" * 10)
  167 +
  168 +
  169 +if __name__ == "__main__":
  170 + unittest.main()