Fangjun Kuang
Committed by GitHub

Fix keyword spotting. (#1689)

Reset the stream right after detecting a keyword
正在显示 43 个修改的文件 包含 823 行增加303 行删除
@@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version" @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version"
574 pwd 574 pwd
575 ls -lh 575 ls -lh
576 576
577 -repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01  
578 -log "Start testing ${repo}"  
579 -  
580 -pushd $dir  
581 -curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz  
582 -tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz  
583 -rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz  
584 -popd  
585 -  
586 -repo=$dir/$repo  
587 -ls -lh $repo  
588 -  
589 -python3 ./python-api-examples/keyword-spotter.py \  
590 - --tokens=$repo/tokens.txt \  
591 - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \  
592 - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \  
593 - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \  
594 - --keywords-file=$repo/test_wavs/test_keywords.txt \  
595 - $repo/test_wavs/0.wav \  
596 - $repo/test_wavs/1.wav  
597 -  
598 -rm -rf $repo  
599 -  
600 if [[ x$OS != x'windows-latest' ]]; then 577 if [[ x$OS != x'windows-latest' ]]; then
601 echo "OS: $OS" 578 echo "OS: $OS"
602 579
@@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then
612 repo=$dir/$repo 589 repo=$dir/$repo
613 ls -lh $repo 590 ls -lh $repo
614 591
615 - python3 ./python-api-examples/keyword-spotter.py \  
616 - --tokens=$repo/tokens.txt \  
617 - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \  
618 - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \  
619 - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \  
620 - --keywords-file=$repo/test_wavs/test_keywords.txt \  
621 - $repo/test_wavs/3.wav \  
622 - $repo/test_wavs/4.wav \  
623 - $repo/test_wavs/5.wav 592 + python3 ./python-api-examples/keyword-spotter.py
624 593
625 python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose 594 python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose
626 595
@@ -79,6 +79,27 @@ jobs: @@ -79,6 +79,27 @@ jobs:
79 otool -L ./install/lib/libsherpa-onnx-c-api.dylib 79 otool -L ./install/lib/libsherpa-onnx-c-api.dylib
80 fi 80 fi
81 81
  82 + - name: Test kws (zh)
  83 + shell: bash
  84 + run: |
  85 + gcc -o kws-c-api ./c-api-examples/kws-c-api.c \
  86 + -I ./build/install/include \
  87 + -L ./build/install/lib/ \
  88 + -l sherpa-onnx-c-api \
  89 + -l onnxruntime
  90 +
  91 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  92 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  93 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  94 +
  95 + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
  96 + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
  97 +
  98 + ./kws-c-api
  99 +
  100 + rm ./kws-c-api
  101 + rm -rf sherpa-onnx-kws-*
  102 +
82 - name: Test Kokoro TTS (en) 103 - name: Test Kokoro TTS (en)
83 shell: bash 104 shell: bash
84 run: | 105 run: |
@@ -81,6 +81,28 @@ jobs: @@ -81,6 +81,28 @@ jobs:
81 otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib 81 otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
82 fi 82 fi
83 83
  84 + - name: Test KWS (zh)
  85 + shell: bash
  86 + run: |
  87 + g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \
  88 + -I ./build/install/include \
  89 + -L ./build/install/lib/ \
  90 + -l sherpa-onnx-cxx-api \
  91 + -l sherpa-onnx-c-api \
  92 + -l onnxruntime
  93 +
  94 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  95 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  96 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  97 +
  98 + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
  99 + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
  100 +
  101 + ./kws-cxx-api
  102 +
  103 + rm kws-cxx-api
  104 + rm -rf sherpa-onnx-kws-*
  105 +
84 - name: Test Kokoro TTS (en) 106 - name: Test Kokoro TTS (en)
85 shell: bash 107 shell: bash
86 run: | 108 run: |
@@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() { @@ -151,24 +151,27 @@ class MainActivity : AppCompatActivity() {
151 stream.acceptWaveform(samples, sampleRate = sampleRateInHz) 151 stream.acceptWaveform(samples, sampleRate = sampleRateInHz)
152 while (kws.isReady(stream)) { 152 while (kws.isReady(stream)) {
153 kws.decode(stream) 153 kws.decode(stream)
154 - }  
155 154
156 - val text = kws.getResult(stream).keyword 155 + val text = kws.getResult(stream).keyword
  156 +
  157 + var textToDisplay = lastText
157 158
158 - var textToDisplay = lastText 159 + if (text.isNotBlank()) {
  160 + // Remember to reset the stream right after detecting a keyword
159 161
160 - if (text.isNotBlank()) {  
161 - if (lastText.isBlank()) {  
162 - textToDisplay = "$idx: $text"  
163 - } else {  
164 - textToDisplay = "$idx: $text\n$lastText" 162 + kws.reset(stream)
  163 + if (lastText.isBlank()) {
  164 + textToDisplay = "$idx: $text"
  165 + } else {
  166 + textToDisplay = "$idx: $text\n$lastText"
  167 + }
  168 + lastText = "$idx: $text\n$lastText"
  169 + idx += 1
165 } 170 }
166 - lastText = "$idx: $text\n$lastText"  
167 - idx += 1  
168 - }  
169 171
170 - runOnUiThread {  
171 - textView.text = textToDisplay 172 + runOnUiThread {
  173 + textView.text = textToDisplay
  174 + }
172 } 175 }
173 } 176 }
174 } 177 }
@@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR})
4 add_executable(decode-file-c-api decode-file-c-api.c) 4 add_executable(decode-file-c-api decode-file-c-api.c)
5 target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs) 5 target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs)
6 6
  7 +add_executable(kws-c-api kws-c-api.c)
  8 +target_link_libraries(kws-c-api sherpa-onnx-c-api)
  9 +
7 if(SHERPA_ONNX_ENABLE_TTS) 10 if(SHERPA_ONNX_ENABLE_TTS)
8 add_executable(offline-tts-c-api offline-tts-c-api.c) 11 add_executable(offline-tts-c-api offline-tts-c-api.c)
9 target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs) 12 target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
  1 +// c-api-examples/kws-c-api.c
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +//
  5 +// This file demonstrates how to use keywords spotter with sherpa-onnx's C
  6 +// clang-format off
  7 +//
  8 +// Usage
  9 +//
  10 +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  11 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  12 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  13 +//
  14 +// ./kws-c-api
  15 +//
  16 +// clang-format on
  17 +#include <stdio.h>
  18 +#include <stdlib.h> // exit
  19 +#include <string.h> // memset
  20 +
  21 +#include "sherpa-onnx/c-api/c-api.h"
  22 +
  23 +int32_t main() {
  24 + SherpaOnnxKeywordSpotterConfig config;
  25 +
  26 + memset(&config, 0, sizeof(config));
  27 + config.model_config.transducer.encoder =
  28 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  29 + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  30 +
  31 + config.model_config.transducer.decoder =
  32 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  33 + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  34 +
  35 + config.model_config.transducer.joiner =
  36 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  37 + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
  38 +
  39 + config.model_config.tokens =
  40 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
  41 +
  42 + config.model_config.provider = "cpu";
  43 + config.model_config.num_threads = 1;
  44 + config.model_config.debug = 1;
  45 +
  46 + config.keywords_file =
  47 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
  48 + "test_keywords.txt";
  49 +
  50 + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config);
  51 + if (!kws) {
  52 + fprintf(stderr, "Please check your config");
  53 + exit(-1);
  54 + }
  55 +
  56 + fprintf(stderr,
  57 + "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n");
  58 +
  59 + const char *wav_filename =
  60 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
  61 +
  62 + float tail_paddings[8000] = {0}; // 0.5 seconds
  63 +
  64 + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
  65 + if (wave == NULL) {
  66 + fprintf(stderr, "Failed to read %s\n", wav_filename);
  67 + exit(-1);
  68 + }
  69 +
  70 + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
  71 + if (!stream) {
  72 + fprintf(stderr, "Failed to create stream\n");
  73 + exit(-1);
  74 + }
  75 +
  76 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
  77 + wave->num_samples);
  78 +
  79 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
  80 + sizeof(tail_paddings) / sizeof(float));
  81 + SherpaOnnxOnlineStreamInputFinished(stream);
  82 + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
  83 + SherpaOnnxDecodeKeywordStream(kws, stream);
  84 + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
  85 + if (r && r->json && strlen(r->keyword)) {
  86 + fprintf(stderr, "Detected keyword: %s\n", r->json);
  87 +
  88 + // Remember to reset the keyword stream right after a keyword is detected
  89 + SherpaOnnxResetKeywordStream(kws, stream);
  90 + }
  91 + SherpaOnnxDestroyKeywordResult(r);
  92 + }
  93 + SherpaOnnxDestroyOnlineStream(stream);
  94 +
  95 + // --------------------------------------------------------------------------
  96 +
  97 + fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n");
  98 +
  99 + stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员");
  100 +
  101 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
  102 + wave->num_samples);
  103 +
  104 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
  105 + sizeof(tail_paddings) / sizeof(float));
  106 + SherpaOnnxOnlineStreamInputFinished(stream);
  107 + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
  108 + SherpaOnnxDecodeKeywordStream(kws, stream);
  109 + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
  110 + if (r && r->json && strlen(r->keyword)) {
  111 + fprintf(stderr, "Detected keyword: %s\n", r->json);
  112 +
  113 + // Remember to reset the keyword stream
  114 + SherpaOnnxResetKeywordStream(kws, stream);
  115 + }
  116 + SherpaOnnxDestroyKeywordResult(r);
  117 + }
  118 + SherpaOnnxDestroyOnlineStream(stream);
  119 +
  120 + // --------------------------------------------------------------------------
  121 +
  122 + fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n");
  123 +
  124 + stream = SherpaOnnxCreateKeywordStreamWithKeywords(
  125 + kws, "y ǎn y uán @演员/zh ī m íng @知名");
  126 +
  127 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples,
  128 + wave->num_samples);
  129 +
  130 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
  131 + sizeof(tail_paddings) / sizeof(float));
  132 + SherpaOnnxOnlineStreamInputFinished(stream);
  133 + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) {
  134 + SherpaOnnxDecodeKeywordStream(kws, stream);
  135 + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream);
  136 + if (r && r->json && strlen(r->keyword)) {
  137 + fprintf(stderr, "Detected keyword: %s\n", r->json);
  138 +
  139 + // Remember to reset the keyword stream
  140 + SherpaOnnxResetKeywordStream(kws, stream);
  141 + }
  142 + SherpaOnnxDestroyKeywordResult(r);
  143 + }
  144 + SherpaOnnxDestroyOnlineStream(stream);
  145 +
  146 + SherpaOnnxFreeWave(wave);
  147 + SherpaOnnxDestroyKeywordSpotter(kws);
  148 +
  149 + return 0;
  150 +}
@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) 3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) 4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
5 5
  6 +add_executable(kws-cxx-api ./kws-cxx-api.cc)
  7 +target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
  8 +
6 add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) 9 add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc)
7 target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api) 10 target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api)
8 11
  1 +// cxx-api-examples/kws-cxx-api.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +//
  5 +// This file demonstrates how to use keywords spotter with sherpa-onnx's C
  6 +// clang-format off
  7 +//
  8 +// Usage
  9 +//
  10 +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  11 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  12 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2
  13 +//
  14 +// ./kws-cxx-api
  15 +//
  16 +// clang-format on
  17 +#include <array>
  18 +#include <iostream>
  19 +
  20 +#include "sherpa-onnx/c-api/cxx-api.h"
  21 +
  22 +int32_t main() {
  23 + using namespace sherpa_onnx::cxx; // NOLINT
  24 +
  25 + KeywordSpotterConfig config;
  26 + config.model_config.transducer.encoder =
  27 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  28 + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  29 +
  30 + config.model_config.transducer.decoder =
  31 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  32 + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx";
  33 +
  34 + config.model_config.transducer.joiner =
  35 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/"
  36 + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx";
  37 +
  38 + config.model_config.tokens =
  39 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt";
  40 +
  41 + config.model_config.provider = "cpu";
  42 + config.model_config.num_threads = 1;
  43 + config.model_config.debug = 1;
  44 +
  45 + config.keywords_file =
  46 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/"
  47 + "test_keywords.txt";
  48 +
  49 + KeywordSpotter kws = KeywordSpotter::Create(config);
  50 + if (!kws.Get()) {
  51 + std::cerr << "Please check your config\n";
  52 + return -1;
  53 + }
  54 +
  55 + std::cout
  56 + << "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n";
  57 +
  58 + std::string wave_filename =
  59 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav";
  60 +
  61 + std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds
  62 +
  63 + Wave wave = ReadWave(wave_filename);
  64 + if (wave.samples.empty()) {
  65 + std::cerr << "Failed to read: '" << wave_filename << "'\n";
  66 + return -1;
  67 + }
  68 +
  69 + OnlineStream stream = kws.CreateStream();
  70 + if (!stream.Get()) {
  71 + std::cerr << "Failed to create stream\n";
  72 + return -1;
  73 + }
  74 +
  75 + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
  76 + wave.samples.size());
  77 +
  78 + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
  79 + tail_paddings.size());
  80 + stream.InputFinished();
  81 +
  82 + while (kws.IsReady(&stream)) {
  83 + kws.Decode(&stream);
  84 + auto r = kws.GetResult(&stream);
  85 + if (!r.keyword.empty()) {
  86 + std::cout << "Detected keyword: " << r.json << "\n";
  87 +
  88 + // Remember to reset the keyword stream right after a keyword is detected
  89 + kws.Reset(&stream);
  90 + }
  91 + }
  92 +
  93 + // --------------------------------------------------------------------------
  94 +
  95 + std::cout << "--Use pre-defined keywords + add a new keyword--\n";
  96 +
  97 + stream = kws.CreateStream("y ǎn y uán @演员");
  98 +
  99 + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
  100 + wave.samples.size());
  101 +
  102 + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
  103 + tail_paddings.size());
  104 + stream.InputFinished();
  105 +
  106 + while (kws.IsReady(&stream)) {
  107 + kws.Decode(&stream);
  108 + auto r = kws.GetResult(&stream);
  109 + if (!r.keyword.empty()) {
  110 + std::cout << "Detected keyword: " << r.json << "\n";
  111 +
  112 + // Remember to reset the keyword stream right after a keyword is detected
  113 + kws.Reset(&stream);
  114 + }
  115 + }
  116 +
  117 + // --------------------------------------------------------------------------
  118 +
  119 + std::cout << "--Use pre-defined keywords + add two new keywords--\n";
  120 +
  121 + stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名");
  122 +
  123 + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(),
  124 + wave.samples.size());
  125 +
  126 + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(),
  127 + tail_paddings.size());
  128 + stream.InputFinished();
  129 +
  130 + while (kws.IsReady(&stream)) {
  131 + kws.Decode(&stream);
  132 + auto r = kws.GetResult(&stream);
  133 + if (!r.keyword.empty()) {
  134 + std::cout << "Detected keyword: " << r.json << "\n";
  135 +
  136 + // Remember to reset the keyword stream right after a keyword is detected
  137 + kws.Reset(&stream);
  138 + }
  139 + }
  140 + return 0;
  141 +}
@@ -73,6 +73,8 @@ void main(List<String> arguments) async { @@ -73,6 +73,8 @@ void main(List<String> arguments) async {
73 spotter.decode(stream); 73 spotter.decode(stream);
74 final result = spotter.getResult(stream); 74 final result = spotter.getResult(stream);
75 if (result.keyword != '') { 75 if (result.keyword != '') {
  76 + // Remember to reset the stream right after detecting a keyword
  77 + spotter.reset(stream);
76 print('Detected: ${result.keyword}'); 78 print('Detected: ${result.keyword}');
77 } 79 }
78 } 80 }
@@ -53,6 +53,8 @@ class KeywordSpotterDemo @@ -53,6 +53,8 @@ class KeywordSpotterDemo
53 var result = kws.GetResult(s); 53 var result = kws.GetResult(s);
54 if (result.Keyword != string.Empty) 54 if (result.Keyword != string.Empty)
55 { 55 {
  56 + // Remember to call Reset() right after detecting a keyword
  57 + kws.Reset(s);
56 Console.WriteLine("Detected: {0}", result.Keyword); 58 Console.WriteLine("Detected: {0}", result.Keyword);
57 } 59 }
58 } 60 }
@@ -70,6 +72,8 @@ class KeywordSpotterDemo @@ -70,6 +72,8 @@ class KeywordSpotterDemo
70 var result = kws.GetResult(s); 72 var result = kws.GetResult(s);
71 if (result.Keyword != string.Empty) 73 if (result.Keyword != string.Empty)
72 { 74 {
  75 + // Remember to call Reset() right after detecting a keyword
  76 + kws.Reset(s);
73 Console.WriteLine("Detected: {0}", result.Keyword); 77 Console.WriteLine("Detected: {0}", result.Keyword);
74 } 78 }
75 } 79 }
@@ -89,6 +93,8 @@ class KeywordSpotterDemo @@ -89,6 +93,8 @@ class KeywordSpotterDemo
89 var result = kws.GetResult(s); 93 var result = kws.GetResult(s);
90 if (result.Keyword != string.Empty) 94 if (result.Keyword != string.Empty)
91 { 95 {
  96 + // Remember to call Reset() right after detecting a keyword
  97 + kws.Reset(s);
92 Console.WriteLine("Detected: {0}", result.Keyword); 98 Console.WriteLine("Detected: {0}", result.Keyword);
93 } 99 }
94 } 100 }
@@ -107,12 +107,15 @@ class KeywordSpotterDemo @@ -107,12 +107,15 @@ class KeywordSpotterDemo
107 while (kws.IsReady(s)) 107 while (kws.IsReady(s))
108 { 108 {
109 kws.Decode(s); 109 kws.Decode(s);
110 - }  
111 110
112 - var result = kws.GetResult(s);  
113 - if (result.Keyword != string.Empty)  
114 - {  
115 - Console.WriteLine("Detected: {0}", result.Keyword); 111 + var result = kws.GetResult(s);
  112 + if (result.Keyword != string.Empty)
  113 + {
  114 + // Remember to call Reset() right after detecting a keyword
  115 + kws.Reset(s);
  116 +
  117 + Console.WriteLine("Detected: {0}", result.Keyword);
  118 + }
116 } 119 }
117 120
118 Thread.Sleep(200); // ms 121 Thread.Sleep(200); // ms
@@ -168,6 +168,10 @@ class KeywordSpotter { @@ -168,6 +168,10 @@ class KeywordSpotter {
168 SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr); 168 SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr);
169 } 169 }
170 170
  171 + void reset(OnlineStream stream) {
  172 + SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr);
  173 + }
  174 +
171 Pointer<SherpaOnnxKeywordSpotter> ptr; 175 Pointer<SherpaOnnxKeywordSpotter> ptr;
172 KeywordSpotterConfig config; 176 KeywordSpotterConfig config;
173 } 177 }
@@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function( @@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function(
667 typedef DecodeKeywordStream = void Function( 667 typedef DecodeKeywordStream = void Function(
668 Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); 668 Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
669 669
  670 +typedef ResetKeywordStreamNative = Void Function(
  671 + Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
  672 +
  673 +typedef ResetKeywordStream = void Function(
  674 + Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
  675 +
670 typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function( 676 typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function(
671 Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); 677 Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>);
672 678
@@ -1157,6 +1163,7 @@ class SherpaOnnxBindings { @@ -1157,6 +1163,7 @@ class SherpaOnnxBindings {
1157 static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords; 1163 static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords;
1158 static IsKeywordStreamReady? isKeywordStreamReady; 1164 static IsKeywordStreamReady? isKeywordStreamReady;
1159 static DecodeKeywordStream? decodeKeywordStream; 1165 static DecodeKeywordStream? decodeKeywordStream;
  1166 + static ResetKeywordStream? resetKeywordStream;
1160 static GetKeywordResultAsJson? getKeywordResultAsJson; 1167 static GetKeywordResultAsJson? getKeywordResultAsJson;
1161 static FreeKeywordResultJson? freeKeywordResultJson; 1168 static FreeKeywordResultJson? freeKeywordResultJson;
1162 1169
@@ -1459,6 +1466,11 @@ class SherpaOnnxBindings { @@ -1459,6 +1466,11 @@ class SherpaOnnxBindings {
1459 'SherpaOnnxDecodeKeywordStream') 1466 'SherpaOnnxDecodeKeywordStream')
1460 .asFunction(); 1467 .asFunction();
1461 1468
  1469 + resetKeywordStream ??= dynamicLibrary
  1470 + .lookup<NativeFunction<ResetKeywordStreamNative>>(
  1471 + 'SherpaOnnxResetKeywordStream')
  1472 + .asFunction();
  1473 +
1462 getKeywordResultAsJson ??= dynamicLibrary 1474 getKeywordResultAsJson ??= dynamicLibrary
1463 .lookup<NativeFunction<GetKeywordResultAsJsonNative>>( 1475 .lookup<NativeFunction<GetKeywordResultAsJsonNative>>(
1464 'SherpaOnnxGetKeywordResultAsJson') 1476 'SherpaOnnxGetKeywordResultAsJson')
@@ -43,6 +43,8 @@ func main() { @@ -43,6 +43,8 @@ func main() {
43 spotter.Decode(stream) 43 spotter.Decode(stream)
44 result := spotter.GetResult(stream) 44 result := spotter.GetResult(stream)
45 if result.Keyword != "" { 45 if result.Keyword != "" {
  46 + // You have to reset the stream right after detecting a keyword
  47 + spotter.Reset(stream)
46 log.Printf("Detected %v\n", result.Keyword) 48 log.Printf("Detected %v\n", result.Keyword)
47 } 49 }
48 } 50 }
@@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( @@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
46 SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); 46 SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf);
47 SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); 47 SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize);
48 48
49 - SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); 49 + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c);
50 50
51 if (c.model_config.transducer.encoder) { 51 if (c.model_config.transducer.encoder) {
52 delete[] c.model_config.transducer.encoder; 52 delete[] c.model_config.transducer.encoder;
@@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( @@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper(
100 } 100 }
101 101
102 return Napi::External<SherpaOnnxKeywordSpotter>::New( 102 return Napi::External<SherpaOnnxKeywordSpotter>::New(
103 - env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) { 103 + env, const_cast<SherpaOnnxKeywordSpotter *>(kws),
  104 + [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) {
104 SherpaOnnxDestroyKeywordSpotter(kws); 105 SherpaOnnxDestroyKeywordSpotter(kws);
105 }); 106 });
106 } 107 }
@@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper( @@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper(
125 return {}; 126 return {};
126 } 127 }
127 128
128 - SherpaOnnxKeywordSpotter *kws = 129 + const SherpaOnnxKeywordSpotter *kws =
129 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); 130 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
130 131
131 - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); 132 + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws);
132 133
133 return Napi::External<SherpaOnnxOnlineStream>::New( 134 return Napi::External<SherpaOnnxOnlineStream>::New(
134 - env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) { 135 + env, const_cast<SherpaOnnxOnlineStream *>(stream),
  136 + [](Napi::Env env, SherpaOnnxOnlineStream *stream) {
135 SherpaOnnxDestroyOnlineStream(stream); 137 SherpaOnnxDestroyOnlineStream(stream);
136 }); 138 });
137 } 139 }
@@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper( @@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper(
162 return {}; 164 return {};
163 } 165 }
164 166
165 - SherpaOnnxKeywordSpotter *kws = 167 + const SherpaOnnxKeywordSpotter *kws =
166 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); 168 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
167 169
168 - SherpaOnnxOnlineStream *stream = 170 + const SherpaOnnxOnlineStream *stream =
169 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); 171 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
170 172
171 int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream); 173 int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream);
@@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) { @@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) {
198 return; 200 return;
199 } 201 }
200 202
201 - SherpaOnnxKeywordSpotter *kws = 203 + const SherpaOnnxKeywordSpotter *kws =
202 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); 204 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
203 205
204 - SherpaOnnxOnlineStream *stream = 206 + const SherpaOnnxOnlineStream *stream =
205 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); 207 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
206 208
207 SherpaOnnxDecodeKeywordStream(kws, stream); 209 SherpaOnnxDecodeKeywordStream(kws, stream);
208 } 210 }
209 211
  212 +static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) {
  213 + Napi::Env env = info.Env();
  214 + if (info.Length() != 2) {
  215 + std::ostringstream os;
  216 + os << "Expect only 2 arguments. Given: " << info.Length();
  217 +
  218 + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
  219 +
  220 + return;
  221 + }
  222 +
  223 + if (!info[0].IsExternal()) {
  224 + Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.")
  225 + .ThrowAsJavaScriptException();
  226 +
  227 + return;
  228 + }
  229 +
  230 + if (!info[1].IsExternal()) {
  231 + Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.")
  232 + .ThrowAsJavaScriptException();
  233 +
  234 + return;
  235 + }
  236 +
  237 + const SherpaOnnxKeywordSpotter *kws =
  238 + info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
  239 +
  240 + const SherpaOnnxOnlineStream *stream =
  241 + info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
  242 +
  243 + SherpaOnnxResetKeywordStream(kws, stream);
  244 +}
  245 +
210 static Napi::String GetKeywordResultAsJsonWrapper( 246 static Napi::String GetKeywordResultAsJsonWrapper(
211 const Napi::CallbackInfo &info) { 247 const Napi::CallbackInfo &info) {
212 Napi::Env env = info.Env(); 248 Napi::Env env = info.Env();
@@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper( @@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper(
233 return {}; 269 return {};
234 } 270 }
235 271
236 - SherpaOnnxKeywordSpotter *kws = 272 + const SherpaOnnxKeywordSpotter *kws =
237 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); 273 info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data();
238 274
239 - SherpaOnnxOnlineStream *stream = 275 + const SherpaOnnxOnlineStream *stream =
240 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); 276 info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data();
241 277
242 const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream); 278 const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream);
@@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) { @@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) {
261 exports.Set(Napi::String::New(env, "decodeKeywordStream"), 297 exports.Set(Napi::String::New(env, "decodeKeywordStream"),
262 Napi::Function::New(env, DecodeKeywordStreamWrapper)); 298 Napi::Function::New(env, DecodeKeywordStreamWrapper));
263 299
  300 + exports.Set(Napi::String::New(env, "resetKeywordStream"),
  301 + Napi::Function::New(env, ResetKeywordStreamWrapper));
  302 +
264 exports.Set(Napi::String::New(env, "getKeywordResultAsJson"), 303 exports.Set(Napi::String::New(env, "getKeywordResultAsJson"),
265 Napi::Function::New(env, GetKeywordResultAsJsonWrapper)); 304 Napi::Function::New(env, GetKeywordResultAsJsonWrapper));
266 } 305 }
@@ -56,6 +56,8 @@ public class KyewordSpotterFromFile { @@ -56,6 +56,8 @@ public class KyewordSpotterFromFile {
56 56
57 String keyword = kws.getResult(stream).getKeyword(); 57 String keyword = kws.getResult(stream).getKeyword();
58 if (!keyword.isEmpty()) { 58 if (!keyword.isEmpty()) {
  59 + // Remember to reset the stream right after detecting a keyword
  60 + kws.reset(stream);
59 System.out.printf("Detected keyword: %s\n", keyword); 61 System.out.printf("Detected keyword: %s\n", keyword);
60 } 62 }
61 } 63 }
@@ -41,6 +41,9 @@ while (kws.isReady(stream)) { @@ -41,6 +41,9 @@ while (kws.isReady(stream)) {
41 const keyword = kws.getResult(stream).keyword; 41 const keyword = kws.getResult(stream).keyword;
42 if (keyword != '') { 42 if (keyword != '') {
43 detectedKeywords.push(keyword); 43 detectedKeywords.push(keyword);
  44 +
  45 + // remember to reset the stream right after detecting a keyword
  46 + kws.reset(stream);
44 } 47 }
45 } 48 }
46 console.log(detectedKeywords); 49 console.log(detectedKeywords);
@@ -169,6 +169,8 @@ def main(): @@ -169,6 +169,8 @@ def main():
169 169
170 print("Started! Please speak") 170 print("Started! Please speak")
171 171
  172 + idx = 0
  173 +
172 sample_rate = 16000 174 sample_rate = 16000
173 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 175 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
174 stream = keyword_spotter.create_stream() 176 stream = keyword_spotter.create_stream()
@@ -179,9 +181,12 @@ def main(): @@ -179,9 +181,12 @@ def main():
179 stream.accept_waveform(sample_rate, samples) 181 stream.accept_waveform(sample_rate, samples)
180 while keyword_spotter.is_ready(stream): 182 while keyword_spotter.is_ready(stream):
181 keyword_spotter.decode_stream(stream) 183 keyword_spotter.decode_stream(stream)
182 - result = keyword_spotter.get_result(stream)  
183 - if result:  
184 - print("\r{}".format(result), end="", flush=True) 184 + result = keyword_spotter.get_result(stream)
  185 + if result:
  186 + print(f"{idx}: {result }")
  187 + idx += 1
  188 + # Remember to reset stream right after detecting a keyword
  189 + keyword_spotter.reset_stream(stream)
185 190
186 191
187 if __name__ == "__main__": 192 if __name__ == "__main__":
@@ -18,122 +18,6 @@ import numpy as np @@ -18,122 +18,6 @@ import numpy as np
18 import sherpa_onnx 18 import sherpa_onnx
19 19
20 20
21 -def get_args():  
22 - parser = argparse.ArgumentParser(  
23 - formatter_class=argparse.ArgumentDefaultsHelpFormatter  
24 - )  
25 -  
26 - parser.add_argument(  
27 - "--tokens",  
28 - type=str,  
29 - help="Path to tokens.txt",  
30 - )  
31 -  
32 - parser.add_argument(  
33 - "--encoder",  
34 - type=str,  
35 - help="Path to the transducer encoder model",  
36 - )  
37 -  
38 - parser.add_argument(  
39 - "--decoder",  
40 - type=str,  
41 - help="Path to the transducer decoder model",  
42 - )  
43 -  
44 - parser.add_argument(  
45 - "--joiner",  
46 - type=str,  
47 - help="Path to the transducer joiner model",  
48 - )  
49 -  
50 - parser.add_argument(  
51 - "--num-threads",  
52 - type=int,  
53 - default=1,  
54 - help="Number of threads for neural network computation",  
55 - )  
56 -  
57 - parser.add_argument(  
58 - "--provider",  
59 - type=str,  
60 - default="cpu",  
61 - help="Valid values: cpu, cuda, coreml",  
62 - )  
63 -  
64 - parser.add_argument(  
65 - "--max-active-paths",  
66 - type=int,  
67 - default=4,  
68 - help="""  
69 - It specifies number of active paths to keep during decoding.  
70 - """,  
71 - )  
72 -  
73 - parser.add_argument(  
74 - "--num-trailing-blanks",  
75 - type=int,  
76 - default=1,  
77 - help="""The number of trailing blanks a keyword should be followed. Setting  
78 - to a larger value (e.g. 8) when your keywords has overlapping tokens  
79 - between each other.  
80 - """,  
81 - )  
82 -  
83 - parser.add_argument(  
84 - "--keywords-file",  
85 - type=str,  
86 - help="""  
87 - The file containing keywords, one words/phrases per line, and for each  
88 - phrase the bpe/cjkchar/pinyin are separated by a space. For example:  
89 -  
90 - ▁HE LL O ▁WORLD  
91 - x iǎo ài t óng x ué  
92 - """,  
93 - )  
94 -  
95 - parser.add_argument(  
96 - "--keywords-score",  
97 - type=float,  
98 - default=1.0,  
99 - help="""  
100 - The boosting score of each token for keywords. The larger the easier to  
101 - survive beam search.  
102 - """,  
103 - )  
104 -  
105 - parser.add_argument(  
106 - "--keywords-threshold",  
107 - type=float,  
108 - default=0.25,  
109 - help="""  
110 - The trigger threshold (i.e. probability) of the keyword. The larger the  
111 - harder to trigger.  
112 - """,  
113 - )  
114 -  
115 - parser.add_argument(  
116 - "sound_files",  
117 - type=str,  
118 - nargs="+",  
119 - help="The input sound file(s) to decode. Each file must be of WAVE"  
120 - "format with a single channel, and each sample has 16-bit, "  
121 - "i.e., int16_t. "  
122 - "The sample rate of the file can be arbitrary and does not need to "  
123 - "be 16 kHz",  
124 - )  
125 -  
126 - return parser.parse_args()  
127 -  
128 -  
129 -def assert_file_exists(filename: str):  
130 - assert Path(filename).is_file(), (  
131 - f"{filename} does not exist!\n"  
132 - "Please refer to "  
133 - "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it"  
134 - )  
135 -  
136 -  
137 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: 21 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
138 """ 22 """
139 Args: 23 Args:
@@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: @@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
159 return samples_float32, f.getframerate() 43 return samples_float32, f.getframerate()
160 44
161 45
162 -def main():  
163 - args = get_args()  
164 - assert_file_exists(args.tokens)  
165 - assert_file_exists(args.encoder)  
166 - assert_file_exists(args.decoder)  
167 - assert_file_exists(args.joiner)  
168 -  
169 - assert Path(  
170 - args.keywords_file  
171 - ).is_file(), (  
172 - f"keywords_file : {args.keywords_file} not exist, please provide a valid path." 46 +def create_keyword_spotter():
  47 + kws = sherpa_onnx.KeywordSpotter(
  48 + tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
  49 + encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  50 + decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
  51 + joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
  52 + num_threads=2,
  53 + keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
  54 + provider="cpu",
173 ) 55 )
174 56
175 - keyword_spotter = sherpa_onnx.KeywordSpotter(  
176 - tokens=args.tokens,  
177 - encoder=args.encoder,  
178 - decoder=args.decoder,  
179 - joiner=args.joiner,  
180 - num_threads=args.num_threads,  
181 - max_active_paths=args.max_active_paths,  
182 - keywords_file=args.keywords_file,  
183 - keywords_score=args.keywords_score,  
184 - keywords_threshold=args.keywords_threshold,  
185 - num_trailing_blanks=args.num_trailing_blanks,  
186 - provider=args.provider,  
187 - ) 57 + return kws
188 58
189 - print("Started!")  
190 - start_time = time.time()  
191 -  
192 - streams = []  
193 - total_duration = 0  
194 - for wave_filename in args.sound_files:  
195 - assert_file_exists(wave_filename)  
196 - samples, sample_rate = read_wave(wave_filename)  
197 - duration = len(samples) / sample_rate  
198 - total_duration += duration  
199 -  
200 - s = keyword_spotter.create_stream()  
201 -  
202 - s.accept_waveform(sample_rate, samples)  
203 -  
204 - tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)  
205 - s.accept_waveform(sample_rate, tail_paddings)  
206 -  
207 - s.input_finished()  
208 -  
209 - streams.append(s)  
210 -  
211 - results = [""] * len(streams)  
212 - while True:  
213 - ready_list = []  
214 - for i, s in enumerate(streams):  
215 - if keyword_spotter.is_ready(s):  
216 - ready_list.append(s)  
217 - r = keyword_spotter.get_result(s)  
218 - if r:  
219 - results[i] += f"{r}/"  
220 - print(f"{r} is detected.")  
221 - if len(ready_list) == 0:  
222 - break  
223 - keyword_spotter.decode_streams(ready_list)  
224 - end_time = time.time()  
225 - print("Done!")  
226 -  
227 - for wave_filename, result in zip(args.sound_files, results):  
228 - print(f"{wave_filename}\n{result}")  
229 - print("-" * 10)  
230 -  
231 - elapsed_seconds = end_time - start_time  
232 - rtf = elapsed_seconds / total_duration  
233 - print(f"num_threads: {args.num_threads}")  
234 - print(f"Wave duration: {total_duration:.3f} s")  
235 - print(f"Elapsed time: {elapsed_seconds:.3f} s")  
236 - print(  
237 - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"  
238 - ) 59 +
  60 +def main():
  61 + kws = create_keyword_spotter()
  62 +
  63 + wave_filename = (
  64 + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
  65 + )
  66 +
  67 + samples, sample_rate = read_wave(wave_filename)
  68 +
  69 + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
  70 +
  71 + print("----------Use pre-defined keywords----------")
  72 + s = kws.create_stream()
  73 + s.accept_waveform(sample_rate, samples)
  74 + s.accept_waveform(sample_rate, tail_paddings)
  75 + s.input_finished()
  76 + while kws.is_ready(s):
  77 + kws.decode_stream(s)
  78 + r = kws.get_result(s)
  79 + if r != "":
  80 + # Remember to call reset right after detected a keyword
  81 + kws.reset_stream(s)
  82 +
  83 + print(f"Detected {r}")
  84 +
  85 + print("----------Use pre-defined keywords + add a new keyword----------")
  86 +
  87 + s = kws.create_stream("y ǎn y uán @演员")
  88 + s.accept_waveform(sample_rate, samples)
  89 + s.accept_waveform(sample_rate, tail_paddings)
  90 + s.input_finished()
  91 + while kws.is_ready(s):
  92 + kws.decode_stream(s)
  93 + r = kws.get_result(s)
  94 + if r != "":
  95 + # Remember to call reset right after detected a keyword
  96 + kws.reset_stream(s)
  97 +
  98 + print(f"Detected {r}")
  99 +
  100 + print("----------Use pre-defined keywords + add 2 new keywords----------")
  101 +
  102 + s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
  103 + s.accept_waveform(sample_rate, samples)
  104 + s.accept_waveform(sample_rate, tail_paddings)
  105 + s.input_finished()
  106 + while kws.is_ready(s):
  107 + kws.decode_stream(s)
  108 + r = kws.get_result(s)
  109 + if r != "":
  110 + # Remember to call reset right after detected a keyword
  111 + kws.reset_stream(s)
  112 +
  113 + print(f"Detected {r}")
239 114
240 115
241 if __name__ == "__main__": 116 if __name__ == "__main__":
@@ -46,6 +46,11 @@ namespace SherpaOnnx @@ -46,6 +46,11 @@ namespace SherpaOnnx
46 Decode(_handle.Handle, stream.Handle); 46 Decode(_handle.Handle, stream.Handle);
47 } 47 }
48 48
  49 + public void Reset(OnlineStream stream)
  50 + {
  51 + Reset(_handle.Handle, stream.Handle);
  52 + }
  53 +
49 // The caller should ensure all passed streams are ready for decoding. 54 // The caller should ensure all passed streams are ready for decoding.
50 public void Decode(IEnumerable<OnlineStream> streams) 55 public void Decode(IEnumerable<OnlineStream> streams)
51 { 56 {
@@ -110,6 +115,9 @@ namespace SherpaOnnx @@ -110,6 +115,9 @@ namespace SherpaOnnx
110 [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")] 115 [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")]
111 private static extern void Decode(IntPtr handle, IntPtr stream); 116 private static extern void Decode(IntPtr handle, IntPtr stream);
112 117
  118 + [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")]
  119 + private static extern void Reset(IntPtr handle, IntPtr stream);
  120 +
113 [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")] 121 [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")]
114 private static extern void Decode(IntPtr handle, IntPtr[] streams, int n); 122 private static extern void Decode(IntPtr handle, IntPtr[] streams, int n);
115 123
@@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) { @@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) {
1584 C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl) 1584 C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
1585 } 1585 }
1586 1586
  1587 +// You MUST call it right after detecting a keyword
  1588 +func (spotter *KeywordSpotter) Reset(s *OnlineStream) {
  1589 + C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl)
  1590 +}
  1591 +
1587 // Get the current result of stream since the last invoke of Reset() 1592 // Get the current result of stream since the last invoke of Reset()
1588 func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult { 1593 func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
1589 p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl) 1594 p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
@@ -20,6 +20,10 @@ class KeywordSpotter { @@ -20,6 +20,10 @@ class KeywordSpotter {
20 addon.decodeKeywordStream(this.handle, stream.handle); 20 addon.decodeKeywordStream(this.handle, stream.handle);
21 } 21 }
22 22
  23 + reset(stream) {
  24 + addon.resetKeywordStream(this.handle, stream.handle);
  25 + }
  26 +
23 getResult(stream) { 27 getResult(stream) {
24 const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle); 28 const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle);
25 29
@@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter { @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter {
678 std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; 678 std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
679 }; 679 };
680 680
681 -SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( 681 +const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
682 const SherpaOnnxKeywordSpotterConfig *config) { 682 const SherpaOnnxKeywordSpotterConfig *config) {
683 sherpa_onnx::KeywordSpotterConfig spotter_config; 683 sherpa_onnx::KeywordSpotterConfig spotter_config;
684 684
@@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( @@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
755 return spotter; 755 return spotter;
756 } 756 }
757 757
758 -void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) { 758 +void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) {
759 delete spotter; 759 delete spotter;
760 } 760 }
761 761
762 -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( 762 +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
763 const SherpaOnnxKeywordSpotter *spotter) { 763 const SherpaOnnxKeywordSpotter *spotter) {
764 SherpaOnnxOnlineStream *stream = 764 SherpaOnnxOnlineStream *stream =
765 new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); 765 new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
766 return stream; 766 return stream;
767 } 767 }
768 768
769 -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords( 769 +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords(
770 const SherpaOnnxKeywordSpotter *spotter, const char *keywords) { 770 const SherpaOnnxKeywordSpotter *spotter, const char *keywords) {
771 SherpaOnnxOnlineStream *stream = 771 SherpaOnnxOnlineStream *stream =
772 new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords)); 772 new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords));
773 return stream; 773 return stream;
774 } 774 }
775 775
776 -int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,  
777 - SherpaOnnxOnlineStream *stream) { 776 +int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
  777 + const SherpaOnnxOnlineStream *stream) {
778 return spotter->impl->IsReady(stream->impl.get()); 778 return spotter->impl->IsReady(stream->impl.get());
779 } 779 }
780 780
781 -void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,  
782 - SherpaOnnxOnlineStream *stream) {  
783 - return spotter->impl->DecodeStream(stream->impl.get()); 781 +void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
  782 + const SherpaOnnxOnlineStream *stream) {
  783 + spotter->impl->DecodeStream(stream->impl.get());
784 } 784 }
785 785
786 -void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,  
787 - SherpaOnnxOnlineStream **streams,  
788 - int32_t n) { 786 +void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter,
  787 + const SherpaOnnxOnlineStream *stream) {
  788 + spotter->impl->Reset(stream->impl.get());
  789 +}
  790 +
  791 +void SherpaOnnxDecodeMultipleKeywordStreams(
  792 + const SherpaOnnxKeywordSpotter *spotter,
  793 + const SherpaOnnxOnlineStream **streams, int32_t n) {
789 std::vector<sherpa_onnx::OnlineStream *> ss(n); 794 std::vector<sherpa_onnx::OnlineStream *> ss(n);
790 for (int32_t i = 0; i != n; ++i) { 795 for (int32_t i = 0; i != n; ++i) {
791 ss[i] = streams[i]->impl.get(); 796 ss[i] = streams[i]->impl.get();
@@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, @@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
794 } 799 }
795 800
796 const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( 801 const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
797 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { 802 + const SherpaOnnxKeywordSpotter *spotter,
  803 + const SherpaOnnxOnlineStream *stream) {
798 const sherpa_onnx::KeywordResult &result = 804 const sherpa_onnx::KeywordResult &result =
799 spotter->impl->GetResult(stream->impl.get()); 805 spotter->impl->GetResult(stream->impl.get());
800 const auto &keyword = result.keyword; 806 const auto &keyword = result.keyword;
@@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) { @@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
869 } 875 }
870 } 876 }
871 877
872 -const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter,  
873 - SherpaOnnxOnlineStream *stream) { 878 +const char *SherpaOnnxGetKeywordResultAsJson(
  879 + const SherpaOnnxKeywordSpotter *spotter,
  880 + const SherpaOnnxOnlineStream *stream) {
874 const sherpa_onnx::KeywordResult &result = 881 const sherpa_onnx::KeywordResult &result =
875 spotter->impl->GetResult(stream->impl.get()); 882 spotter->impl->GetResult(stream->impl.get());
876 883
@@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson( @@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson(
600 SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s); 600 SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s);
601 601
602 // ============================================================ 602 // ============================================================
603 -// For Keyword Spot 603 +// For Keyword Spotter
604 // ============================================================ 604 // ============================================================
605 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { 605 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
606 /// The triggered keyword. 606 /// The triggered keyword.
@@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter @@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
660 /// @param config Config for the keyword spotter. 660 /// @param config Config for the keyword spotter.
661 /// @return Return a pointer to the spotter. The user has to invoke 661 /// @return Return a pointer to the spotter. The user has to invoke
662 /// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak. 662 /// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak.
663 -SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( 663 +SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter(
664 const SherpaOnnxKeywordSpotterConfig *config); 664 const SherpaOnnxKeywordSpotterConfig *config);
665 665
666 /// Free a pointer returned by SherpaOnnxCreateKeywordSpotter() 666 /// Free a pointer returned by SherpaOnnxCreateKeywordSpotter()
667 /// 667 ///
668 /// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter() 668 /// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter()
669 SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter( 669 SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter(
670 - SherpaOnnxKeywordSpotter *spotter); 670 + const SherpaOnnxKeywordSpotter *spotter);
671 671
672 /// Create an online stream for accepting wave samples. 672 /// Create an online stream for accepting wave samples.
673 /// 673 ///
674 /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter() 674 /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter()
675 /// @return Return a pointer to an OnlineStream. The user has to invoke 675 /// @return Return a pointer to an OnlineStream. The user has to invoke
676 /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. 676 /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
677 -SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( 677 +SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
678 const SherpaOnnxKeywordSpotter *spotter); 678 const SherpaOnnxKeywordSpotter *spotter);
679 679
680 /// Create an online stream for accepting wave samples with the specified hot 680 /// Create an online stream for accepting wave samples with the specified hot
@@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( @@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream(
684 /// @param keywords A pointer points to the keywords that you set 684 /// @param keywords A pointer points to the keywords that you set
685 /// @return Return a pointer to an OnlineStream. The user has to invoke 685 /// @return Return a pointer to an OnlineStream. The user has to invoke
686 /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. 686 /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak.
687 -SHERPA_ONNX_API SherpaOnnxOnlineStream * 687 +SHERPA_ONNX_API const SherpaOnnxOnlineStream *
688 SherpaOnnxCreateKeywordStreamWithKeywords( 688 SherpaOnnxCreateKeywordStreamWithKeywords(
689 const SherpaOnnxKeywordSpotter *spotter, const char *keywords); 689 const SherpaOnnxKeywordSpotter *spotter, const char *keywords);
690 690
@@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords( @@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords(
693 /// 693 ///
694 /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter 694 /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter
695 /// @param stream A pointer returned by SherpaOnnxCreateKeywordStream 695 /// @param stream A pointer returned by SherpaOnnxCreateKeywordStream
696 -SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady(  
697 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); 696 +SHERPA_ONNX_API int32_t
  697 +SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter,
  698 + const SherpaOnnxOnlineStream *stream);
698 699
699 /// Call this function to run the neural network model and decoding. 700 /// Call this function to run the neural network model and decoding.
700 // 701 //
701 /// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST 702 /// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST
702 /// return 1. 703 /// return 1.
703 SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( 704 SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
704 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); 705 + const SherpaOnnxKeywordSpotter *spotter,
  706 + const SherpaOnnxOnlineStream *stream);
  707 +
  708 +/// Please call it right after a keyword is detected
  709 +SHERPA_ONNX_API void SherpaOnnxResetKeywordStream(
  710 + const SherpaOnnxKeywordSpotter *spotter,
  711 + const SherpaOnnxOnlineStream *stream);
705 712
706 /// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes 713 /// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes
707 /// multiple OnlineStream in parallel. 714 /// multiple OnlineStream in parallel.
@@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( @@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream(
714 /// SherpaOnnxCreateKeywordStream() 721 /// SherpaOnnxCreateKeywordStream()
715 /// @param n Number of elements in the given streams array. 722 /// @param n Number of elements in the given streams array.
716 SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( 723 SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
717 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,  
718 - int32_t n); 724 + const SherpaOnnxKeywordSpotter *spotter,
  725 + const SherpaOnnxOnlineStream **streams, int32_t n);
719 726
720 /// Get the decoding results so far for an OnlineStream. 727 /// Get the decoding results so far for an OnlineStream.
721 /// 728 ///
@@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( @@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams(
725 /// SherpaOnnxDestroyKeywordResult() to free the returned pointer to 732 /// SherpaOnnxDestroyKeywordResult() to free the returned pointer to
726 /// avoid memory leak. 733 /// avoid memory leak.
727 SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( 734 SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult(
728 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); 735 + const SherpaOnnxKeywordSpotter *spotter,
  736 + const SherpaOnnxOnlineStream *stream);
729 737
730 /// Destroy the pointer returned by SherpaOnnxGetKeywordResult(). 738 /// Destroy the pointer returned by SherpaOnnxGetKeywordResult().
731 /// 739 ///
@@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult( @@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult(
736 // the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned 744 // the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned
737 // pointer to avoid memory leak 745 // pointer to avoid memory leak
738 SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson( 746 SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson(
739 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); 747 + const SherpaOnnxKeywordSpotter *spotter,
  748 + const SherpaOnnxOnlineStream *stream);
740 749
741 SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s); 750 SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s);
742 751
@@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text, @@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text,
391 return ans; 391 return ans;
392 } 392 }
393 393
  394 +KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) {
  395 + struct SherpaOnnxKeywordSpotterConfig c;
  396 + memset(&c, 0, sizeof(c));
  397 +
  398 + c.feat_config.sample_rate = config.feat_config.sample_rate;
  399 +
  400 + c.model_config.transducer.encoder =
  401 + config.model_config.transducer.encoder.c_str();
  402 + c.model_config.transducer.decoder =
  403 + config.model_config.transducer.decoder.c_str();
  404 + c.model_config.transducer.joiner =
  405 + config.model_config.transducer.joiner.c_str();
  406 + c.feat_config.feature_dim = config.feat_config.feature_dim;
  407 +
  408 + c.model_config.paraformer.encoder =
  409 + config.model_config.paraformer.encoder.c_str();
  410 + c.model_config.paraformer.decoder =
  411 + config.model_config.paraformer.decoder.c_str();
  412 +
  413 + c.model_config.zipformer2_ctc.model =
  414 + config.model_config.zipformer2_ctc.model.c_str();
  415 +
  416 + c.model_config.tokens = config.model_config.tokens.c_str();
  417 + c.model_config.num_threads = config.model_config.num_threads;
  418 + c.model_config.provider = config.model_config.provider.c_str();
  419 + c.model_config.debug = config.model_config.debug;
  420 + c.model_config.model_type = config.model_config.model_type.c_str();
  421 + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str();
  422 + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str();
  423 + c.model_config.tokens_buf = config.model_config.tokens_buf.c_str();
  424 + c.model_config.tokens_buf_size = config.model_config.tokens_buf.size();
  425 +
  426 + c.max_active_paths = config.max_active_paths;
  427 + c.num_trailing_blanks = config.num_trailing_blanks;
  428 + c.keywords_score = config.keywords_score;
  429 + c.keywords_threshold = config.keywords_threshold;
  430 + c.keywords_file = config.keywords_file.c_str();
  431 +
  432 + auto p = SherpaOnnxCreateKeywordSpotter(&c);
  433 + return KeywordSpotter(p);
  434 +}
  435 +
  436 +KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p)
  437 + : MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter>(p) {}
  438 +
  439 +void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const {
  440 + SherpaOnnxDestroyKeywordSpotter(p);
  441 +}
  442 +
  443 +OnlineStream KeywordSpotter::CreateStream() const {
  444 + auto s = SherpaOnnxCreateKeywordStream(p_);
  445 + return OnlineStream{s};
  446 +}
  447 +
  448 +OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const {
  449 + auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str());
  450 + return OnlineStream{s};
  451 +}
  452 +
  453 +bool KeywordSpotter::IsReady(const OnlineStream *s) const {
  454 + return SherpaOnnxIsKeywordStreamReady(p_, s->Get());
  455 +}
  456 +
  457 +void KeywordSpotter::Decode(const OnlineStream *s) const {
  458 + return SherpaOnnxDecodeKeywordStream(p_, s->Get());
  459 +}
  460 +
  461 +void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const {
  462 + if (n <= 0) {
  463 + return;
  464 + }
  465 +
  466 + std::vector<const SherpaOnnxOnlineStream *> streams(n);
  467 + for (int32_t i = 0; i != n; ++n) {
  468 + streams[i] = ss[i].Get();
  469 + }
  470 +
  471 + SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n);
  472 +}
  473 +
  474 +KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const {
  475 + auto r = SherpaOnnxGetKeywordResult(p_, s->Get());
  476 +
  477 + KeywordResult ans;
  478 + ans.keyword = r->keyword;
  479 +
  480 + ans.tokens.resize(r->count);
  481 + for (int32_t i = 0; i < r->count; ++i) {
  482 + ans.tokens[i] = r->tokens_arr[i];
  483 + }
  484 +
  485 + if (r->timestamps) {
  486 + ans.timestamps.resize(r->count);
  487 + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data());
  488 + }
  489 +
  490 + ans.start_time = r->start_time;
  491 + ans.json = r->json;
  492 +
  493 + SherpaOnnxDestroyKeywordResult(r);
  494 +
  495 + return ans;
  496 +}
  497 +
  498 +void KeywordSpotter::Reset(const OnlineStream *s) const {
  499 + SherpaOnnxResetKeywordStream(p_, s->Get());
  500 +}
  501 +
394 } // namespace sherpa_onnx::cxx 502 } // namespace sherpa_onnx::cxx
@@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts @@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts
406 explicit OfflineTts(const SherpaOnnxOfflineTts *p); 406 explicit OfflineTts(const SherpaOnnxOfflineTts *p);
407 }; 407 };
408 408
  409 +// ============================================================
  410 +// For Keyword Spotter
  411 +// ============================================================
  412 +
  413 +struct KeywordResult {
  414 + std::string keyword;
  415 + std::vector<std::string> tokens;
  416 + std::vector<float> timestamps;
  417 + float start_time;
  418 + std::string json;
  419 +};
  420 +
  421 +struct KeywordSpotterConfig {
  422 + FeatureConfig feat_config;
  423 + OnlineModelConfig model_config;
  424 + int32_t max_active_paths = 4;
  425 + int32_t num_trailing_blanks = 1;
  426 + float keywords_score = 1.0f;
  427 + float keywords_threshold = 0.25f;
  428 + std::string keywords_file;
  429 +};
  430 +
  431 +class SHERPA_ONNX_API KeywordSpotter
  432 + : public MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter> {
  433 + public:
  434 + static KeywordSpotter Create(const KeywordSpotterConfig &config);
  435 +
  436 + void Destroy(const SherpaOnnxKeywordSpotter *p) const;
  437 +
  438 + OnlineStream CreateStream() const;
  439 +
  440 + OnlineStream CreateStream(const std::string &keywords) const;
  441 +
  442 + bool IsReady(const OnlineStream *s) const;
  443 +
  444 + void Decode(const OnlineStream *s) const;
  445 +
  446 + void Decode(const OnlineStream *ss, int32_t n) const;
  447 +
  448 + void Reset(const OnlineStream *s) const;
  449 +
  450 + KeywordResult GetResult(const OnlineStream *s) const;
  451 +
  452 + private:
  453 + explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
  454 +};
  455 +
409 } // namespace sherpa_onnx::cxx 456 } // namespace sherpa_onnx::cxx
410 457
411 #endif // SHERPA_ONNX_C_API_CXX_API_H_ 458 #endif // SHERPA_ONNX_C_API_CXX_API_H_
@@ -38,6 +38,8 @@ class KeywordSpotterImpl { @@ -38,6 +38,8 @@ class KeywordSpotterImpl {
38 38
39 virtual bool IsReady(OnlineStream *s) const = 0; 39 virtual bool IsReady(OnlineStream *s) const = 0;
40 40
  41 + virtual void Reset(OnlineStream *s) const = 0;
  42 +
41 virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; 43 virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
42 44
43 virtual KeywordResult GetResult(OnlineStream *s) const = 0; 45 virtual KeywordResult GetResult(OnlineStream *s) const = 0;
@@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
195 return s->GetNumProcessedFrames() + model_->ChunkSize() < 195 return s->GetNumProcessedFrames() + model_->ChunkSize() <
196 s->NumFramesReady(); 196 s->NumFramesReady();
197 } 197 }
  198 + void Reset(OnlineStream *s) const override { InitOnlineStream(s); }
198 199
199 void DecodeStreams(OnlineStream **ss, int32_t n) const override { 200 void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  201 + for (int32_t i = 0; i < n; ++i) {
  202 + auto s = ss[i];
  203 + auto r = s->GetKeywordResult(true);
  204 + int32_t num_trailing_blanks = r.num_trailing_blanks;
  205 + // assume subsampling_factor is 4
  206 + // assume frameshift is 0.01 second
  207 + float trailing_slience = num_trailing_blanks * 4 * 0.01;
  208 +
  209 + // it resets automatically after detecting 1.5 seconds of silence
  210 + float threshold = 1.5;
  211 + if (trailing_slience > threshold) {
  212 + Reset(s);
  213 + }
  214 + }
  215 +
200 int32_t chunk_size = model_->ChunkSize(); 216 int32_t chunk_size = model_->ChunkSize();
201 int32_t chunk_shift = model_->ChunkShift(); 217 int32_t chunk_shift = model_->ChunkShift();
202 218
@@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const { @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const {
157 return impl_->IsReady(s); 157 return impl_->IsReady(s);
158 } 158 }
159 159
  160 +void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); }
  161 +
160 void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { 162 void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
161 impl_->DecodeStreams(ss, n); 163 impl_->DecodeStreams(ss, n);
162 } 164 }
@@ -129,6 +129,9 @@ class KeywordSpotter { @@ -129,6 +129,9 @@ class KeywordSpotter {
129 */ 129 */
130 bool IsReady(OnlineStream *s) const; 130 bool IsReady(OnlineStream *s) const;
131 131
  132 + // Remember to call it after detecting a keyword
  133 + void Reset(OnlineStream *s) const;
  134 +
132 /** Decode a single stream. */ 135 /** Decode a single stream. */
133 void DecodeStream(OnlineStream *s) const { 136 void DecodeStream(OnlineStream *s) const {
134 OnlineStream *ss[1] = {s}; 137 OnlineStream *ss[1] = {s};
@@ -106,13 +106,15 @@ as the device_name. @@ -106,13 +106,15 @@ as the device_name.
106 106
107 while (spotter.IsReady(stream.get())) { 107 while (spotter.IsReady(stream.get())) {
108 spotter.DecodeStream(stream.get()); 108 spotter.DecodeStream(stream.get());
109 - }  
110 109
111 - const auto r = spotter.GetResult(stream.get());  
112 - if (!r.keyword.empty()) {  
113 - display.Print(keyword_index, r.AsJsonString());  
114 - fflush(stderr);  
115 - keyword_index++; 110 + const auto r = spotter.GetResult(stream.get());
  111 + if (!r.keyword.empty()) {
  112 + display.Print(keyword_index, r.AsJsonString());
  113 + fflush(stderr);
  114 + keyword_index++;
  115 +
  116 + spotter.Reset(stream.get());
  117 + }
116 } 118 }
117 } 119 }
118 120
@@ -150,13 +150,15 @@ for a list of pre-trained models to download. @@ -150,13 +150,15 @@ for a list of pre-trained models to download.
150 while (!stop) { 150 while (!stop) {
151 while (spotter.IsReady(s.get())) { 151 while (spotter.IsReady(s.get())) {
152 spotter.DecodeStream(s.get()); 152 spotter.DecodeStream(s.get());
153 - }  
154 153
155 - const auto r = spotter.GetResult(s.get());  
156 - if (!r.keyword.empty()) {  
157 - display.Print(keyword_index, r.AsJsonString());  
158 - fflush(stderr);  
159 - keyword_index++; 154 + const auto r = spotter.GetResult(s.get());
  155 + if (!r.keyword.empty()) {
  156 + display.Print(keyword_index, r.AsJsonString());
  157 + fflush(stderr);
  158 + keyword_index++;
  159 +
  160 + spotter.Reset(s.get());
  161 + }
160 } 162 }
161 163
162 Pa_Sleep(20); // sleep for 20ms 164 Pa_Sleep(20); // sleep for 20ms
@@ -27,6 +27,10 @@ public class KeywordSpotter { @@ -27,6 +27,10 @@ public class KeywordSpotter {
27 decode(ptr, s.getPtr()); 27 decode(ptr, s.getPtr());
28 } 28 }
29 29
  30 + public void reset(OnlineStream s) {
  31 + reset(ptr, s.getPtr());
  32 + }
  33 +
30 public boolean isReady(OnlineStream s) { 34 public boolean isReady(OnlineStream s) {
31 return isReady(ptr, s.getPtr()); 35 return isReady(ptr, s.getPtr());
32 } 36 }
@@ -60,6 +64,8 @@ public class KeywordSpotter { @@ -60,6 +64,8 @@ public class KeywordSpotter {
60 64
61 private native void decode(long ptr, long streamPtr); 65 private native void decode(long ptr, long streamPtr);
62 66
  67 + private native void reset(long ptr, long streamPtr);
  68 +
63 private native boolean isReady(long ptr, long streamPtr); 69 private native boolean isReady(long ptr, long streamPtr);
64 70
65 private native Object[] getResult(long ptr, long streamPtr); 71 private native Object[] getResult(long ptr, long streamPtr);
@@ -162,6 +162,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode( @@ -162,6 +162,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
162 } 162 }
163 163
164 SHERPA_ONNX_EXTERN_C 164 SHERPA_ONNX_EXTERN_C
  165 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset(
  166 + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
  167 + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
  168 + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
  169 +
  170 + kws->Reset(stream);
  171 +}
  172 +
  173 +SHERPA_ONNX_EXTERN_C
165 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream( 174 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
166 JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { 175 JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
167 auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr); 176 auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
@@ -49,6 +49,7 @@ class KeywordSpotter( @@ -49,6 +49,7 @@ class KeywordSpotter(
49 } 49 }
50 50
51 fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) 51 fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
  52 + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
52 fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) 53 fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
53 fun getResult(stream: OnlineStream): KeywordSpotterResult { 54 fun getResult(stream: OnlineStream): KeywordSpotterResult {
54 val objArray = getResult(ptr, stream.ptr) 55 val objArray = getResult(ptr, stream.ptr)
@@ -74,6 +75,7 @@ class KeywordSpotter( @@ -74,6 +75,7 @@ class KeywordSpotter(
74 private external fun createStream(ptr: Long, keywords: String): Long 75 private external fun createStream(ptr: Long, keywords: String): Long
75 private external fun isReady(ptr: Long, streamPtr: Long): Boolean 76 private external fun isReady(ptr: Long, streamPtr: Long): Boolean
76 private external fun decode(ptr: Long, streamPtr: Long) 77 private external fun decode(ptr: Long, streamPtr: Long)
  78 + private external fun reset(ptr: Long, streamPtr: Long)
77 private external fun getResult(ptr: Long, streamPtr: Long): Array<Any> 79 private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
78 80
79 companion object { 81 companion object {
@@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) { @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) {
67 py::arg("keywords"), py::call_guard<py::gil_scoped_release>()) 67 py::arg("keywords"), py::call_guard<py::gil_scoped_release>())
68 .def("is_ready", &PyClass::IsReady, 68 .def("is_ready", &PyClass::IsReady,
69 py::call_guard<py::gil_scoped_release>()) 69 py::call_guard<py::gil_scoped_release>())
  70 + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>())
70 .def("decode_stream", &PyClass::DecodeStream, 71 .def("decode_stream", &PyClass::DecodeStream,
71 py::call_guard<py::gil_scoped_release>()) 72 py::call_guard<py::gil_scoped_release>())
72 .def( 73 .def(
@@ -104,8 +104,8 @@ class KeywordSpotter(object): @@ -104,8 +104,8 @@ class KeywordSpotter(object):
104 ) 104 )
105 105
106 provider_config = ProviderConfig( 106 provider_config = ProviderConfig(
107 - provider=provider,  
108 - device = device, 107 + provider=provider,
  108 + device=device,
109 ) 109 )
110 110
111 model_config = OnlineModelConfig( 111 model_config = OnlineModelConfig(
@@ -131,6 +131,9 @@ class KeywordSpotter(object): @@ -131,6 +131,9 @@ class KeywordSpotter(object):
131 ) 131 )
132 self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) 132 self.keyword_spotter = _KeywordSpotter(keywords_spotter_config)
133 133
  134 + def reset_stream(self, s: OnlineStream):
  135 + self.keyword_spotter.reset(s)
  136 +
134 def create_stream(self, keywords: Optional[str] = None): 137 def create_stream(self, keywords: Optional[str] = None):
135 if keywords is None: 138 if keywords is None:
136 return self.keyword_spotter.create_stream() 139 return self.keyword_spotter.create_stream()
@@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase): @@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase):
98 if r: 98 if r:
99 print(f"{r} is detected.") 99 print(f"{r} is detected.")
100 results[i] += f"{r}/" 100 results[i] += f"{r}/"
  101 +
  102 + keyword_spotter.reset_stream(s)
  103 +
101 if len(ready_list) == 0: 104 if len(ready_list) == 0:
102 break 105 break
103 keyword_spotter.decode_streams(ready_list) 106 keyword_spotter.decode_streams(ready_list)
@@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase): @@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase):
158 if r: 161 if r:
159 print(f"{r} is detected.") 162 print(f"{r} is detected.")
160 results[i] += f"{r}/" 163 results[i] += f"{r}/"
  164 +
  165 + keyword_spotter.reset_stream(s)
  166 +
161 if len(ready_list) == 0: 167 if len(ready_list) == 0:
162 break 168 break
163 keyword_spotter.decode_streams(ready_list) 169 keyword_spotter.decode_streams(ready_list)
@@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper { @@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper {
1076 SherpaOnnxDecodeKeywordStream(spotter, stream) 1076 SherpaOnnxDecodeKeywordStream(spotter, stream)
1077 } 1077 }
1078 1078
  1079 + func reset() {
  1080 + SherpaOnnxResetKeywordStream(spotter, stream)
  1081 + }
  1082 +
1079 func getResult() -> SherpaOnnxKeywordResultWrapper { 1083 func getResult() -> SherpaOnnxKeywordResultWrapper {
1080 let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult( 1084 let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult(
1081 spotter, stream) 1085 spotter, stream)
@@ -70,6 +70,9 @@ func run() { @@ -70,6 +70,9 @@ func run() {
70 spotter.decode() 70 spotter.decode()
71 let keyword = spotter.getResult().keyword 71 let keyword = spotter.getResult().keyword
72 if keyword != "" { 72 if keyword != "" {
  73 + // Remember to call reset() right after detecting a keyword
  74 + spotter.reset()
  75 +
73 print("Detected: \(keyword)") 76 print("Detected: \(keyword)")
74 } 77 }
75 } 78 }
@@ -17,6 +17,7 @@ set(exported_functions @@ -17,6 +17,7 @@ set(exported_functions
17 SherpaOnnxIsKeywordStreamReady 17 SherpaOnnxIsKeywordStreamReady
18 SherpaOnnxOnlineStreamAcceptWaveform 18 SherpaOnnxOnlineStreamAcceptWaveform
19 SherpaOnnxOnlineStreamInputFinished 19 SherpaOnnxOnlineStreamInputFinished
  20 + SherpaOnnxResetKeywordStream
20 ) 21 )
21 set(mangled_exported_functions) 22 set(mangled_exported_functions)
22 foreach(x IN LISTS exported_functions) 23 foreach(x IN LISTS exported_functions)
@@ -102,15 +102,17 @@ if (navigator.mediaDevices.getUserMedia) { @@ -102,15 +102,17 @@ if (navigator.mediaDevices.getUserMedia) {
102 recognizer_stream.acceptWaveform(expectedSampleRate, samples); 102 recognizer_stream.acceptWaveform(expectedSampleRate, samples);
103 while (recognizer.isReady(recognizer_stream)) { 103 while (recognizer.isReady(recognizer_stream)) {
104 recognizer.decode(recognizer_stream); 104 recognizer.decode(recognizer_stream);
105 - }  
106 105
  106 + let result = recognizer.getResult(recognizer_stream);
107 107
108 - let result = recognizer.getResult(recognizer_stream); 108 + if (result.keyword.length > 0) {
  109 + console.log(result)
  110 + lastResult = result;
  111 + resultList.push(JSON.stringify(result));
109 112
110 - if (result.keyword.length > 0) {  
111 - console.log(result)  
112 - lastResult = result;  
113 - resultList.push(JSON.stringify(result)); 113 + // remember to reset the stream right after detecting a keyword
  114 + recognizer.reset(recognizer_stream);
  115 + }
114 } 116 }
115 117
116 118
@@ -296,8 +296,11 @@ class Kws { @@ -296,8 +296,11 @@ class Kws {
296 } 296 }
297 297
298 decode(stream) { 298 decode(stream) {
299 - return this.Module._SherpaOnnxDecodeKeywordStream(  
300 - this.handle, stream.handle); 299 + this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle);
  300 + }
  301 +
  302 + reset(stream) {
  303 + this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle);
301 } 304 }
302 305
303 getResult(stream) { 306 getResult(stream) {