Wei Kang
Committed by GitHub

Add Python API for keyword spotting (#576)

* Add alsa & microphone support for keyword spotting

* Add python wrapper
... ... @@ -293,3 +293,61 @@ git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
python3 sherpa-onnx/python/tests/test_text2token.py --verbose
rm -rf /tmp/sherpa-test-data
mkdir -p /tmp/onnx-models
dir=/tmp/onnx-models
log "Test keyword spotting models"
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
sherpa_onnx_version=$(python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)")
echo "sherpa_onnx version: $sherpa_onnx_version"
pwd
ls -lh
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Start testing ${repo}"
pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
popd
repo=$dir/$repo
ls -lh $repo
python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
log "Start testing ${repo}"
pushd $dir
wget -qq https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz
popd
repo=$dir/$repo
ls -lh $repo
python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav \
$repo/test_wavs/5.wav
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
rm -r $dir
... ...
#!/usr/bin/env python3
# Real-time keyword spotting from a microphone with sherpa-onnx Python API
#
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
# to download pre-trained models
import argparse
import sys
from pathlib import Path
from typing import List
try:
import sounddevice as sd
except ImportError:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_onnx
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the transducer joiner model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
)
parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
▁HE LL O ▁WORLD
x iǎo ài t óng x ué
""",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
""",
)
return parser.parse_args()
def main():
args = get_args()
devices = sd.query_devices()
if len(devices) == 0:
print("No microphone devices found")
sys.exit(0)
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
assert_file_exists(args.tokens)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert Path(
args.keywords_file
).is_file(), (
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
)
keyword_spotter = sherpa_onnx.KeywordSpotter(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
max_active_paths=args.max_active_paths,
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_tailing_blanks=args.rnum_tailing_blanks,
provider=args.provider,
)
print("Started! Please speak")
sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
stream = keyword_spotter.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
samples = samples.reshape(-1)
stream.accept_waveform(sample_rate, samples)
while keyword_spotter.is_ready(stream):
keyword_spotter.decode_stream(stream)
result = keyword_spotter.get_result(stream)
if result:
print("\r{}".format(result), end="", flush=True)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
... ...
#!/usr/bin/env python3
"""
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
from wave file(s).
Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
to download pre-trained models.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple
import numpy as np
import sherpa_onnx
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the transducer joiner model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
)
parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
▁HE LL O ▁WORLD
x iǎo ài t óng x ué
""",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def main():
args = get_args()
assert_file_exists(args.tokens)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert Path(
args.keywords_file
).is_file(), (
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
)
keyword_spotter = sherpa_onnx.KeywordSpotter(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
max_active_paths=args.max_active_paths,
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
provider=args.provider,
)
print("Started!")
start_time = time.time()
streams = []
total_duration = 0
for wave_filename in args.sound_files:
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = keyword_spotter.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
results[i] += f"{r}/"
print(f"{r} is detected.")
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
end_time = time.time()
print("Done!")
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if __name__ == "__main__":
main()
... ...
... ... @@ -230,12 +230,14 @@ endif()
if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
add_executable(sherpa-onnx-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc)
add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc)
add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
set(exes
sherpa-onnx-alsa
sherpa-onnx-keyword-spotter-alsa
sherpa-onnx-alsa-offline
sherpa-onnx-offline-tts-play-alsa
sherpa-onnx-alsa-offline-speaker-identification
... ... @@ -278,6 +280,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
microphone.cc
)
add_executable(sherpa-onnx-keyword-spotter-microphone
sherpa-onnx-keyword-spotter-microphone.cc
microphone.cc
)
add_executable(sherpa-onnx-microphone
sherpa-onnx-microphone.cc
microphone.cc
... ... @@ -311,6 +318,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
set(exes
sherpa-onnx-microphone
sherpa-onnx-keyword-spotter-microphone
sherpa-onnx-microphone-offline
sherpa-onnx-microphone-offline-speaker-identification
sherpa-onnx-offline-tts-play
... ...
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-alsa.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <cstdint>
#include "sherpa-onnx/csrc/alsa.h"
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/parse-options.h"
bool stop = false;
static void Handler(int sig) {
stop = true;
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
}
int main(int32_t argc, char *argv[]) {
signal(SIGINT, Handler);
const char *kUsageMessage = R"usage(
Usage:
./bin/sherpa-onnx-keyword-spotter-alsa \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=2 \
--keywords-file=keywords.txt \
device_name
Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
for a list of pre-trained models to download.
The device name specifies which microphone to use in case there are several
on you system. You can use
arecord -l
to find all available microphones on your computer. For instance, if it outputs
**** List of CAPTURE Hardware Devices ****
card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
Subdevices: 1/1
Subdevice #0: subdevice #0
and if you want to select card 3 and the device 0 on that card, please use:
hw:3,0
or
plughw:3,0
as the device_name.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::KeywordSpotterConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 1) {
fprintf(stderr, "Please provide only 1 argument: the device name\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::KeywordSpotter spotter(config);
int32_t expected_sample_rate = config.feat_config.sampling_rate;
std::string device_name = po.GetArg(1);
sherpa_onnx::Alsa alsa(device_name.c_str());
fprintf(stderr, "Use recording device: %s\n", device_name.c_str());
if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
expected_sample_rate);
exit(-1);
}
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
std::string last_text;
auto stream = spotter.CreateStream();
sherpa_onnx::Display display;
int32_t keyword_index = 0;
while (!stop) {
const std::vector<float> &samples = alsa.Read(chunk);
stream->AcceptWaveform(expected_sample_rate, samples.data(),
samples.size());
while (spotter.IsReady(stream.get())) {
spotter.DecodeStream(stream.get());
}
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
}
}
return 0;
}
... ...
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter-microphone.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include "portaudio.h" // NOLINT
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/microphone.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
bool stop = false;
static int32_t RecordCallback(const void *input_buffer,
void * /*output_buffer*/,
unsigned long frames_per_buffer, // NOLINT
const PaStreamCallbackTimeInfo * /*time_info*/,
PaStreamCallbackFlags /*status_flags*/,
void *user_data) {
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(user_data);
stream->AcceptWaveform(16000, reinterpret_cast<const float *>(input_buffer),
frames_per_buffer);
return stop ? paComplete : paContinue;
}
static void Handler(int32_t sig) {
stop = true;
fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n");
}
int32_t main(int32_t argc, char *argv[]) {
signal(SIGINT, Handler);
const char *kUsageMessage = R"usage(
This program uses streaming models with microphone for keyword spotting.
Usage:
./bin/sherpa-onnx-keyword-spotter-microphone \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--provider=cpu \
--num-threads=1 \
--keywords-file=keywords.txt
Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::KeywordSpotterConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 0) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
sherpa_onnx::KeywordSpotter spotter(config);
auto s = spotter.CreateStream();
sherpa_onnx::Microphone mic;
PaDeviceIndex num_devices = Pa_GetDeviceCount();
fprintf(stderr, "Num devices: %d\n", num_devices);
PaStreamParameters param;
param.device = Pa_GetDefaultInputDevice();
if (param.device == paNoDevice) {
fprintf(stderr, "No default input device found\n");
exit(EXIT_FAILURE);
}
fprintf(stderr, "Use default device: %d\n", param.device);
const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
fprintf(stderr, " Name: %s\n", info->name);
fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels);
param.channelCount = 1;
param.sampleFormat = paFloat32;
param.suggestedLatency = info->defaultLowInputLatency;
param.hostApiSpecificStreamInfo = nullptr;
float sample_rate = 16000;
PaStream *stream;
PaError err =
Pa_OpenStream(&stream, &param, nullptr, /* &outputParameters, */
sample_rate,
0, // frames per buffer
paClipOff, // we won't output out of range samples
// so don't bother clipping them
RecordCallback, s.get());
if (err != paNoError) {
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
exit(EXIT_FAILURE);
}
err = Pa_StartStream(stream);
fprintf(stderr, "Started\n");
if (err != paNoError) {
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
exit(EXIT_FAILURE);
}
int32_t keyword_index = 0;
sherpa_onnx::Display display;
while (!stop) {
while (spotter.IsReady(s.get())) {
spotter.DecodeStream(s.get());
}
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
}
Pa_Sleep(20); // sleep for 20ms
}
err = Pa_CloseStream(stream);
if (err != paNoError) {
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
exit(EXIT_FAILURE);
}
return 0;
}
... ...
... ... @@ -12,7 +12,6 @@
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
... ...
... ... @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-lm-config.cc
offline-model-config.cc
... ...
// sherpa-onnx/python/csrc/keyword-spotter.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/keyword-spotter.h"
namespace sherpa_onnx {
static void PybindKeywordResult(py::module *m) {
using PyClass = KeywordResult;
py::class_<PyClass>(*m, "KeywordResult")
.def_property_readonly(
"keyword",
[](PyClass &self) -> py::str {
return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(),
self.keyword.size(), "ignore"));
})
.def_property_readonly(
"tokens",
[](PyClass &self) -> std::vector<std::string> { return self.tokens; })
.def_property_readonly(
"timestamps",
[](PyClass &self) -> std::vector<float> { return self.timestamps; });
}
static void PybindKeywordSpotterConfig(py::module *m) {
using PyClass = KeywordSpotterConfig;
py::class_<PyClass>(*m, "KeywordSpotterConfig")
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
int32_t, int32_t, float, float, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1,
py::arg("keywords_score") = 1.0,
py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "")
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks)
.def_readwrite("keywords_score", &PyClass::keywords_score)
.def_readwrite("keywords_threshold", &PyClass::keywords_threshold)
.def_readwrite("keywords_file", &PyClass::keywords_file)
.def("__str__", &PyClass::ToString);
}
void PybindKeywordSpotter(py::module *m) {
PybindKeywordResult(m);
PybindKeywordSpotterConfig(m);
using PyClass = KeywordSpotter;
py::class_<PyClass>(*m, "KeywordSpotter")
.def(py::init<const KeywordSpotterConfig &>(), py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def(
"create_stream",
[](const PyClass &self) { return self.CreateStream(); },
py::call_guard<py::gil_scoped_release>())
.def(
"create_stream",
[](PyClass &self, const std::string &keywords) {
return self.CreateStream(keywords);
},
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream,
py::call_guard<py::gil_scoped_release>())
.def(
"decode_streams",
[](PyClass &self, std::vector<OnlineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
},
py::call_guard<py::gil_scoped_release>())
.def("get_result", &PyClass::GetResult,
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/keyword-spotter.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindKeywordSpotter(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_
... ...
... ... @@ -8,6 +8,7 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
#include "sherpa-onnx/python/csrc/offline-model-config.h"
... ... @@ -35,6 +36,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
PybindKeywordSpotter(&m);
PybindDisplay(&m);
... ...
... ... @@ -17,6 +17,7 @@ from _sherpa_onnx import (
VoiceActivityDetector,
)
from .keyword_spotter import KeywordSpotter
from .offline_recognizer import OfflineRecognizer
from .online_recognizer import OnlineRecognizer
from .utils import text2token
... ...
# Copyright (c) 2023 Xiaomi Corporation
from pathlib import Path
from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
KeywordSpotterConfig,
OnlineModelConfig,
OnlineTransducerModelConfig,
OnlineStream,
)
from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
def _assert_file_exists(f: str):
assert Path(f).is_file(), f"{f} does not exist"
class KeywordSpotter(object):
"""A class for keyword spotting.
Please refer to the following files for usages
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter.py
- https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/keyword-spotter-from-microphone.py
"""
def __init__(
self,
tokens: str,
encoder: str,
decoder: str,
joiner: str,
keywords_file: str,
num_threads: int = 2,
sample_rate: float = 16000,
feature_dim: int = 80,
max_active_paths: int = 4,
keywords_score: float = 1.0,
keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1,
provider: str = "cpu",
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
encoder:
Path to ``encoder.onnx``.
decoder:
Path to ``decoder.onnx``.
joiner:
Path to ``joiner.onnx``.
keywords_file:
The file containing keywords, one word/phrase per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
keywords_score:
The boosting score of each token for keywords. The larger the easier to
survive beam search.
keywords_threshold:
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
num_trailing_blanks:
The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
_assert_file_exists(decoder)
_assert_file_exists(joiner)
assert num_threads > 0, num_threads
transducer_config = OnlineTransducerModelConfig(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
model_config = OnlineModelConfig(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
provider=provider,
)
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
keywords_spotter_config = KeywordSpotterConfig(
feat_config=feat_config,
model_config=model_config,
max_active_paths=max_active_paths,
num_trailing_blanks=num_trailing_blanks,
keywords_score=keywords_score,
keywords_threshold=keywords_threshold,
keywords_file=keywords_file,
)
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
def create_stream(self, keywords: Optional[str] = None):
if keywords is None:
return self.keyword_spotter.create_stream()
else:
return self.keyword_spotter.create_stream(keywords)
def decode_stream(self, s: OnlineStream):
self.keyword_spotter.decode_stream(s)
def decode_streams(self, ss: List[OnlineStream]):
self.keyword_spotter.decode_streams(ss)
def is_ready(self, s: OnlineStream) -> bool:
return self.keyword_spotter.is_ready(s)
def get_result(self, s: OnlineStream) -> str:
return self.keyword_spotter.get_result(s).keyword.strip()
def tokens(self, s: OnlineStream) -> List[str]:
return self.keyword_spotter.get_result(s).tokens
def timestamps(self, s: OnlineStream) -> List[float]:
return self.keyword_spotter.get_result(s).timestamps
... ...
... ... @@ -20,6 +20,7 @@ endfunction()
# please sort the files in alphabetic order
set(py_test_files
test_feature_extractor_config.py
test_keyword_spotter.py
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
... ...
# sherpa-onnx/python/tests/test_keyword_spotter.py
#
# Copyright (c) 2024 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_keyword_spotter_py
import unittest
import wave
from pathlib import Path
from typing import Tuple
import numpy as np
import sherpa_onnx
d = "/tmp/onnx-models"
# Please refer to
# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
# to download pre-trained models for testing
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
class TestKeywordSpotter(unittest.TestCase):
def test_zipformer_transducer_en(self):
for use_int8 in [True, False]:
if use_int8:
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
else:
encoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
tokens = (
f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav"
if not Path(encoder).is_file():
print("skipping test_zipformer_transducer_en()")
return
keyword_spotter = sherpa_onnx.KeywordSpotter(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
provider="cpu",
)
streams = []
waves = [wave0, wave1]
for wave in waves:
s = keyword_spotter.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result[0:-1]}")
print("-" * 10)
def test_zipformer_transducer_cn(self):
for use_int8 in [True, False]:
if use_int8:
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
else:
encoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
decoder = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
joiner = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"
tokens = (
f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav"
wave2 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav"
if not Path(encoder).is_file():
print("skipping test_zipformer_transducer_cn()")
return
keyword_spotter = sherpa_onnx.KeywordSpotter(
encoder=encoder,
decoder=decoder,
joiner=joiner,
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
provider="cpu",
)
streams = []
waves = [wave0, wave1, wave2]
for wave in waves:
s = keyword_spotter.create_stream()
samples, sample_rate = read_wave(wave)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
for wave_filename, result in zip(waves, results):
print(f"{wave_filename}\n{result[0:-1]}")
print("-" * 10)
if __name__ == "__main__":
unittest.main()
... ...