Fangjun Kuang
Committed by GitHub

Fix keyword spotting. (#1689)

Reset the stream right after detecting a keyword
正在显示 43 个修改的文件 包含 823 行增加303 行删除
... ... @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version"
pwd
ls -lh
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
log "Start testing ${repo}"
pushd $dir
curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz
popd
repo=$dir/$repo
ls -lh $repo
python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav
rm -rf $repo
if [[ x$OS != x'windows-latest' ]]; then
echo "OS: $OS"
... ... @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then
repo=$dir/$repo
ls -lh $repo
python3 ./python-api-examples/keyword-spotter.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
--keywords-file=$repo/test_wavs/test_keywords.txt \
$repo/test_wavs/3.wav \
$repo/test_wavs/4.wav \
$repo/test_wavs/5.wav
python3 ./python-api-examples/keyword-spotter.py
python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
... ...
... ... @@ -79,6 +79,27 @@ jobs:
otool -L ./install/lib/libsherpa-onnx-c-api.dylib
fi
- name: Test kws (zh)
shell: bash
run: |
gcc -o kws-c-api ./c-api-examples/kws-c-api.c \
-I ./build/install/include \
-L ./build/install/lib/ \
-l sherpa-onnx-c-api \
-l onnxruntime
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
./kws-c-api
rm ./kws-c-api
rm -rf sherpa-onnx-kws-*
- name: Test Kokoro TTS (en)
shell: bash
run: |
... ...
... ... @@ -81,6 +81,28 @@ jobs:
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
fi
- name: Test KWS (zh)
shell: bash
run: |
g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \
-I ./build/install/include \
-L ./build/install/lib/ \
-l sherpa-onnx-cxx-api \
-l sherpa-onnx-c-api \
-l onnxruntime
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
./kws-cxx-api
rm kws-cxx-api
rm -rf sherpa-onnx-kws-*
- name: Test Kokoro TTS (en)
shell: bash
run: |
... ...
... ... @@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() {
stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
while (kws.isReady(stream)) {
kws.decode(stream)
}
val text = kws.getResult(stream).keyword
val text = kws.getResult(stream).keyword
var textToDisplay = lastText
var textToDisplay = lastText
if (text.isNotBlank()) {
// Remember to reset the stream right after detecting a keyword
if (text.isNotBlank()) {
if (lastText.isBlank()) {
textToDisplay = "$idx: $text"
} else {
textToDisplay = "$idx: $text\n$lastText"
kws.reset(stream)
if (lastText.isBlank()) {
textToDisplay = "$idx: $text"
} else {
textToDisplay = "$idx: $text\n$lastText"
}
lastText = "$idx: $text\n$lastText"
idx += 1
}
lastText = "$idx: $text\n$lastText"
idx += 1
}
runOnUiThread {
textView.text = textToDisplay
runOnUiThread {
textView.text = textToDisplay
}
}
}
}
... ...
... ... @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(decode-file-c-api decode-file-c-api.c)
target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
add_executable(kws-c-api kws-c-api.c)
target_link_libraries(kws-c-api sherpa-onnx-c-api)
if(SHERPA_ONNX_ENABLE_TTS)
add_executable(offline-tts-c-api offline-tts-c-api.c)
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
... ...
// c-api-examples/kws-c-api.c
//
// Copyright (c) 2025 Xiaomi Corporation
//
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
// clang-format off
//
// Usage
//
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
//
// ./kws-c-api
//
// clang-format on
#include <stdio.h>
#include <stdlib.h> // exit
#include <string.h> // memset
#include "sherpa-onnx/c-api/c-api.h"
int32_t main() {
SherpaOnnxKeywordSpotterConfig config;
memset(&config, 0, sizeof(config));
config.model_config.transducer.encoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.transducer.decoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.transducer.joiner =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.tokens =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
config.model_config.provider = "cpu";
config.model_config.num_threads = 1;
config.model_config.debug = 1;
config.keywords_file =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
"test_keywords.txt";
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config);
if (!kws) {
fprintf(stderr, "Please check your config");
exit(-1);
}
fprintf(stderr,
"--Test pre-defined keywords from test_wavs/test_keywords.txt--\n");
const char *wav_filename =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
float tail_paddings[8000] = {0}; // 0.5 seconds
const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read %s\n", wav_filename);
exit(-1);
}
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
if (!stream) {
fprintf(stderr, "Failed to create stream\n");
exit(-1);
}
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);
// Remember to reset the keyword stream right after a keyword is detected
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);
// --------------------------------------------------------------------------
fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n");
stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员");
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);
// Remember to reset the keyword stream
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);
// --------------------------------------------------------------------------
fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n");
stream = SherpaOnnxCreateKeywordStreamWithKeywords(
kws, "y ǎn y uán @演员/zh ī m íng @知名");
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
wave->num_samples);
SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
sizeof(tail_paddings) / sizeof(float));
SherpaOnnxOnlineStreamInputFinished(stream);
while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
SherpaOnnxDecodeKeywordStream(kws, stream);
const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
if (r && r->json && strlen(r->keyword)) {
fprintf(stderr, "Detected keyword: %s\n", r->json);
// Remember to reset the keyword stream
SherpaOnnxResetKeywordStream(kws, stream);
}
SherpaOnnxDestroyKeywordResult(r);
}
SherpaOnnxDestroyOnlineStream(stream);
SherpaOnnxFreeWave(wave);
SherpaOnnxDestroyKeywordSpotter(kws);
return 0;
}
... ...
... ... @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
add_executable(kws-cxx-api ./kws-cxx-api.cc)
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
... ...
// cxx-api-examples/kws-cxx-api.cc
//
// Copyright (c) 2025 Xiaomi Corporation
//
// This file demonstrates how to use keywords spotter with sherpa-onnx's C
// clang-format off
//
// Usage
//
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
//
// ./kws-cxx-api
//
// clang-format on
#include <array>
#include <iostream>
#include "sherpa-onnx/c-api/cxx-api.h"
int32_t main() {
using namespace sherpa_onnx::cxx; // NOLINT
KeywordSpotterConfig config;
config.model_config.transducer.encoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.transducer.decoder =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.transducer.joiner =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
"joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
config.model_config.tokens =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
config.model_config.provider = "cpu";
config.model_config.num_threads = 1;
config.model_config.debug = 1;
config.keywords_file =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
"test_keywords.txt";
KeywordSpotter kws = KeywordSpotter::Create(config);
if (!kws.Get()) {
std::cerr << "Please check your config\n";
return -1;
}
std::cout
<< "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n";
std::string wave_filename =
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds
Wave wave = ReadWave(wave_filename);
if (wave.samples.empty()) {
std::cerr << "Failed to read: '" << wave_filename << "'\n";
return -1;
}
OnlineStream stream = kws.CreateStream();
if (!stream.Get()) {
std::cerr << "Failed to create stream\n";
return -1;
}
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
wave.samples.size());
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
tail_paddings.size());
stream.InputFinished();
while (kws.IsReady(&stream)) {
kws.Decode(&stream);
auto r = kws.GetResult(&stream);
if (!r.keyword.empty()) {
std::cout << "Detected keyword: " << r.json << "\n";
// Remember to reset the keyword stream right after a keyword is detected
kws.Reset(&stream);
}
}
// --------------------------------------------------------------------------
std::cout << "--Use pre-defined keywords + add a new keyword--\n";
stream = kws.CreateStream("y ǎn y uán @演员");
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
wave.samples.size());
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
tail_paddings.size());
stream.InputFinished();
while (kws.IsReady(&stream)) {
kws.Decode(&stream);
auto r = kws.GetResult(&stream);
if (!r.keyword.empty()) {
std::cout << "Detected keyword: " << r.json << "\n";
// Remember to reset the keyword stream right after a keyword is detected
kws.Reset(&stream);
}
}
// --------------------------------------------------------------------------
std::cout << "--Use pre-defined keywords + add two new keywords--\n";
stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名");
stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
wave.samples.size());
stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
tail_paddings.size());
stream.InputFinished();
while (kws.IsReady(&stream)) {
kws.Decode(&stream);
auto r = kws.GetResult(&stream);
if (!r.keyword.empty()) {
std::cout << "Detected keyword: " << r.json << "\n";
// Remember to reset the keyword stream right after a keyword is detected
kws.Reset(&stream);
}
}
return 0;
}
... ...
... ... @@ -73,6 +73,8 @@ void main(List<String> arguments) async {
spotter.decode(stream);
final result = spotter.getResult(stream);
if (result.keyword != '') {
// Remember to reset the stream right after detecting a keyword
spotter.reset(stream);
print('Detected: ${result.keyword}');
}
}
... ...
... ... @@ -53,6 +53,8 @@ class KeywordSpotterDemo
var result = kws.GetResult(s);
if (result.Keyword != string.Empty)
{
// Remember to call Reset() right after detecting a keyword
kws.Reset(s);
Console.WriteLine("Detected: {0}", result.Keyword);
}
}
... ... @@ -70,6 +72,8 @@ class KeywordSpotterDemo
var result = kws.GetResult(s);
if (result.Keyword != string.Empty)
{
// Remember to call Reset() right after detecting a keyword
kws.Reset(s);
Console.WriteLine("Detected: {0}", result.Keyword);
}
}
... ... @@ -89,6 +93,8 @@ class KeywordSpotterDemo
var result = kws.GetResult(s);
if (result.Keyword != string.Empty)
{
// Remember to call Reset() right after detecting a keyword
kws.Reset(s);
Console.WriteLine("Detected: {0}", result.Keyword);
}
}
... ...
... ... @@ -107,12 +107,15 @@ class KeywordSpotterDemo
while (kws.IsReady(s))
{
kws.Decode(s);
}
var result = kws.GetResult(s);
if (result.Keyword != string.Empty)
{
Console.WriteLine("Detected: {0}", result.Keyword);
var result = kws.GetResult(s);
if (result.Keyword != string.Empty)
{
// Remember to call Reset() right after detecting a keyword
kws.Reset(s);
Console.WriteLine("Detected: {0}", result.Keyword);
}
}
Thread.Sleep(200); // ms
... ...
... ... @@ -168,6 +168,10 @@ class KeywordSpotter {
SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr);
}
void reset(OnlineStream stream) {
SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr);
}
Pointer<SherpaOnnxKeywordSpotter> ptr;
KeywordSpotterConfig config;
}
... ...
... ... @@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function(
typedef DecodeKeywordStream = void Function(
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
typedef ResetKeywordStreamNative = Void Function(
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
typedef ResetKeywordStream = void Function(
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function(
Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
... ... @@ -1157,6 +1163,7 @@ class SherpaOnnxBindings {
static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords;
static IsKeywordStreamReady? isKeywordStreamReady;
static DecodeKeywordStream? decodeKeywordStream;
static ResetKeywordStream? resetKeywordStream;
static GetKeywordResultAsJson? getKeywordResultAsJson;
static FreeKeywordResultJson? freeKeywordResultJson;
... ... @@ -1459,6 +1466,11 @@ class SherpaOnnxBindings {
'SherpaOnnxDecodeKeywordStream')
.asFunction();
resetKeywordStream ??= dynamicLibrary
.lookup<NativeFunction<ResetKeywordStreamNative>>(
'SherpaOnnxResetKeywordStream')
.asFunction();
getKeywordResultAsJson ??= dynamicLibrary
.lookup<NativeFunction<GetKeywordResultAsJsonNative>>(
'SherpaOnnxGetKeywordResultAsJson')
... ...
... ... @@ -43,6 +43,8 @@ func main() {
spotter.Decode(stream)
result := spotter.GetResult(stream)
if result.Keyword != "" {
// You have to reset the stream right after detecting a keyword
spotter.Reset(stream)
log.Printf("Detected %v\n", result.Keyword)
}
}
... ...
... ... @@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf);
SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize);
SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
if (c.model_config.transducer.encoder) {
delete[] c.model_config.transducer.encoder;
... ... @@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
}
return Napi::External<SherpaOnnxKeywordSpotter>::New(
env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
env, const_cast<SherpaOnnxKeywordSpotter *>(kws),
[](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
SherpaOnnxDestroyKeywordSpotter(kws);
});
}
... ... @@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper(
return {};
}
SherpaOnnxKeywordSpotter *kws =
const SherpaOnnxKeywordSpotter *kws =
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
return Napi::External<SherpaOnnxOnlineStream>::New(
env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) {
env, const_cast<SherpaOnnxOnlineStream *>(stream),
[](Napi::Env env, SherpaOnnxOnlineStream *stream) {
SherpaOnnxDestroyOnlineStream(stream);
});
}
... ... @@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper(
return {};
}
SherpaOnnxKeywordSpotter *kws =
const SherpaOnnxKeywordSpotter *kws =
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
SherpaOnnxOnlineStream *stream =
const SherpaOnnxOnlineStream *stream =
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream);
... ... @@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) {
return;
}
SherpaOnnxKeywordSpotter *kws =
const SherpaOnnxKeywordSpotter *kws =
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
SherpaOnnxOnlineStream *stream =
const SherpaOnnxOnlineStream *stream =
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
SherpaOnnxDecodeKeywordStream(kws, stream);
}
static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return;
}
if (!info[0].IsExternal()) {
Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.")
.ThrowAsJavaScriptException();
return;
}
if (!info[1].IsExternal()) {
Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.")
.ThrowAsJavaScriptException();
return;
}
const SherpaOnnxKeywordSpotter *kws =
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
const SherpaOnnxOnlineStream *stream =
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
SherpaOnnxResetKeywordStream(kws, stream);
}
static Napi::String GetKeywordResultAsJsonWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
... ... @@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper(
return {};
}
SherpaOnnxKeywordSpotter *kws =
const SherpaOnnxKeywordSpotter *kws =
info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
SherpaOnnxOnlineStream *stream =
const SherpaOnnxOnlineStream *stream =
info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream);
... ... @@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "decodeKeywordStream"),
Napi::Function::New(env, DecodeKeywordStreamWrapper));
exports.Set(Napi::String::New(env, "resetKeywordStream"),
Napi::Function::New(env, ResetKeywordStreamWrapper));
exports.Set(Napi::String::New(env, "getKeywordResultAsJson"),
Napi::Function::New(env, GetKeywordResultAsJsonWrapper));
}
... ...
... ... @@ -56,6 +56,8 @@ public class KyewordSpotterFromFile {
String keyword = kws.getResult(stream).getKeyword();
if (!keyword.isEmpty()) {
// Remember to reset the stream right after detecting a keyword
kws.reset(stream);
System.out.printf("Detected keyword: %s\n", keyword);
}
}
... ...
... ... @@ -41,6 +41,9 @@ while (kws.isReady(stream)) {
const keyword = kws.getResult(stream).keyword;
if (keyword != '') {
detectedKeywords.push(keyword);
// remember to reset the stream right after detecting a keyword
kws.reset(stream);
}
}
console.log(detectedKeywords);
... ...
... ... @@ -169,6 +169,8 @@ def main():
print("Started! Please speak")
idx = 0
sample_rate = 16000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
stream = keyword_spotter.create_stream()
... ... @@ -179,9 +181,12 @@ def main():
stream.accept_waveform(sample_rate, samples)
while keyword_spotter.is_ready(stream):
keyword_spotter.decode_stream(stream)
result = keyword_spotter.get_result(stream)
if result:
print("\r{}".format(result), end="", flush=True)
result = keyword_spotter.get_result(stream)
if result:
print(f"{idx}: {result }")
idx += 1
# Remember to reset stream right after detecting a keyword
keyword_spotter.reset_stream(stream)
if __name__ == "__main__":
... ...
... ... @@ -18,122 +18,6 @@ import numpy as np
import sherpa_onnx
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--encoder",
type=str,
help="Path to the transducer encoder model",
)
parser.add_argument(
"--decoder",
type=str,
help="Path to the transducer decoder model",
)
parser.add_argument(
"--joiner",
type=str,
help="Path to the transducer joiner model",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--num-trailing-blanks",
type=int,
default=1,
help="""The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
""",
)
parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
▁HE LL O ▁WORLD
x iǎo ài t óng x ué
""",
)
parser.add_argument(
"--keywords-score",
type=float,
default=1.0,
help="""
The boosting score of each token for keywords. The larger the easier to
survive beam search.
""",
)
parser.add_argument(
"--keywords-threshold",
type=float,
default=0.25,
help="""
The trigger threshold (i.e. probability) of the keyword. The larger the
harder to trigger.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
... ... @@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
return samples_float32, f.getframerate()
def main():
args = get_args()
assert_file_exists(args.tokens)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert Path(
args.keywords_file
).is_file(), (
f"keywords_file : {args.keywords_file} not exist, please provide a valid path."
def create_keyword_spotter():
kws = sherpa_onnx.KeywordSpotter(
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
num_threads=2,
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
provider="cpu",
)
keyword_spotter = sherpa_onnx.KeywordSpotter(
tokens=args.tokens,
encoder=args.encoder,
decoder=args.decoder,
joiner=args.joiner,
num_threads=args.num_threads,
max_active_paths=args.max_active_paths,
keywords_file=args.keywords_file,
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
provider=args.provider,
)
return kws
print("Started!")
start_time = time.time()
streams = []
total_duration = 0
for wave_filename in args.sound_files:
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = keyword_spotter.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
streams.append(s)
results = [""] * len(streams)
while True:
ready_list = []
for i, s in enumerate(streams):
if keyword_spotter.is_ready(s):
ready_list.append(s)
r = keyword_spotter.get_result(s)
if r:
results[i] += f"{r}/"
print(f"{r} is detected.")
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
end_time = time.time()
print("Done!")
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
def main():
kws = create_keyword_spotter()
wave_filename = (
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
)
samples, sample_rate = read_wave(wave_filename)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
print("----------Use pre-defined keywords----------")
s = kws.create_stream()
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add a new keyword----------")
s = kws.create_stream("y ǎn y uán @演员")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add 2 new keywords----------")
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
if __name__ == "__main__":
... ...
... ... @@ -46,6 +46,11 @@ namespace SherpaOnnx
Decode(_handle.Handle, stream.Handle);
}
public void Reset(OnlineStream stream)
{
Reset(_handle.Handle, stream.Handle);
}
// The caller should ensure all passed streams are ready for decoding.
public void Decode(IEnumerable<OnlineStream> streams)
{
... ... @@ -110,6 +115,9 @@ namespace SherpaOnnx
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")]
private static extern void Decode(IntPtr handle, IntPtr stream);
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")]
private static extern void Reset(IntPtr handle, IntPtr stream);
[DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")]
private static extern void Decode(IntPtr handle, IntPtr[] streams, int n);
... ...
... ... @@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) {
C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
}
// You MUST call it right after detecting a keyword
func (spotter *KeywordSpotter) Reset(s *OnlineStream) {
C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl)
}
// Get the current result of stream since the last invoke of Reset()
func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
... ...
... ... @@ -20,6 +20,10 @@ class KeywordSpotter {
addon.decodeKeywordStream(this.handle, stream.handle);
}
reset(stream) {
addon.resetKeywordStream(this.handle, stream.handle);
}
getResult(stream) {
const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle);
... ...
... ... @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter {
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
};
SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig *config) {
sherpa_onnx::KeywordSpotterConfig spotter_config;
... ... @@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
return spotter;
}
void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) {
delete spotter;
}
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
const SherpaOnnxKeywordSpotter *spotter) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
return stream;
}
SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
const SherpaOnnxKeywordSpotter *spotter, const char *keywords) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords));
return stream;
}
int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream) {
int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream) {
return spotter->impl->IsReady(stream->impl.get());
}
void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream) {
return spotter->impl->DecodeStream(stream->impl.get());
void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream) {
spotter->impl->DecodeStream(stream->impl.get());
}
void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream **streams,
int32_t n) {
void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream) {
spotter->impl->Reset(stream->impl.get());
}
void SherpaOnnxDecodeMultipleKeywordStreams(
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream **streams, int32_t n) {
std::vector<sherpa_onnx::OnlineStream *> ss(n);
for (int32_t i = 0; i != n; ++i) {
ss[i] = streams[i]->impl.get();
... ... @@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
}
const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream) {
const sherpa_onnx::KeywordResult &result =
spotter->impl->GetResult(stream->impl.get());
const auto &keyword = result.keyword;
... ... @@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
}
}
const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream) {
const char *SherpaOnnxGetKeywordResultAsJson(
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream) {
const sherpa_onnx::KeywordResult &result =
spotter->impl->GetResult(stream->impl.get());
... ...
... ... @@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson(
SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s);
// ============================================================
// For Keyword Spot
// For Keyword Spotter
// ============================================================
SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
/// The triggered keyword.
... ... @@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
/// @param config Config for the keyword spotter.
/// @return Return a pointer to the spotter. The user has to invoke
/// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig *config);
/// Free a pointer returned by SherpaOnnxCreateKeywordSpotter()
///
/// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter()
SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter(
SherpaOnnxKeywordSpotter *spotter);
const SherpaOnnxKeywordSpotter *spotter);
/// Create an online stream for accepting wave samples.
///
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter()
/// @return Return a pointer to an OnlineStream. The user has to invoke
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
const SherpaOnnxKeywordSpotter *spotter);
/// Create an online stream for accepting wave samples with the specified hot
... ... @@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
/// @param keywords A pointer points to the keywords that you set
/// @return Return a pointer to an OnlineStream. The user has to invoke
/// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxOnlineStream *
SHERPA_ONNX_API const SherpaOnnxOnlineStream *
SherpaOnnxCreateKeywordStreamWithKeywords(
const SherpaOnnxKeywordSpotter *spotter, const char *keywords);
... ... @@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords(
///
/// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter
/// @param stream A pointer returned by SherpaOnnxCreateKeywordStream
SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API int32_t
SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream);
/// Call this function to run the neural network model and decoding.
//
/// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST
/// return 1.
SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream);
/// Please call it right after a keyword is detected
SHERPA_ONNX_API void SherpaOnnxResetKeywordStream(
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream);
/// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes
/// multiple OnlineStream in parallel.
... ... @@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
/// SherpaOnnxCreateKeywordStream()
/// @param n Number of elements in the given streams array.
SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
int32_t n);
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream **streams, int32_t n);
/// Get the decoding results so far for an OnlineStream.
///
... ... @@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
/// SherpaOnnxDestroyKeywordResult() to free the returned pointer to
/// avoid memory leak.
SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream);
/// Destroy the pointer returned by SherpaOnnxGetKeywordResult().
///
... ... @@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult(
// the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned
// pointer to avoid memory leak
SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
const SherpaOnnxKeywordSpotter *spotter,
const SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s);
... ...
... ... @@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text,
return ans;
}
KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
struct SherpaOnnxKeywordSpotterConfig c;
memset(&c, 0, sizeof(c));
c.feat_config.sample_rate = config.feat_config.sample_rate;
c.model_config.transducer.encoder =
config.model_config.transducer.encoder.c_str();
c.model_config.transducer.decoder =
config.model_config.transducer.decoder.c_str();
c.model_config.transducer.joiner =
config.model_config.transducer.joiner.c_str();
c.feat_config.feature_dim = config.feat_config.feature_dim;
c.model_config.paraformer.encoder =
config.model_config.paraformer.encoder.c_str();
c.model_config.paraformer.decoder =
config.model_config.paraformer.decoder.c_str();
c.model_config.zipformer2_ctc.model =
config.model_config.zipformer2_ctc.model.c_str();
c.model_config.tokens = config.model_config.tokens.c_str();
c.model_config.num_threads = config.model_config.num_threads;
c.model_config.provider = config.model_config.provider.c_str();
c.model_config.debug = config.model_config.debug;
c.model_config.model_type = config.model_config.model_type.c_str();
c.model_config.modeling_unit = config.model_config.modeling_unit.c_str();
c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str();
c.model_config.tokens_buf = config.model_config.tokens_buf.c_str();
c.model_config.tokens_buf_size = config.model_config.tokens_buf.size();
c.max_active_paths = config.max_active_paths;
c.num_trailing_blanks = config.num_trailing_blanks;
c.keywords_score = config.keywords_score;
c.keywords_threshold = config.keywords_threshold;
c.keywords_file = config.keywords_file.c_str();
auto p = SherpaOnnxCreateKeywordSpotter(&c);
return KeywordSpotter(p);
}
KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p)
: MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter>(p) {}
void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const {
SherpaOnnxDestroyKeywordSpotter(p);
}
OnlineStream KeywordSpotter::CreateStream() const {
auto s = SherpaOnnxCreateKeywordStream(p_);
return OnlineStream{s};
}
OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const {
auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str());
return OnlineStream{s};
}
bool KeywordSpotter::IsReady(const OnlineStream *s) const {
return SherpaOnnxIsKeywordStreamReady(p_, s->Get());
}
void KeywordSpotter::Decode(const OnlineStream *s) const {
return SherpaOnnxDecodeKeywordStream(p_, s->Get());
}
void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const {
if (n <= 0) {
return;
}
std::vector<const SherpaOnnxOnlineStream *> streams(n);
for (int32_t i = 0; i != n; ++n) {
streams[i] = ss[i].Get();
}
SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n);
}
KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const {
auto r = SherpaOnnxGetKeywordResult(p_, s->Get());
KeywordResult ans;
ans.keyword = r->keyword;
ans.tokens.resize(r->count);
for (int32_t i = 0; i < r->count; ++i) {
ans.tokens[i] = r->tokens_arr[i];
}
if (r->timestamps) {
ans.timestamps.resize(r->count);
std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data());
}
ans.start_time = r->start_time;
ans.json = r->json;
SherpaOnnxDestroyKeywordResult(r);
return ans;
}
void KeywordSpotter::Reset(const OnlineStream *s) const {
SherpaOnnxResetKeywordStream(p_, s->Get());
}
} // namespace sherpa_onnx::cxx
... ...
... ... @@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts
explicit OfflineTts(const SherpaOnnxOfflineTts *p);
};
// ============================================================
// For Keyword Spotter
// ============================================================
struct KeywordResult {
std::string keyword;
std::vector<std::string> tokens;
std::vector<float> timestamps;
float start_time;
std::string json;
};
struct KeywordSpotterConfig {
FeatureConfig feat_config;
OnlineModelConfig model_config;
int32_t max_active_paths = 4;
int32_t num_trailing_blanks = 1;
float keywords_score = 1.0f;
float keywords_threshold = 0.25f;
std::string keywords_file;
};
class SHERPA_ONNX_API KeywordSpotter
: public MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter> {
public:
static KeywordSpotter Create(const KeywordSpotterConfig &config);
void Destroy(const SherpaOnnxKeywordSpotter *p) const;
OnlineStream CreateStream() const;
OnlineStream CreateStream(const std::string &keywords) const;
bool IsReady(const OnlineStream *s) const;
void Decode(const OnlineStream *s) const;
void Decode(const OnlineStream *ss, int32_t n) const;
void Reset(const OnlineStream *s) const;
KeywordResult GetResult(const OnlineStream *s) const;
private:
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
};
} // namespace sherpa_onnx::cxx
#endif // SHERPA_ONNX_C_API_CXX_API_H_
... ...
... ... @@ -38,6 +38,8 @@ class KeywordSpotterImpl {
virtual bool IsReady(OnlineStream *s) const = 0;
virtual void Reset(OnlineStream *s) const = 0;
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
... ...
... ... @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
return s->GetNumProcessedFrames() + model_->ChunkSize() <
s->NumFramesReady();
}
void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
auto s = ss[i];
auto r = s->GetKeywordResult(true);
int32_t num_trailing_blanks = r.num_trailing_blanks;
// assume subsampling_factor is 4
// assume frameshift is 0.01 second
float trailing_slience = num_trailing_blanks * 4 * 0.01;
// it resets automatically after detecting 1.5 seconds of silence
float threshold = 1.5;
if (trailing_slience > threshold) {
Reset(s);
}
}
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
... ...
... ... @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
... ...
... ... @@ -129,6 +129,9 @@ class KeywordSpotter {
*/
bool IsReady(OnlineStream *s) const;
// Remember to call it after detecting a keyword
void Reset(OnlineStream *s) const;
/** Decode a single stream. */
void DecodeStream(OnlineStream *s) const {
OnlineStream *ss[1] = {s};
... ...
... ... @@ -106,13 +106,15 @@ as the device_name.
while (spotter.IsReady(stream.get())) {
spotter.DecodeStream(stream.get());
}
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(stream.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
spotter.Reset(stream.get());
}
}
}
... ...
... ... @@ -150,13 +150,15 @@ for a list of pre-trained models to download.
while (!stop) {
while (spotter.IsReady(s.get())) {
spotter.DecodeStream(s.get());
}
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
const auto r = spotter.GetResult(s.get());
if (!r.keyword.empty()) {
display.Print(keyword_index, r.AsJsonString());
fflush(stderr);
keyword_index++;
spotter.Reset(s.get());
}
}
Pa_Sleep(20); // sleep for 20ms
... ...
... ... @@ -27,6 +27,10 @@ public class KeywordSpotter {
decode(ptr, s.getPtr());
}
public void reset(OnlineStream s) {
reset(ptr, s.getPtr());
}
public boolean isReady(OnlineStream s) {
return isReady(ptr, s.getPtr());
}
... ... @@ -60,6 +64,8 @@ public class KeywordSpotter {
private native void decode(long ptr, long streamPtr);
private native void reset(long ptr, long streamPtr);
private native boolean isReady(long ptr, long streamPtr);
private native Object[] getResult(long ptr, long streamPtr);
... ...
... ... @@ -162,6 +162,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
kws->Reset(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
... ...
... ... @@ -49,6 +49,7 @@ class KeywordSpotter(
}
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
fun getResult(stream: OnlineStream): KeywordSpotterResult {
val objArray = getResult(ptr, stream.ptr)
... ... @@ -74,6 +75,7 @@ class KeywordSpotter(
private external fun createStream(ptr: Long, keywords: String): Long
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
private external fun decode(ptr: Long, streamPtr: Long)
private external fun reset(ptr: Long, streamPtr: Long)
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
companion object {
... ...
... ... @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>())
.def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
.def("decode_stream", &PyClass::DecodeStream,
py::call_guard<py::gil_scoped_release>())
.def(
... ...
... ... @@ -104,8 +104,8 @@ class KeywordSpotter(object):
)
provider_config = ProviderConfig(
provider=provider,
device = device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -131,6 +131,9 @@ class KeywordSpotter(object):
)
self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
def reset_stream(self, s: OnlineStream):
self.keyword_spotter.reset(s)
def create_stream(self, keywords: Optional[str] = None):
if keywords is None:
return self.keyword_spotter.create_stream()
... ...
... ... @@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase):
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
keyword_spotter.reset_stream(s)
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
... ... @@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase):
if r:
print(f"{r} is detected.")
results[i] += f"{r}/"
keyword_spotter.reset_stream(s)
if len(ready_list) == 0:
break
keyword_spotter.decode_streams(ready_list)
... ...
... ... @@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper {
SherpaOnnxDecodeKeywordStream(spotter, stream)
}
func reset() {
SherpaOnnxResetKeywordStream(spotter, stream)
}
func getResult() -> SherpaOnnxKeywordResultWrapper {
let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult(
spotter, stream)
... ...
... ... @@ -70,6 +70,9 @@ func run() {
spotter.decode()
let keyword = spotter.getResult().keyword
if keyword != "" {
// Remember to call reset() right after detecting a keyword
spotter.reset()
print("Detected: \(keyword)")
}
}
... ...
... ... @@ -17,6 +17,7 @@ set(exported_functions
SherpaOnnxIsKeywordStreamReady
SherpaOnnxOnlineStreamAcceptWaveform
SherpaOnnxOnlineStreamInputFinished
SherpaOnnxResetKeywordStream
)
set(mangled_exported_functions)
foreach(x IN LISTS exported_functions)
... ...
... ... @@ -102,15 +102,17 @@ if (navigator.mediaDevices.getUserMedia) {
recognizer_stream.acceptWaveform(expectedSampleRate, samples);
while (recognizer.isReady(recognizer_stream)) {
recognizer.decode(recognizer_stream);
}
let result = recognizer.getResult(recognizer_stream);
let result = recognizer.getResult(recognizer_stream);
if (result.keyword.length > 0) {
console.log(result)
lastResult = result;
resultList.push(JSON.stringify(result));
if (result.keyword.length > 0) {
console.log(result)
lastResult = result;
resultList.push(JSON.stringify(result));
// remember to reset the stream right after detecting a keyword
recognizer.reset(recognizer_stream);
}
}
... ...
... ... @@ -296,8 +296,11 @@ class Kws {
}
decode(stream) {
return this.Module._SherpaOnnxDecodeKeywordStream(
this.handle, stream.handle);
this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle);
}
reset(stream) {
this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle);
}
getResult(stream) {
... ...