Committed by
GitHub
Fix keyword spotting. (#1689)
Reset the stream right after detecting a keyword
正在显示
43 个修改的文件
包含
781 行增加
和
261 行删除
| @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version" | @@ -574,29 +574,6 @@ echo "sherpa_onnx version: $sherpa_onnx_version" | ||
| 574 | pwd | 574 | pwd |
| 575 | ls -lh | 575 | ls -lh |
| 576 | 576 | ||
| 577 | -repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 | ||
| 578 | -log "Start testing ${repo}" | ||
| 579 | - | ||
| 580 | -pushd $dir | ||
| 581 | -curl -LS -O https://github.com/pkufool/keyword-spotting-models/releases/download/v0.1/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz | ||
| 582 | -tar xf sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz | ||
| 583 | -rm sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz | ||
| 584 | -popd | ||
| 585 | - | ||
| 586 | -repo=$dir/$repo | ||
| 587 | -ls -lh $repo | ||
| 588 | - | ||
| 589 | -python3 ./python-api-examples/keyword-spotter.py \ | ||
| 590 | - --tokens=$repo/tokens.txt \ | ||
| 591 | - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 592 | - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 593 | - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 594 | - --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 595 | - $repo/test_wavs/0.wav \ | ||
| 596 | - $repo/test_wavs/1.wav | ||
| 597 | - | ||
| 598 | -rm -rf $repo | ||
| 599 | - | ||
| 600 | if [[ x$OS != x'windows-latest' ]]; then | 577 | if [[ x$OS != x'windows-latest' ]]; then |
| 601 | echo "OS: $OS" | 578 | echo "OS: $OS" |
| 602 | 579 | ||
| @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then | @@ -612,15 +589,7 @@ if [[ x$OS != x'windows-latest' ]]; then | ||
| 612 | repo=$dir/$repo | 589 | repo=$dir/$repo |
| 613 | ls -lh $repo | 590 | ls -lh $repo |
| 614 | 591 | ||
| 615 | - python3 ./python-api-examples/keyword-spotter.py \ | ||
| 616 | - --tokens=$repo/tokens.txt \ | ||
| 617 | - --encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 618 | - --decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 619 | - --joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ | ||
| 620 | - --keywords-file=$repo/test_wavs/test_keywords.txt \ | ||
| 621 | - $repo/test_wavs/3.wav \ | ||
| 622 | - $repo/test_wavs/4.wav \ | ||
| 623 | - $repo/test_wavs/5.wav | 592 | + python3 ./python-api-examples/keyword-spotter.py |
| 624 | 593 | ||
| 625 | python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose | 594 | python3 sherpa-onnx/python/tests/test_keyword_spotter.py --verbose |
| 626 | 595 |
| @@ -79,6 +79,27 @@ jobs: | @@ -79,6 +79,27 @@ jobs: | ||
| 79 | otool -L ./install/lib/libsherpa-onnx-c-api.dylib | 79 | otool -L ./install/lib/libsherpa-onnx-c-api.dylib |
| 80 | fi | 80 | fi |
| 81 | 81 | ||
| 82 | + - name: Test kws (zh) | ||
| 83 | + shell: bash | ||
| 84 | + run: | | ||
| 85 | + gcc -o kws-c-api ./c-api-examples/kws-c-api.c \ | ||
| 86 | + -I ./build/install/include \ | ||
| 87 | + -L ./build/install/lib/ \ | ||
| 88 | + -l sherpa-onnx-c-api \ | ||
| 89 | + -l onnxruntime | ||
| 90 | + | ||
| 91 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 92 | + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 93 | + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 94 | + | ||
| 95 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 96 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 97 | + | ||
| 98 | + ./kws-c-api | ||
| 99 | + | ||
| 100 | + rm ./kws-c-api | ||
| 101 | + rm -rf sherpa-onnx-kws-* | ||
| 102 | + | ||
| 82 | - name: Test Kokoro TTS (en) | 103 | - name: Test Kokoro TTS (en) |
| 83 | shell: bash | 104 | shell: bash |
| 84 | run: | | 105 | run: | |
| @@ -81,6 +81,28 @@ jobs: | @@ -81,6 +81,28 @@ jobs: | ||
| 81 | otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib | 81 | otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib |
| 82 | fi | 82 | fi |
| 83 | 83 | ||
| 84 | + - name: Test KWS (zh) | ||
| 85 | + shell: bash | ||
| 86 | + run: | | ||
| 87 | + g++ -std=c++17 -o kws-cxx-api ./cxx-api-examples/kws-cxx-api.cc \ | ||
| 88 | + -I ./build/install/include \ | ||
| 89 | + -L ./build/install/lib/ \ | ||
| 90 | + -l sherpa-onnx-cxx-api \ | ||
| 91 | + -l sherpa-onnx-c-api \ | ||
| 92 | + -l onnxruntime | ||
| 93 | + | ||
| 94 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 95 | + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 96 | + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 97 | + | ||
| 98 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 99 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 100 | + | ||
| 101 | + ./kws-cxx-api | ||
| 102 | + | ||
| 103 | + rm kws-cxx-api | ||
| 104 | + rm -rf sherpa-onnx-kws-* | ||
| 105 | + | ||
| 84 | - name: Test Kokoro TTS (en) | 106 | - name: Test Kokoro TTS (en) |
| 85 | shell: bash | 107 | shell: bash |
| 86 | run: | | 108 | run: | |
| @@ -151,13 +151,15 @@ class MainActivity : AppCompatActivity() { | @@ -151,13 +151,15 @@ class MainActivity : AppCompatActivity() { | ||
| 151 | stream.acceptWaveform(samples, sampleRate = sampleRateInHz) | 151 | stream.acceptWaveform(samples, sampleRate = sampleRateInHz) |
| 152 | while (kws.isReady(stream)) { | 152 | while (kws.isReady(stream)) { |
| 153 | kws.decode(stream) | 153 | kws.decode(stream) |
| 154 | - } | ||
| 155 | 154 | ||
| 156 | val text = kws.getResult(stream).keyword | 155 | val text = kws.getResult(stream).keyword |
| 157 | 156 | ||
| 158 | var textToDisplay = lastText | 157 | var textToDisplay = lastText |
| 159 | 158 | ||
| 160 | if (text.isNotBlank()) { | 159 | if (text.isNotBlank()) { |
| 160 | + // Remember to reset the stream right after detecting a keyword | ||
| 161 | + | ||
| 162 | + kws.reset(stream) | ||
| 161 | if (lastText.isBlank()) { | 163 | if (lastText.isBlank()) { |
| 162 | textToDisplay = "$idx: $text" | 164 | textToDisplay = "$idx: $text" |
| 163 | } else { | 165 | } else { |
| @@ -173,6 +175,7 @@ class MainActivity : AppCompatActivity() { | @@ -173,6 +175,7 @@ class MainActivity : AppCompatActivity() { | ||
| 173 | } | 175 | } |
| 174 | } | 176 | } |
| 175 | } | 177 | } |
| 178 | + } | ||
| 176 | 179 | ||
| 177 | private fun initMicrophone(): Boolean { | 180 | private fun initMicrophone(): Boolean { |
| 178 | if (ActivityCompat.checkSelfPermission( | 181 | if (ActivityCompat.checkSelfPermission( |
| @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -4,6 +4,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 4 | add_executable(decode-file-c-api decode-file-c-api.c) | 4 | add_executable(decode-file-c-api decode-file-c-api.c) |
| 5 | target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs) | 5 | target_link_libraries(decode-file-c-api sherpa-onnx-c-api cargs) |
| 6 | 6 | ||
| 7 | +add_executable(kws-c-api kws-c-api.c) | ||
| 8 | +target_link_libraries(kws-c-api sherpa-onnx-c-api) | ||
| 9 | + | ||
| 7 | if(SHERPA_ONNX_ENABLE_TTS) | 10 | if(SHERPA_ONNX_ENABLE_TTS) |
| 8 | add_executable(offline-tts-c-api offline-tts-c-api.c) | 11 | add_executable(offline-tts-c-api offline-tts-c-api.c) |
| 9 | target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs) | 12 | target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs) |
c-api-examples/kws-c-api.c
0 → 100644
| 1 | +// c-api-examples/kws-c-api.c | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +// | ||
| 5 | +// This file demonstrates how to use keywords spotter with sherpa-onnx's C | ||
| 6 | +// clang-format off | ||
| 7 | +// | ||
| 8 | +// Usage | ||
| 9 | +// | ||
| 10 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 11 | +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 12 | +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 13 | +// | ||
| 14 | +// ./kws-c-api | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | +#include <stdio.h> | ||
| 18 | +#include <stdlib.h> // exit | ||
| 19 | +#include <string.h> // memset | ||
| 20 | + | ||
| 21 | +#include "sherpa-onnx/c-api/c-api.h" | ||
| 22 | + | ||
| 23 | +int32_t main() { | ||
| 24 | + SherpaOnnxKeywordSpotterConfig config; | ||
| 25 | + | ||
| 26 | + memset(&config, 0, sizeof(config)); | ||
| 27 | + config.model_config.transducer.encoder = | ||
| 28 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 29 | + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 30 | + | ||
| 31 | + config.model_config.transducer.decoder = | ||
| 32 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 33 | + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 34 | + | ||
| 35 | + config.model_config.transducer.joiner = | ||
| 36 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 37 | + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 38 | + | ||
| 39 | + config.model_config.tokens = | ||
| 40 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; | ||
| 41 | + | ||
| 42 | + config.model_config.provider = "cpu"; | ||
| 43 | + config.model_config.num_threads = 1; | ||
| 44 | + config.model_config.debug = 1; | ||
| 45 | + | ||
| 46 | + config.keywords_file = | ||
| 47 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/" | ||
| 48 | + "test_keywords.txt"; | ||
| 49 | + | ||
| 50 | + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&config); | ||
| 51 | + if (!kws) { | ||
| 52 | + fprintf(stderr, "Please check your config"); | ||
| 53 | + exit(-1); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + fprintf(stderr, | ||
| 57 | + "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"); | ||
| 58 | + | ||
| 59 | + const char *wav_filename = | ||
| 60 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"; | ||
| 61 | + | ||
| 62 | + float tail_paddings[8000] = {0}; // 0.5 seconds | ||
| 63 | + | ||
| 64 | + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); | ||
| 65 | + if (wave == NULL) { | ||
| 66 | + fprintf(stderr, "Failed to read %s\n", wav_filename); | ||
| 67 | + exit(-1); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); | ||
| 71 | + if (!stream) { | ||
| 72 | + fprintf(stderr, "Failed to create stream\n"); | ||
| 73 | + exit(-1); | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, | ||
| 77 | + wave->num_samples); | ||
| 78 | + | ||
| 79 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 80 | + sizeof(tail_paddings) / sizeof(float)); | ||
| 81 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 82 | + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { | ||
| 83 | + SherpaOnnxDecodeKeywordStream(kws, stream); | ||
| 84 | + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); | ||
| 85 | + if (r && r->json && strlen(r->keyword)) { | ||
| 86 | + fprintf(stderr, "Detected keyword: %s\n", r->json); | ||
| 87 | + | ||
| 88 | + // Remember to reset the keyword stream right after a keyword is detected | ||
| 89 | + SherpaOnnxResetKeywordStream(kws, stream); | ||
| 90 | + } | ||
| 91 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 92 | + } | ||
| 93 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 94 | + | ||
| 95 | + // -------------------------------------------------------------------------- | ||
| 96 | + | ||
| 97 | + fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n"); | ||
| 98 | + | ||
| 99 | + stream = SherpaOnnxCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员"); | ||
| 100 | + | ||
| 101 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, | ||
| 102 | + wave->num_samples); | ||
| 103 | + | ||
| 104 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 105 | + sizeof(tail_paddings) / sizeof(float)); | ||
| 106 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 107 | + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { | ||
| 108 | + SherpaOnnxDecodeKeywordStream(kws, stream); | ||
| 109 | + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); | ||
| 110 | + if (r && r->json && strlen(r->keyword)) { | ||
| 111 | + fprintf(stderr, "Detected keyword: %s\n", r->json); | ||
| 112 | + | ||
| 113 | + // Remember to reset the keyword stream | ||
| 114 | + SherpaOnnxResetKeywordStream(kws, stream); | ||
| 115 | + } | ||
| 116 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 117 | + } | ||
| 118 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 119 | + | ||
| 120 | + // -------------------------------------------------------------------------- | ||
| 121 | + | ||
| 122 | + fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n"); | ||
| 123 | + | ||
| 124 | + stream = SherpaOnnxCreateKeywordStreamWithKeywords( | ||
| 125 | + kws, "y ǎn y uán @演员/zh ī m íng @知名"); | ||
| 126 | + | ||
| 127 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, | ||
| 128 | + wave->num_samples); | ||
| 129 | + | ||
| 130 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 131 | + sizeof(tail_paddings) / sizeof(float)); | ||
| 132 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 133 | + while (SherpaOnnxIsKeywordStreamReady(kws, stream)) { | ||
| 134 | + SherpaOnnxDecodeKeywordStream(kws, stream); | ||
| 135 | + const SherpaOnnxKeywordResult *r = SherpaOnnxGetKeywordResult(kws, stream); | ||
| 136 | + if (r && r->json && strlen(r->keyword)) { | ||
| 137 | + fprintf(stderr, "Detected keyword: %s\n", r->json); | ||
| 138 | + | ||
| 139 | + // Remember to reset the keyword stream | ||
| 140 | + SherpaOnnxResetKeywordStream(kws, stream); | ||
| 141 | + } | ||
| 142 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 143 | + } | ||
| 144 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 145 | + | ||
| 146 | + SherpaOnnxFreeWave(wave); | ||
| 147 | + SherpaOnnxDestroyKeywordSpotter(kws); | ||
| 148 | + | ||
| 149 | + return 0; | ||
| 150 | +} |
| @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) | 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) |
| 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) | 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) |
| 5 | 5 | ||
| 6 | +add_executable(kws-cxx-api ./kws-cxx-api.cc) | ||
| 7 | +target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api) | ||
| 8 | + | ||
| 6 | add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) | 9 | add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) |
| 7 | target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api) | 10 | target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-onnx-cxx-api) |
| 8 | 11 |
cxx-api-examples/kws-cxx-api.cc
0 → 100644
| 1 | +// cxx-api-examples/kws-cxx-api.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +// | ||
| 5 | +// This file demonstrates how to use keywords spotter with sherpa-onnx's C | ||
| 6 | +// clang-format off | ||
| 7 | +// | ||
| 8 | +// Usage | ||
| 9 | +// | ||
| 10 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 11 | +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 12 | +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 13 | +// | ||
| 14 | +// ./kws-cxx-api | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | +#include <array> | ||
| 18 | +#include <iostream> | ||
| 19 | + | ||
| 20 | +#include "sherpa-onnx/c-api/cxx-api.h" | ||
| 21 | + | ||
| 22 | +int32_t main() { | ||
| 23 | + using namespace sherpa_onnx::cxx; // NOLINT | ||
| 24 | + | ||
| 25 | + KeywordSpotterConfig config; | ||
| 26 | + config.model_config.transducer.encoder = | ||
| 27 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 28 | + "encoder-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 29 | + | ||
| 30 | + config.model_config.transducer.decoder = | ||
| 31 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 32 | + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 33 | + | ||
| 34 | + config.model_config.transducer.joiner = | ||
| 35 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/" | ||
| 36 | + "joiner-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 37 | + | ||
| 38 | + config.model_config.tokens = | ||
| 39 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"; | ||
| 40 | + | ||
| 41 | + config.model_config.provider = "cpu"; | ||
| 42 | + config.model_config.num_threads = 1; | ||
| 43 | + config.model_config.debug = 1; | ||
| 44 | + | ||
| 45 | + config.keywords_file = | ||
| 46 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/" | ||
| 47 | + "test_keywords.txt"; | ||
| 48 | + | ||
| 49 | + KeywordSpotter kws = KeywordSpotter::Create(config); | ||
| 50 | + if (!kws.Get()) { | ||
| 51 | + std::cerr << "Please check your config\n"; | ||
| 52 | + return -1; | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + std::cout | ||
| 56 | + << "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"; | ||
| 57 | + | ||
| 58 | + std::string wave_filename = | ||
| 59 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"; | ||
| 60 | + | ||
| 61 | + std::array<float, 8000> tail_paddings = {0}; // 0.5 seconds | ||
| 62 | + | ||
| 63 | + Wave wave = ReadWave(wave_filename); | ||
| 64 | + if (wave.samples.empty()) { | ||
| 65 | + std::cerr << "Failed to read: '" << wave_filename << "'\n"; | ||
| 66 | + return -1; | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + OnlineStream stream = kws.CreateStream(); | ||
| 70 | + if (!stream.Get()) { | ||
| 71 | + std::cerr << "Failed to create stream\n"; | ||
| 72 | + return -1; | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), | ||
| 76 | + wave.samples.size()); | ||
| 77 | + | ||
| 78 | + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), | ||
| 79 | + tail_paddings.size()); | ||
| 80 | + stream.InputFinished(); | ||
| 81 | + | ||
| 82 | + while (kws.IsReady(&stream)) { | ||
| 83 | + kws.Decode(&stream); | ||
| 84 | + auto r = kws.GetResult(&stream); | ||
| 85 | + if (!r.keyword.empty()) { | ||
| 86 | + std::cout << "Detected keyword: " << r.json << "\n"; | ||
| 87 | + | ||
| 88 | + // Remember to reset the keyword stream right after a keyword is detected | ||
| 89 | + kws.Reset(&stream); | ||
| 90 | + } | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + // -------------------------------------------------------------------------- | ||
| 94 | + | ||
| 95 | + std::cout << "--Use pre-defined keywords + add a new keyword--\n"; | ||
| 96 | + | ||
| 97 | + stream = kws.CreateStream("y ǎn y uán @演员"); | ||
| 98 | + | ||
| 99 | + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), | ||
| 100 | + wave.samples.size()); | ||
| 101 | + | ||
| 102 | + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), | ||
| 103 | + tail_paddings.size()); | ||
| 104 | + stream.InputFinished(); | ||
| 105 | + | ||
| 106 | + while (kws.IsReady(&stream)) { | ||
| 107 | + kws.Decode(&stream); | ||
| 108 | + auto r = kws.GetResult(&stream); | ||
| 109 | + if (!r.keyword.empty()) { | ||
| 110 | + std::cout << "Detected keyword: " << r.json << "\n"; | ||
| 111 | + | ||
| 112 | + // Remember to reset the keyword stream right after a keyword is detected | ||
| 113 | + kws.Reset(&stream); | ||
| 114 | + } | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + // -------------------------------------------------------------------------- | ||
| 118 | + | ||
| 119 | + std::cout << "--Use pre-defined keywords + add two new keywords--\n"; | ||
| 120 | + | ||
| 121 | + stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名"); | ||
| 122 | + | ||
| 123 | + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), | ||
| 124 | + wave.samples.size()); | ||
| 125 | + | ||
| 126 | + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), | ||
| 127 | + tail_paddings.size()); | ||
| 128 | + stream.InputFinished(); | ||
| 129 | + | ||
| 130 | + while (kws.IsReady(&stream)) { | ||
| 131 | + kws.Decode(&stream); | ||
| 132 | + auto r = kws.GetResult(&stream); | ||
| 133 | + if (!r.keyword.empty()) { | ||
| 134 | + std::cout << "Detected keyword: " << r.json << "\n"; | ||
| 135 | + | ||
| 136 | + // Remember to reset the keyword stream right after a keyword is detected | ||
| 137 | + kws.Reset(&stream); | ||
| 138 | + } | ||
| 139 | + } | ||
| 140 | + return 0; | ||
| 141 | +} |
| @@ -73,6 +73,8 @@ void main(List<String> arguments) async { | @@ -73,6 +73,8 @@ void main(List<String> arguments) async { | ||
| 73 | spotter.decode(stream); | 73 | spotter.decode(stream); |
| 74 | final result = spotter.getResult(stream); | 74 | final result = spotter.getResult(stream); |
| 75 | if (result.keyword != '') { | 75 | if (result.keyword != '') { |
| 76 | + // Remember to reset the stream right after detecting a keyword | ||
| 77 | + spotter.reset(stream); | ||
| 76 | print('Detected: ${result.keyword}'); | 78 | print('Detected: ${result.keyword}'); |
| 77 | } | 79 | } |
| 78 | } | 80 | } |
| @@ -53,6 +53,8 @@ class KeywordSpotterDemo | @@ -53,6 +53,8 @@ class KeywordSpotterDemo | ||
| 53 | var result = kws.GetResult(s); | 53 | var result = kws.GetResult(s); |
| 54 | if (result.Keyword != string.Empty) | 54 | if (result.Keyword != string.Empty) |
| 55 | { | 55 | { |
| 56 | + // Remember to call Reset() right after detecting a keyword | ||
| 57 | + kws.Reset(s); | ||
| 56 | Console.WriteLine("Detected: {0}", result.Keyword); | 58 | Console.WriteLine("Detected: {0}", result.Keyword); |
| 57 | } | 59 | } |
| 58 | } | 60 | } |
| @@ -70,6 +72,8 @@ class KeywordSpotterDemo | @@ -70,6 +72,8 @@ class KeywordSpotterDemo | ||
| 70 | var result = kws.GetResult(s); | 72 | var result = kws.GetResult(s); |
| 71 | if (result.Keyword != string.Empty) | 73 | if (result.Keyword != string.Empty) |
| 72 | { | 74 | { |
| 75 | + // Remember to call Reset() right after detecting a keyword | ||
| 76 | + kws.Reset(s); | ||
| 73 | Console.WriteLine("Detected: {0}", result.Keyword); | 77 | Console.WriteLine("Detected: {0}", result.Keyword); |
| 74 | } | 78 | } |
| 75 | } | 79 | } |
| @@ -89,6 +93,8 @@ class KeywordSpotterDemo | @@ -89,6 +93,8 @@ class KeywordSpotterDemo | ||
| 89 | var result = kws.GetResult(s); | 93 | var result = kws.GetResult(s); |
| 90 | if (result.Keyword != string.Empty) | 94 | if (result.Keyword != string.Empty) |
| 91 | { | 95 | { |
| 96 | + // Remember to call Reset() right after detecting a keyword | ||
| 97 | + kws.Reset(s); | ||
| 92 | Console.WriteLine("Detected: {0}", result.Keyword); | 98 | Console.WriteLine("Detected: {0}", result.Keyword); |
| 93 | } | 99 | } |
| 94 | } | 100 | } |
| @@ -107,13 +107,16 @@ class KeywordSpotterDemo | @@ -107,13 +107,16 @@ class KeywordSpotterDemo | ||
| 107 | while (kws.IsReady(s)) | 107 | while (kws.IsReady(s)) |
| 108 | { | 108 | { |
| 109 | kws.Decode(s); | 109 | kws.Decode(s); |
| 110 | - } | ||
| 111 | 110 | ||
| 112 | var result = kws.GetResult(s); | 111 | var result = kws.GetResult(s); |
| 113 | if (result.Keyword != string.Empty) | 112 | if (result.Keyword != string.Empty) |
| 114 | { | 113 | { |
| 114 | + // Remember to call Reset() right after detecting a keyword | ||
| 115 | + kws.Reset(s); | ||
| 116 | + | ||
| 115 | Console.WriteLine("Detected: {0}", result.Keyword); | 117 | Console.WriteLine("Detected: {0}", result.Keyword); |
| 116 | } | 118 | } |
| 119 | + } | ||
| 117 | 120 | ||
| 118 | Thread.Sleep(200); // ms | 121 | Thread.Sleep(200); // ms |
| 119 | } | 122 | } |
| @@ -168,6 +168,10 @@ class KeywordSpotter { | @@ -168,6 +168,10 @@ class KeywordSpotter { | ||
| 168 | SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr); | 168 | SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr); |
| 169 | } | 169 | } |
| 170 | 170 | ||
| 171 | + void reset(OnlineStream stream) { | ||
| 172 | + SherpaOnnxBindings.resetKeywordStream?.call(ptr, stream.ptr); | ||
| 173 | + } | ||
| 174 | + | ||
| 171 | Pointer<SherpaOnnxKeywordSpotter> ptr; | 175 | Pointer<SherpaOnnxKeywordSpotter> ptr; |
| 172 | KeywordSpotterConfig config; | 176 | KeywordSpotterConfig config; |
| 173 | } | 177 | } |
| @@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function( | @@ -667,6 +667,12 @@ typedef DecodeKeywordStreamNative = Void Function( | ||
| 667 | typedef DecodeKeywordStream = void Function( | 667 | typedef DecodeKeywordStream = void Function( |
| 668 | Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); | 668 | Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); |
| 669 | 669 | ||
| 670 | +typedef ResetKeywordStreamNative = Void Function( | ||
| 671 | + Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); | ||
| 672 | + | ||
| 673 | +typedef ResetKeywordStream = void Function( | ||
| 674 | + Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); | ||
| 675 | + | ||
| 670 | typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function( | 676 | typedef GetKeywordResultAsJsonNative = Pointer<Utf8> Function( |
| 671 | Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); | 677 | Pointer<SherpaOnnxKeywordSpotter>, Pointer<SherpaOnnxOnlineStream>); |
| 672 | 678 | ||
| @@ -1157,6 +1163,7 @@ class SherpaOnnxBindings { | @@ -1157,6 +1163,7 @@ class SherpaOnnxBindings { | ||
| 1157 | static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords; | 1163 | static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords; |
| 1158 | static IsKeywordStreamReady? isKeywordStreamReady; | 1164 | static IsKeywordStreamReady? isKeywordStreamReady; |
| 1159 | static DecodeKeywordStream? decodeKeywordStream; | 1165 | static DecodeKeywordStream? decodeKeywordStream; |
| 1166 | + static ResetKeywordStream? resetKeywordStream; | ||
| 1160 | static GetKeywordResultAsJson? getKeywordResultAsJson; | 1167 | static GetKeywordResultAsJson? getKeywordResultAsJson; |
| 1161 | static FreeKeywordResultJson? freeKeywordResultJson; | 1168 | static FreeKeywordResultJson? freeKeywordResultJson; |
| 1162 | 1169 | ||
| @@ -1459,6 +1466,11 @@ class SherpaOnnxBindings { | @@ -1459,6 +1466,11 @@ class SherpaOnnxBindings { | ||
| 1459 | 'SherpaOnnxDecodeKeywordStream') | 1466 | 'SherpaOnnxDecodeKeywordStream') |
| 1460 | .asFunction(); | 1467 | .asFunction(); |
| 1461 | 1468 | ||
| 1469 | + resetKeywordStream ??= dynamicLibrary | ||
| 1470 | + .lookup<NativeFunction<ResetKeywordStreamNative>>( | ||
| 1471 | + 'SherpaOnnxResetKeywordStream') | ||
| 1472 | + .asFunction(); | ||
| 1473 | + | ||
| 1462 | getKeywordResultAsJson ??= dynamicLibrary | 1474 | getKeywordResultAsJson ??= dynamicLibrary |
| 1463 | .lookup<NativeFunction<GetKeywordResultAsJsonNative>>( | 1475 | .lookup<NativeFunction<GetKeywordResultAsJsonNative>>( |
| 1464 | 'SherpaOnnxGetKeywordResultAsJson') | 1476 | 'SherpaOnnxGetKeywordResultAsJson') |
| @@ -43,6 +43,8 @@ func main() { | @@ -43,6 +43,8 @@ func main() { | ||
| 43 | spotter.Decode(stream) | 43 | spotter.Decode(stream) |
| 44 | result := spotter.GetResult(stream) | 44 | result := spotter.GetResult(stream) |
| 45 | if result.Keyword != "" { | 45 | if result.Keyword != "" { |
| 46 | + // You have to reset the stream right after detecting a keyword | ||
| 47 | + spotter.Reset(stream) | ||
| 46 | log.Printf("Detected %v\n", result.Keyword) | 48 | log.Printf("Detected %v\n", result.Keyword) |
| 47 | } | 49 | } |
| 48 | } | 50 | } |
| @@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( | @@ -46,7 +46,7 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( | ||
| 46 | SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); | 46 | SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); |
| 47 | SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); | 47 | SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); |
| 48 | 48 | ||
| 49 | - SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); | 49 | + const SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); |
| 50 | 50 | ||
| 51 | if (c.model_config.transducer.encoder) { | 51 | if (c.model_config.transducer.encoder) { |
| 52 | delete[] c.model_config.transducer.encoder; | 52 | delete[] c.model_config.transducer.encoder; |
| @@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( | @@ -100,7 +100,8 @@ static Napi::External<SherpaOnnxKeywordSpotter> CreateKeywordSpotterWrapper( | ||
| 100 | } | 100 | } |
| 101 | 101 | ||
| 102 | return Napi::External<SherpaOnnxKeywordSpotter>::New( | 102 | return Napi::External<SherpaOnnxKeywordSpotter>::New( |
| 103 | - env, kws, [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) { | 103 | + env, const_cast<SherpaOnnxKeywordSpotter *>(kws), |
| 104 | + [](Napi::Env env, SherpaOnnxKeywordSpotter *kws) { | ||
| 104 | SherpaOnnxDestroyKeywordSpotter(kws); | 105 | SherpaOnnxDestroyKeywordSpotter(kws); |
| 105 | }); | 106 | }); |
| 106 | } | 107 | } |
| @@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper( | @@ -125,13 +126,14 @@ static Napi::External<SherpaOnnxOnlineStream> CreateKeywordStreamWrapper( | ||
| 125 | return {}; | 126 | return {}; |
| 126 | } | 127 | } |
| 127 | 128 | ||
| 128 | - SherpaOnnxKeywordSpotter *kws = | 129 | + const SherpaOnnxKeywordSpotter *kws = |
| 129 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); | 130 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); |
| 130 | 131 | ||
| 131 | - SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); | 132 | + const SherpaOnnxOnlineStream *stream = SherpaOnnxCreateKeywordStream(kws); |
| 132 | 133 | ||
| 133 | return Napi::External<SherpaOnnxOnlineStream>::New( | 134 | return Napi::External<SherpaOnnxOnlineStream>::New( |
| 134 | - env, stream, [](Napi::Env env, SherpaOnnxOnlineStream *stream) { | 135 | + env, const_cast<SherpaOnnxOnlineStream *>(stream), |
| 136 | + [](Napi::Env env, SherpaOnnxOnlineStream *stream) { | ||
| 135 | SherpaOnnxDestroyOnlineStream(stream); | 137 | SherpaOnnxDestroyOnlineStream(stream); |
| 136 | }); | 138 | }); |
| 137 | } | 139 | } |
| @@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper( | @@ -162,10 +164,10 @@ static Napi::Boolean IsKeywordStreamReadyWrapper( | ||
| 162 | return {}; | 164 | return {}; |
| 163 | } | 165 | } |
| 164 | 166 | ||
| 165 | - SherpaOnnxKeywordSpotter *kws = | 167 | + const SherpaOnnxKeywordSpotter *kws = |
| 166 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); | 168 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); |
| 167 | 169 | ||
| 168 | - SherpaOnnxOnlineStream *stream = | 170 | + const SherpaOnnxOnlineStream *stream = |
| 169 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); | 171 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); |
| 170 | 172 | ||
| 171 | int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream); | 173 | int32_t is_ready = SherpaOnnxIsKeywordStreamReady(kws, stream); |
| @@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) { | @@ -198,15 +200,49 @@ static void DecodeKeywordStreamWrapper(const Napi::CallbackInfo &info) { | ||
| 198 | return; | 200 | return; |
| 199 | } | 201 | } |
| 200 | 202 | ||
| 201 | - SherpaOnnxKeywordSpotter *kws = | 203 | + const SherpaOnnxKeywordSpotter *kws = |
| 202 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); | 204 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); |
| 203 | 205 | ||
| 204 | - SherpaOnnxOnlineStream *stream = | 206 | + const SherpaOnnxOnlineStream *stream = |
| 205 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); | 207 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); |
| 206 | 208 | ||
| 207 | SherpaOnnxDecodeKeywordStream(kws, stream); | 209 | SherpaOnnxDecodeKeywordStream(kws, stream); |
| 208 | } | 210 | } |
| 209 | 211 | ||
| 212 | +static void ResetKeywordStreamWrapper(const Napi::CallbackInfo &info) { | ||
| 213 | + Napi::Env env = info.Env(); | ||
| 214 | + if (info.Length() != 2) { | ||
| 215 | + std::ostringstream os; | ||
| 216 | + os << "Expect only 2 arguments. Given: " << info.Length(); | ||
| 217 | + | ||
| 218 | + Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException(); | ||
| 219 | + | ||
| 220 | + return; | ||
| 221 | + } | ||
| 222 | + | ||
| 223 | + if (!info[0].IsExternal()) { | ||
| 224 | + Napi::TypeError::New(env, "Argument 0 should be a keyword spotter pointer.") | ||
| 225 | + .ThrowAsJavaScriptException(); | ||
| 226 | + | ||
| 227 | + return; | ||
| 228 | + } | ||
| 229 | + | ||
| 230 | + if (!info[1].IsExternal()) { | ||
| 231 | + Napi::TypeError::New(env, "Argument 1 should be an online stream pointer.") | ||
| 232 | + .ThrowAsJavaScriptException(); | ||
| 233 | + | ||
| 234 | + return; | ||
| 235 | + } | ||
| 236 | + | ||
| 237 | + const SherpaOnnxKeywordSpotter *kws = | ||
| 238 | + info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); | ||
| 239 | + | ||
| 240 | + const SherpaOnnxOnlineStream *stream = | ||
| 241 | + info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); | ||
| 242 | + | ||
| 243 | + SherpaOnnxResetKeywordStream(kws, stream); | ||
| 244 | +} | ||
| 245 | + | ||
| 210 | static Napi::String GetKeywordResultAsJsonWrapper( | 246 | static Napi::String GetKeywordResultAsJsonWrapper( |
| 211 | const Napi::CallbackInfo &info) { | 247 | const Napi::CallbackInfo &info) { |
| 212 | Napi::Env env = info.Env(); | 248 | Napi::Env env = info.Env(); |
| @@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper( | @@ -233,10 +269,10 @@ static Napi::String GetKeywordResultAsJsonWrapper( | ||
| 233 | return {}; | 269 | return {}; |
| 234 | } | 270 | } |
| 235 | 271 | ||
| 236 | - SherpaOnnxKeywordSpotter *kws = | 272 | + const SherpaOnnxKeywordSpotter *kws = |
| 237 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); | 273 | info[0].As<Napi::External<SherpaOnnxKeywordSpotter>>().Data(); |
| 238 | 274 | ||
| 239 | - SherpaOnnxOnlineStream *stream = | 275 | + const SherpaOnnxOnlineStream *stream = |
| 240 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); | 276 | info[1].As<Napi::External<SherpaOnnxOnlineStream>>().Data(); |
| 241 | 277 | ||
| 242 | const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream); | 278 | const char *json = SherpaOnnxGetKeywordResultAsJson(kws, stream); |
| @@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) { | @@ -261,6 +297,9 @@ void InitKeywordSpotting(Napi::Env env, Napi::Object exports) { | ||
| 261 | exports.Set(Napi::String::New(env, "decodeKeywordStream"), | 297 | exports.Set(Napi::String::New(env, "decodeKeywordStream"), |
| 262 | Napi::Function::New(env, DecodeKeywordStreamWrapper)); | 298 | Napi::Function::New(env, DecodeKeywordStreamWrapper)); |
| 263 | 299 | ||
| 300 | + exports.Set(Napi::String::New(env, "resetKeywordStream"), | ||
| 301 | + Napi::Function::New(env, ResetKeywordStreamWrapper)); | ||
| 302 | + | ||
| 264 | exports.Set(Napi::String::New(env, "getKeywordResultAsJson"), | 303 | exports.Set(Napi::String::New(env, "getKeywordResultAsJson"), |
| 265 | Napi::Function::New(env, GetKeywordResultAsJsonWrapper)); | 304 | Napi::Function::New(env, GetKeywordResultAsJsonWrapper)); |
| 266 | } | 305 | } |
| @@ -56,6 +56,8 @@ public class KyewordSpotterFromFile { | @@ -56,6 +56,8 @@ public class KyewordSpotterFromFile { | ||
| 56 | 56 | ||
| 57 | String keyword = kws.getResult(stream).getKeyword(); | 57 | String keyword = kws.getResult(stream).getKeyword(); |
| 58 | if (!keyword.isEmpty()) { | 58 | if (!keyword.isEmpty()) { |
| 59 | + // Remember to reset the stream right after detecting a keyword | ||
| 60 | + kws.reset(stream); | ||
| 59 | System.out.printf("Detected keyword: %s\n", keyword); | 61 | System.out.printf("Detected keyword: %s\n", keyword); |
| 60 | } | 62 | } |
| 61 | } | 63 | } |
| @@ -41,6 +41,9 @@ while (kws.isReady(stream)) { | @@ -41,6 +41,9 @@ while (kws.isReady(stream)) { | ||
| 41 | const keyword = kws.getResult(stream).keyword; | 41 | const keyword = kws.getResult(stream).keyword; |
| 42 | if (keyword != '') { | 42 | if (keyword != '') { |
| 43 | detectedKeywords.push(keyword); | 43 | detectedKeywords.push(keyword); |
| 44 | + | ||
| 45 | + // remember to reset the stream right after detecting a keyword | ||
| 46 | + kws.reset(stream); | ||
| 44 | } | 47 | } |
| 45 | } | 48 | } |
| 46 | console.log(detectedKeywords); | 49 | console.log(detectedKeywords); |
| @@ -169,6 +169,8 @@ def main(): | @@ -169,6 +169,8 @@ def main(): | ||
| 169 | 169 | ||
| 170 | print("Started! Please speak") | 170 | print("Started! Please speak") |
| 171 | 171 | ||
| 172 | + idx = 0 | ||
| 173 | + | ||
| 172 | sample_rate = 16000 | 174 | sample_rate = 16000 |
| 173 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 175 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 174 | stream = keyword_spotter.create_stream() | 176 | stream = keyword_spotter.create_stream() |
| @@ -181,7 +183,10 @@ def main(): | @@ -181,7 +183,10 @@ def main(): | ||
| 181 | keyword_spotter.decode_stream(stream) | 183 | keyword_spotter.decode_stream(stream) |
| 182 | result = keyword_spotter.get_result(stream) | 184 | result = keyword_spotter.get_result(stream) |
| 183 | if result: | 185 | if result: |
| 184 | - print("\r{}".format(result), end="", flush=True) | 186 | + print(f"{idx}: {result }") |
| 187 | + idx += 1 | ||
| 188 | + # Remember to reset stream right after detecting a keyword | ||
| 189 | + keyword_spotter.reset_stream(stream) | ||
| 185 | 190 | ||
| 186 | 191 | ||
| 187 | if __name__ == "__main__": | 192 | if __name__ == "__main__": |
| @@ -18,122 +18,6 @@ import numpy as np | @@ -18,122 +18,6 @@ import numpy as np | ||
| 18 | import sherpa_onnx | 18 | import sherpa_onnx |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | -def get_args(): | ||
| 22 | - parser = argparse.ArgumentParser( | ||
| 23 | - formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 24 | - ) | ||
| 25 | - | ||
| 26 | - parser.add_argument( | ||
| 27 | - "--tokens", | ||
| 28 | - type=str, | ||
| 29 | - help="Path to tokens.txt", | ||
| 30 | - ) | ||
| 31 | - | ||
| 32 | - parser.add_argument( | ||
| 33 | - "--encoder", | ||
| 34 | - type=str, | ||
| 35 | - help="Path to the transducer encoder model", | ||
| 36 | - ) | ||
| 37 | - | ||
| 38 | - parser.add_argument( | ||
| 39 | - "--decoder", | ||
| 40 | - type=str, | ||
| 41 | - help="Path to the transducer decoder model", | ||
| 42 | - ) | ||
| 43 | - | ||
| 44 | - parser.add_argument( | ||
| 45 | - "--joiner", | ||
| 46 | - type=str, | ||
| 47 | - help="Path to the transducer joiner model", | ||
| 48 | - ) | ||
| 49 | - | ||
| 50 | - parser.add_argument( | ||
| 51 | - "--num-threads", | ||
| 52 | - type=int, | ||
| 53 | - default=1, | ||
| 54 | - help="Number of threads for neural network computation", | ||
| 55 | - ) | ||
| 56 | - | ||
| 57 | - parser.add_argument( | ||
| 58 | - "--provider", | ||
| 59 | - type=str, | ||
| 60 | - default="cpu", | ||
| 61 | - help="Valid values: cpu, cuda, coreml", | ||
| 62 | - ) | ||
| 63 | - | ||
| 64 | - parser.add_argument( | ||
| 65 | - "--max-active-paths", | ||
| 66 | - type=int, | ||
| 67 | - default=4, | ||
| 68 | - help=""" | ||
| 69 | - It specifies number of active paths to keep during decoding. | ||
| 70 | - """, | ||
| 71 | - ) | ||
| 72 | - | ||
| 73 | - parser.add_argument( | ||
| 74 | - "--num-trailing-blanks", | ||
| 75 | - type=int, | ||
| 76 | - default=1, | ||
| 77 | - help="""The number of trailing blanks a keyword should be followed. Setting | ||
| 78 | - to a larger value (e.g. 8) when your keywords has overlapping tokens | ||
| 79 | - between each other. | ||
| 80 | - """, | ||
| 81 | - ) | ||
| 82 | - | ||
| 83 | - parser.add_argument( | ||
| 84 | - "--keywords-file", | ||
| 85 | - type=str, | ||
| 86 | - help=""" | ||
| 87 | - The file containing keywords, one words/phrases per line, and for each | ||
| 88 | - phrase the bpe/cjkchar/pinyin are separated by a space. For example: | ||
| 89 | - | ||
| 90 | - ▁HE LL O ▁WORLD | ||
| 91 | - x iǎo ài t óng x ué | ||
| 92 | - """, | ||
| 93 | - ) | ||
| 94 | - | ||
| 95 | - parser.add_argument( | ||
| 96 | - "--keywords-score", | ||
| 97 | - type=float, | ||
| 98 | - default=1.0, | ||
| 99 | - help=""" | ||
| 100 | - The boosting score of each token for keywords. The larger the easier to | ||
| 101 | - survive beam search. | ||
| 102 | - """, | ||
| 103 | - ) | ||
| 104 | - | ||
| 105 | - parser.add_argument( | ||
| 106 | - "--keywords-threshold", | ||
| 107 | - type=float, | ||
| 108 | - default=0.25, | ||
| 109 | - help=""" | ||
| 110 | - The trigger threshold (i.e. probability) of the keyword. The larger the | ||
| 111 | - harder to trigger. | ||
| 112 | - """, | ||
| 113 | - ) | ||
| 114 | - | ||
| 115 | - parser.add_argument( | ||
| 116 | - "sound_files", | ||
| 117 | - type=str, | ||
| 118 | - nargs="+", | ||
| 119 | - help="The input sound file(s) to decode. Each file must be of WAVE" | ||
| 120 | - "format with a single channel, and each sample has 16-bit, " | ||
| 121 | - "i.e., int16_t. " | ||
| 122 | - "The sample rate of the file can be arbitrary and does not need to " | ||
| 123 | - "be 16 kHz", | ||
| 124 | - ) | ||
| 125 | - | ||
| 126 | - return parser.parse_args() | ||
| 127 | - | ||
| 128 | - | ||
| 129 | -def assert_file_exists(filename: str): | ||
| 130 | - assert Path(filename).is_file(), ( | ||
| 131 | - f"{filename} does not exist!\n" | ||
| 132 | - "Please refer to " | ||
| 133 | - "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" | ||
| 134 | - ) | ||
| 135 | - | ||
| 136 | - | ||
| 137 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | 21 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: |
| 138 | """ | 22 | """ |
| 139 | Args: | 23 | Args: |
| @@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | @@ -159,83 +43,74 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 159 | return samples_float32, f.getframerate() | 43 | return samples_float32, f.getframerate() |
| 160 | 44 | ||
| 161 | 45 | ||
| 162 | -def main(): | ||
| 163 | - args = get_args() | ||
| 164 | - assert_file_exists(args.tokens) | ||
| 165 | - assert_file_exists(args.encoder) | ||
| 166 | - assert_file_exists(args.decoder) | ||
| 167 | - assert_file_exists(args.joiner) | ||
| 168 | - | ||
| 169 | - assert Path( | ||
| 170 | - args.keywords_file | ||
| 171 | - ).is_file(), ( | ||
| 172 | - f"keywords_file : {args.keywords_file} not exist, please provide a valid path." | 46 | +def create_keyword_spotter(): |
| 47 | + kws = sherpa_onnx.KeywordSpotter( | ||
| 48 | + tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt", | ||
| 49 | + encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 50 | + decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 51 | + joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", | ||
| 52 | + num_threads=2, | ||
| 53 | + keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt", | ||
| 54 | + provider="cpu", | ||
| 173 | ) | 55 | ) |
| 174 | 56 | ||
| 175 | - keyword_spotter = sherpa_onnx.KeywordSpotter( | ||
| 176 | - tokens=args.tokens, | ||
| 177 | - encoder=args.encoder, | ||
| 178 | - decoder=args.decoder, | ||
| 179 | - joiner=args.joiner, | ||
| 180 | - num_threads=args.num_threads, | ||
| 181 | - max_active_paths=args.max_active_paths, | ||
| 182 | - keywords_file=args.keywords_file, | ||
| 183 | - keywords_score=args.keywords_score, | ||
| 184 | - keywords_threshold=args.keywords_threshold, | ||
| 185 | - num_trailing_blanks=args.num_trailing_blanks, | ||
| 186 | - provider=args.provider, | ||
| 187 | - ) | 57 | + return kws |
| 188 | 58 | ||
| 189 | - print("Started!") | ||
| 190 | - start_time = time.time() | ||
| 191 | 59 | ||
| 192 | - streams = [] | ||
| 193 | - total_duration = 0 | ||
| 194 | - for wave_filename in args.sound_files: | ||
| 195 | - assert_file_exists(wave_filename) | ||
| 196 | - samples, sample_rate = read_wave(wave_filename) | ||
| 197 | - duration = len(samples) / sample_rate | ||
| 198 | - total_duration += duration | 60 | +def main(): |
| 61 | + kws = create_keyword_spotter() | ||
| 199 | 62 | ||
| 200 | - s = keyword_spotter.create_stream() | 63 | + wave_filename = ( |
| 64 | + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" | ||
| 65 | + ) | ||
| 201 | 66 | ||
| 202 | - s.accept_waveform(sample_rate, samples) | 67 | + samples, sample_rate = read_wave(wave_filename) |
| 203 | 68 | ||
| 204 | tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) | 69 | tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) |
| 70 | + | ||
| 71 | + print("----------Use pre-defined keywords----------") | ||
| 72 | + s = kws.create_stream() | ||
| 73 | + s.accept_waveform(sample_rate, samples) | ||
| 205 | s.accept_waveform(sample_rate, tail_paddings) | 74 | s.accept_waveform(sample_rate, tail_paddings) |
| 75 | + s.input_finished() | ||
| 76 | + while kws.is_ready(s): | ||
| 77 | + kws.decode_stream(s) | ||
| 78 | + r = kws.get_result(s) | ||
| 79 | + if r != "": | ||
| 80 | + # Remember to call reset right after detected a keyword | ||
| 81 | + kws.reset_stream(s) | ||
| 82 | + | ||
| 83 | + print(f"Detected {r}") | ||
| 206 | 84 | ||
| 85 | + print("----------Use pre-defined keywords + add a new keyword----------") | ||
| 86 | + | ||
| 87 | + s = kws.create_stream("y ǎn y uán @演员") | ||
| 88 | + s.accept_waveform(sample_rate, samples) | ||
| 89 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 207 | s.input_finished() | 90 | s.input_finished() |
| 91 | + while kws.is_ready(s): | ||
| 92 | + kws.decode_stream(s) | ||
| 93 | + r = kws.get_result(s) | ||
| 94 | + if r != "": | ||
| 95 | + # Remember to call reset right after detected a keyword | ||
| 96 | + kws.reset_stream(s) | ||
| 208 | 97 | ||
| 209 | - streams.append(s) | ||
| 210 | - | ||
| 211 | - results = [""] * len(streams) | ||
| 212 | - while True: | ||
| 213 | - ready_list = [] | ||
| 214 | - for i, s in enumerate(streams): | ||
| 215 | - if keyword_spotter.is_ready(s): | ||
| 216 | - ready_list.append(s) | ||
| 217 | - r = keyword_spotter.get_result(s) | ||
| 218 | - if r: | ||
| 219 | - results[i] += f"{r}/" | ||
| 220 | - print(f"{r} is detected.") | ||
| 221 | - if len(ready_list) == 0: | ||
| 222 | - break | ||
| 223 | - keyword_spotter.decode_streams(ready_list) | ||
| 224 | - end_time = time.time() | ||
| 225 | - print("Done!") | ||
| 226 | - | ||
| 227 | - for wave_filename, result in zip(args.sound_files, results): | ||
| 228 | - print(f"{wave_filename}\n{result}") | ||
| 229 | - print("-" * 10) | ||
| 230 | - | ||
| 231 | - elapsed_seconds = end_time - start_time | ||
| 232 | - rtf = elapsed_seconds / total_duration | ||
| 233 | - print(f"num_threads: {args.num_threads}") | ||
| 234 | - print(f"Wave duration: {total_duration:.3f} s") | ||
| 235 | - print(f"Elapsed time: {elapsed_seconds:.3f} s") | ||
| 236 | - print( | ||
| 237 | - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" | ||
| 238 | - ) | 98 | + print(f"Detected {r}") |
| 99 | + | ||
| 100 | + print("----------Use pre-defined keywords + add 2 new keywords----------") | ||
| 101 | + | ||
| 102 | + s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名") | ||
| 103 | + s.accept_waveform(sample_rate, samples) | ||
| 104 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 105 | + s.input_finished() | ||
| 106 | + while kws.is_ready(s): | ||
| 107 | + kws.decode_stream(s) | ||
| 108 | + r = kws.get_result(s) | ||
| 109 | + if r != "": | ||
| 110 | + # Remember to call reset right after detected a keyword | ||
| 111 | + kws.reset_stream(s) | ||
| 112 | + | ||
| 113 | + print(f"Detected {r}") | ||
| 239 | 114 | ||
| 240 | 115 | ||
| 241 | if __name__ == "__main__": | 116 | if __name__ == "__main__": |
| @@ -46,6 +46,11 @@ namespace SherpaOnnx | @@ -46,6 +46,11 @@ namespace SherpaOnnx | ||
| 46 | Decode(_handle.Handle, stream.Handle); | 46 | Decode(_handle.Handle, stream.Handle); |
| 47 | } | 47 | } |
| 48 | 48 | ||
| 49 | + public void Reset(OnlineStream stream) | ||
| 50 | + { | ||
| 51 | + Reset(_handle.Handle, stream.Handle); | ||
| 52 | + } | ||
| 53 | + | ||
| 49 | // The caller should ensure all passed streams are ready for decoding. | 54 | // The caller should ensure all passed streams are ready for decoding. |
| 50 | public void Decode(IEnumerable<OnlineStream> streams) | 55 | public void Decode(IEnumerable<OnlineStream> streams) |
| 51 | { | 56 | { |
| @@ -110,6 +115,9 @@ namespace SherpaOnnx | @@ -110,6 +115,9 @@ namespace SherpaOnnx | ||
| 110 | [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")] | 115 | [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeKeywordStream")] |
| 111 | private static extern void Decode(IntPtr handle, IntPtr stream); | 116 | private static extern void Decode(IntPtr handle, IntPtr stream); |
| 112 | 117 | ||
| 118 | + [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxResetKeywordStream")] | ||
| 119 | + private static extern void Reset(IntPtr handle, IntPtr stream); | ||
| 120 | + | ||
| 113 | [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")] | 121 | [DllImport(Dll.Filename, EntryPoint = "SherpaOnnxDecodeMultipleKeywordStreams")] |
| 114 | private static extern void Decode(IntPtr handle, IntPtr[] streams, int n); | 122 | private static extern void Decode(IntPtr handle, IntPtr[] streams, int n); |
| 115 | 123 |
| @@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) { | @@ -1584,6 +1584,11 @@ func (spotter *KeywordSpotter) Decode(s *OnlineStream) { | ||
| 1584 | C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl) | 1584 | C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl) |
| 1585 | } | 1585 | } |
| 1586 | 1586 | ||
| 1587 | +// You MUST call it right after detecting a keyword | ||
| 1588 | +func (spotter *KeywordSpotter) Reset(s *OnlineStream) { | ||
| 1589 | + C.SherpaOnnxResetKeywordStream(spotter.impl, s.impl) | ||
| 1590 | +} | ||
| 1591 | + | ||
| 1587 | // Get the current result of stream since the last invoke of Reset() | 1592 | // Get the current result of stream since the last invoke of Reset() |
| 1588 | func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult { | 1593 | func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult { |
| 1589 | p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl) | 1594 | p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl) |
| @@ -20,6 +20,10 @@ class KeywordSpotter { | @@ -20,6 +20,10 @@ class KeywordSpotter { | ||
| 20 | addon.decodeKeywordStream(this.handle, stream.handle); | 20 | addon.decodeKeywordStream(this.handle, stream.handle); |
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | + reset(stream) { | ||
| 24 | + addon.resetKeywordStream(this.handle, stream.handle); | ||
| 25 | + } | ||
| 26 | + | ||
| 23 | getResult(stream) { | 27 | getResult(stream) { |
| 24 | const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle); | 28 | const jsonStr = addon.getKeywordResultAsJson(this.handle, stream.handle); |
| 25 | 29 |
| @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter { | @@ -678,7 +678,7 @@ struct SherpaOnnxKeywordSpotter { | ||
| 678 | std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; | 678 | std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; |
| 679 | }; | 679 | }; |
| 680 | 680 | ||
| 681 | -SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | 681 | +const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( |
| 682 | const SherpaOnnxKeywordSpotterConfig *config) { | 682 | const SherpaOnnxKeywordSpotterConfig *config) { |
| 683 | sherpa_onnx::KeywordSpotterConfig spotter_config; | 683 | sherpa_onnx::KeywordSpotterConfig spotter_config; |
| 684 | 684 | ||
| @@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | @@ -755,37 +755,42 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | ||
| 755 | return spotter; | 755 | return spotter; |
| 756 | } | 756 | } |
| 757 | 757 | ||
| 758 | -void SherpaOnnxDestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) { | 758 | +void SherpaOnnxDestroyKeywordSpotter(const SherpaOnnxKeywordSpotter *spotter) { |
| 759 | delete spotter; | 759 | delete spotter; |
| 760 | } | 760 | } |
| 761 | 761 | ||
| 762 | -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( | 762 | +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( |
| 763 | const SherpaOnnxKeywordSpotter *spotter) { | 763 | const SherpaOnnxKeywordSpotter *spotter) { |
| 764 | SherpaOnnxOnlineStream *stream = | 764 | SherpaOnnxOnlineStream *stream = |
| 765 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); | 765 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); |
| 766 | return stream; | 766 | return stream; |
| 767 | } | 767 | } |
| 768 | 768 | ||
| 769 | -SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords( | 769 | +const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStreamWithKeywords( |
| 770 | const SherpaOnnxKeywordSpotter *spotter, const char *keywords) { | 770 | const SherpaOnnxKeywordSpotter *spotter, const char *keywords) { |
| 771 | SherpaOnnxOnlineStream *stream = | 771 | SherpaOnnxOnlineStream *stream = |
| 772 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords)); | 772 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream(keywords)); |
| 773 | return stream; | 773 | return stream; |
| 774 | } | 774 | } |
| 775 | 775 | ||
| 776 | -int32_t SherpaOnnxIsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, | ||
| 777 | - SherpaOnnxOnlineStream *stream) { | 776 | +int32_t SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter, |
| 777 | + const SherpaOnnxOnlineStream *stream) { | ||
| 778 | return spotter->impl->IsReady(stream->impl.get()); | 778 | return spotter->impl->IsReady(stream->impl.get()); |
| 779 | } | 779 | } |
| 780 | 780 | ||
| 781 | -void SherpaOnnxDecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, | ||
| 782 | - SherpaOnnxOnlineStream *stream) { | ||
| 783 | - return spotter->impl->DecodeStream(stream->impl.get()); | 781 | +void SherpaOnnxDecodeKeywordStream(const SherpaOnnxKeywordSpotter *spotter, |
| 782 | + const SherpaOnnxOnlineStream *stream) { | ||
| 783 | + spotter->impl->DecodeStream(stream->impl.get()); | ||
| 784 | } | 784 | } |
| 785 | 785 | ||
| 786 | -void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, | ||
| 787 | - SherpaOnnxOnlineStream **streams, | ||
| 788 | - int32_t n) { | 786 | +void SherpaOnnxResetKeywordStream(const SherpaOnnxKeywordSpotter *spotter, |
| 787 | + const SherpaOnnxOnlineStream *stream) { | ||
| 788 | + spotter->impl->Reset(stream->impl.get()); | ||
| 789 | +} | ||
| 790 | + | ||
| 791 | +void SherpaOnnxDecodeMultipleKeywordStreams( | ||
| 792 | + const SherpaOnnxKeywordSpotter *spotter, | ||
| 793 | + const SherpaOnnxOnlineStream **streams, int32_t n) { | ||
| 789 | std::vector<sherpa_onnx::OnlineStream *> ss(n); | 794 | std::vector<sherpa_onnx::OnlineStream *> ss(n); |
| 790 | for (int32_t i = 0; i != n; ++i) { | 795 | for (int32_t i = 0; i != n; ++i) { |
| 791 | ss[i] = streams[i]->impl.get(); | 796 | ss[i] = streams[i]->impl.get(); |
| @@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, | @@ -794,7 +799,8 @@ void SherpaOnnxDecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, | ||
| 794 | } | 799 | } |
| 795 | 800 | ||
| 796 | const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( | 801 | const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( |
| 797 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { | 802 | + const SherpaOnnxKeywordSpotter *spotter, |
| 803 | + const SherpaOnnxOnlineStream *stream) { | ||
| 798 | const sherpa_onnx::KeywordResult &result = | 804 | const sherpa_onnx::KeywordResult &result = |
| 799 | spotter->impl->GetResult(stream->impl.get()); | 805 | spotter->impl->GetResult(stream->impl.get()); |
| 800 | const auto &keyword = result.keyword; | 806 | const auto &keyword = result.keyword; |
| @@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) { | @@ -869,8 +875,9 @@ void SherpaOnnxDestroyKeywordResult(const SherpaOnnxKeywordResult *r) { | ||
| 869 | } | 875 | } |
| 870 | } | 876 | } |
| 871 | 877 | ||
| 872 | -const char *SherpaOnnxGetKeywordResultAsJson(SherpaOnnxKeywordSpotter *spotter, | ||
| 873 | - SherpaOnnxOnlineStream *stream) { | 878 | +const char *SherpaOnnxGetKeywordResultAsJson( |
| 879 | + const SherpaOnnxKeywordSpotter *spotter, | ||
| 880 | + const SherpaOnnxOnlineStream *stream) { | ||
| 874 | const sherpa_onnx::KeywordResult &result = | 881 | const sherpa_onnx::KeywordResult &result = |
| 875 | spotter->impl->GetResult(stream->impl.get()); | 882 | spotter->impl->GetResult(stream->impl.get()); |
| 876 | 883 |
| @@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson( | @@ -600,7 +600,7 @@ SHERPA_ONNX_API const char *SherpaOnnxGetOfflineStreamResultAsJson( | ||
| 600 | SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s); | 600 | SHERPA_ONNX_API void SherpaOnnxDestroyOfflineStreamResultJson(const char *s); |
| 601 | 601 | ||
| 602 | // ============================================================ | 602 | // ============================================================ |
| 603 | -// For Keyword Spot | 603 | +// For Keyword Spotter |
| 604 | // ============================================================ | 604 | // ============================================================ |
| 605 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | 605 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { |
| 606 | /// The triggered keyword. | 606 | /// The triggered keyword. |
| @@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | @@ -660,21 +660,21 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | ||
| 660 | /// @param config Config for the keyword spotter. | 660 | /// @param config Config for the keyword spotter. |
| 661 | /// @return Return a pointer to the spotter. The user has to invoke | 661 | /// @return Return a pointer to the spotter. The user has to invoke |
| 662 | /// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak. | 662 | /// SherpaOnnxDestroyKeywordSpotter() to free it to avoid memory leak. |
| 663 | -SHERPA_ONNX_API SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | 663 | +SHERPA_ONNX_API const SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( |
| 664 | const SherpaOnnxKeywordSpotterConfig *config); | 664 | const SherpaOnnxKeywordSpotterConfig *config); |
| 665 | 665 | ||
| 666 | /// Free a pointer returned by SherpaOnnxCreateKeywordSpotter() | 666 | /// Free a pointer returned by SherpaOnnxCreateKeywordSpotter() |
| 667 | /// | 667 | /// |
| 668 | /// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter() | 668 | /// @param p A pointer returned by SherpaOnnxCreateKeywordSpotter() |
| 669 | SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter( | 669 | SHERPA_ONNX_API void SherpaOnnxDestroyKeywordSpotter( |
| 670 | - SherpaOnnxKeywordSpotter *spotter); | 670 | + const SherpaOnnxKeywordSpotter *spotter); |
| 671 | 671 | ||
| 672 | /// Create an online stream for accepting wave samples. | 672 | /// Create an online stream for accepting wave samples. |
| 673 | /// | 673 | /// |
| 674 | /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter() | 674 | /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter() |
| 675 | /// @return Return a pointer to an OnlineStream. The user has to invoke | 675 | /// @return Return a pointer to an OnlineStream. The user has to invoke |
| 676 | /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. | 676 | /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. |
| 677 | -SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( | 677 | +SHERPA_ONNX_API const SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( |
| 678 | const SherpaOnnxKeywordSpotter *spotter); | 678 | const SherpaOnnxKeywordSpotter *spotter); |
| 679 | 679 | ||
| 680 | /// Create an online stream for accepting wave samples with the specified hot | 680 | /// Create an online stream for accepting wave samples with the specified hot |
| @@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( | @@ -684,7 +684,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *SherpaOnnxCreateKeywordStream( | ||
| 684 | /// @param keywords A pointer points to the keywords that you set | 684 | /// @param keywords A pointer points to the keywords that you set |
| 685 | /// @return Return a pointer to an OnlineStream. The user has to invoke | 685 | /// @return Return a pointer to an OnlineStream. The user has to invoke |
| 686 | /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. | 686 | /// SherpaOnnxDestroyOnlineStream() to free it to avoid memory leak. |
| 687 | -SHERPA_ONNX_API SherpaOnnxOnlineStream * | 687 | +SHERPA_ONNX_API const SherpaOnnxOnlineStream * |
| 688 | SherpaOnnxCreateKeywordStreamWithKeywords( | 688 | SherpaOnnxCreateKeywordStreamWithKeywords( |
| 689 | const SherpaOnnxKeywordSpotter *spotter, const char *keywords); | 689 | const SherpaOnnxKeywordSpotter *spotter, const char *keywords); |
| 690 | 690 | ||
| @@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords( | @@ -693,15 +693,22 @@ SherpaOnnxCreateKeywordStreamWithKeywords( | ||
| 693 | /// | 693 | /// |
| 694 | /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter | 694 | /// @param spotter A pointer returned by SherpaOnnxCreateKeywordSpotter |
| 695 | /// @param stream A pointer returned by SherpaOnnxCreateKeywordStream | 695 | /// @param stream A pointer returned by SherpaOnnxCreateKeywordStream |
| 696 | -SHERPA_ONNX_API int32_t SherpaOnnxIsKeywordStreamReady( | ||
| 697 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); | 696 | +SHERPA_ONNX_API int32_t |
| 697 | +SherpaOnnxIsKeywordStreamReady(const SherpaOnnxKeywordSpotter *spotter, | ||
| 698 | + const SherpaOnnxOnlineStream *stream); | ||
| 698 | 699 | ||
| 699 | /// Call this function to run the neural network model and decoding. | 700 | /// Call this function to run the neural network model and decoding. |
| 700 | // | 701 | // |
| 701 | /// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST | 702 | /// Precondition for this function: SherpaOnnxIsKeywordStreamReady() MUST |
| 702 | /// return 1. | 703 | /// return 1. |
| 703 | SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( | 704 | SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( |
| 704 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); | 705 | + const SherpaOnnxKeywordSpotter *spotter, |
| 706 | + const SherpaOnnxOnlineStream *stream); | ||
| 707 | + | ||
| 708 | +/// Please call it right after a keyword is detected | ||
| 709 | +SHERPA_ONNX_API void SherpaOnnxResetKeywordStream( | ||
| 710 | + const SherpaOnnxKeywordSpotter *spotter, | ||
| 711 | + const SherpaOnnxOnlineStream *stream); | ||
| 705 | 712 | ||
| 706 | /// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes | 713 | /// This function is similar to SherpaOnnxDecodeKeywordStream(). It decodes |
| 707 | /// multiple OnlineStream in parallel. | 714 | /// multiple OnlineStream in parallel. |
| @@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( | @@ -714,8 +721,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeKeywordStream( | ||
| 714 | /// SherpaOnnxCreateKeywordStream() | 721 | /// SherpaOnnxCreateKeywordStream() |
| 715 | /// @param n Number of elements in the given streams array. | 722 | /// @param n Number of elements in the given streams array. |
| 716 | SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( | 723 | SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( |
| 717 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, | ||
| 718 | - int32_t n); | 724 | + const SherpaOnnxKeywordSpotter *spotter, |
| 725 | + const SherpaOnnxOnlineStream **streams, int32_t n); | ||
| 719 | 726 | ||
| 720 | /// Get the decoding results so far for an OnlineStream. | 727 | /// Get the decoding results so far for an OnlineStream. |
| 721 | /// | 728 | /// |
| @@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( | @@ -725,7 +732,8 @@ SHERPA_ONNX_API void SherpaOnnxDecodeMultipleKeywordStreams( | ||
| 725 | /// SherpaOnnxDestroyKeywordResult() to free the returned pointer to | 732 | /// SherpaOnnxDestroyKeywordResult() to free the returned pointer to |
| 726 | /// avoid memory leak. | 733 | /// avoid memory leak. |
| 727 | SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( | 734 | SHERPA_ONNX_API const SherpaOnnxKeywordResult *SherpaOnnxGetKeywordResult( |
| 728 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); | 735 | + const SherpaOnnxKeywordSpotter *spotter, |
| 736 | + const SherpaOnnxOnlineStream *stream); | ||
| 729 | 737 | ||
| 730 | /// Destroy the pointer returned by SherpaOnnxGetKeywordResult(). | 738 | /// Destroy the pointer returned by SherpaOnnxGetKeywordResult(). |
| 731 | /// | 739 | /// |
| @@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult( | @@ -736,7 +744,8 @@ SHERPA_ONNX_API void SherpaOnnxDestroyKeywordResult( | ||
| 736 | // the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned | 744 | // the user has to call SherpaOnnxFreeKeywordResultJson() to free the returned |
| 737 | // pointer to avoid memory leak | 745 | // pointer to avoid memory leak |
| 738 | SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson( | 746 | SHERPA_ONNX_API const char *SherpaOnnxGetKeywordResultAsJson( |
| 739 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); | 747 | + const SherpaOnnxKeywordSpotter *spotter, |
| 748 | + const SherpaOnnxOnlineStream *stream); | ||
| 740 | 749 | ||
| 741 | SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s); | 750 | SHERPA_ONNX_API void SherpaOnnxFreeKeywordResultJson(const char *s); |
| 742 | 751 |
| @@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text, | @@ -391,4 +391,112 @@ GeneratedAudio OfflineTts::Generate(const std::string &text, | ||
| 391 | return ans; | 391 | return ans; |
| 392 | } | 392 | } |
| 393 | 393 | ||
| 394 | +KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) { | ||
| 395 | + struct SherpaOnnxKeywordSpotterConfig c; | ||
| 396 | + memset(&c, 0, sizeof(c)); | ||
| 397 | + | ||
| 398 | + c.feat_config.sample_rate = config.feat_config.sample_rate; | ||
| 399 | + | ||
| 400 | + c.model_config.transducer.encoder = | ||
| 401 | + config.model_config.transducer.encoder.c_str(); | ||
| 402 | + c.model_config.transducer.decoder = | ||
| 403 | + config.model_config.transducer.decoder.c_str(); | ||
| 404 | + c.model_config.transducer.joiner = | ||
| 405 | + config.model_config.transducer.joiner.c_str(); | ||
| 406 | + c.feat_config.feature_dim = config.feat_config.feature_dim; | ||
| 407 | + | ||
| 408 | + c.model_config.paraformer.encoder = | ||
| 409 | + config.model_config.paraformer.encoder.c_str(); | ||
| 410 | + c.model_config.paraformer.decoder = | ||
| 411 | + config.model_config.paraformer.decoder.c_str(); | ||
| 412 | + | ||
| 413 | + c.model_config.zipformer2_ctc.model = | ||
| 414 | + config.model_config.zipformer2_ctc.model.c_str(); | ||
| 415 | + | ||
| 416 | + c.model_config.tokens = config.model_config.tokens.c_str(); | ||
| 417 | + c.model_config.num_threads = config.model_config.num_threads; | ||
| 418 | + c.model_config.provider = config.model_config.provider.c_str(); | ||
| 419 | + c.model_config.debug = config.model_config.debug; | ||
| 420 | + c.model_config.model_type = config.model_config.model_type.c_str(); | ||
| 421 | + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); | ||
| 422 | + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); | ||
| 423 | + c.model_config.tokens_buf = config.model_config.tokens_buf.c_str(); | ||
| 424 | + c.model_config.tokens_buf_size = config.model_config.tokens_buf.size(); | ||
| 425 | + | ||
| 426 | + c.max_active_paths = config.max_active_paths; | ||
| 427 | + c.num_trailing_blanks = config.num_trailing_blanks; | ||
| 428 | + c.keywords_score = config.keywords_score; | ||
| 429 | + c.keywords_threshold = config.keywords_threshold; | ||
| 430 | + c.keywords_file = config.keywords_file.c_str(); | ||
| 431 | + | ||
| 432 | + auto p = SherpaOnnxCreateKeywordSpotter(&c); | ||
| 433 | + return KeywordSpotter(p); | ||
| 434 | +} | ||
| 435 | + | ||
| 436 | +KeywordSpotter::KeywordSpotter(const SherpaOnnxKeywordSpotter *p) | ||
| 437 | + : MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter>(p) {} | ||
| 438 | + | ||
| 439 | +void KeywordSpotter::Destroy(const SherpaOnnxKeywordSpotter *p) const { | ||
| 440 | + SherpaOnnxDestroyKeywordSpotter(p); | ||
| 441 | +} | ||
| 442 | + | ||
| 443 | +OnlineStream KeywordSpotter::CreateStream() const { | ||
| 444 | + auto s = SherpaOnnxCreateKeywordStream(p_); | ||
| 445 | + return OnlineStream{s}; | ||
| 446 | +} | ||
| 447 | + | ||
| 448 | +OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const { | ||
| 449 | + auto s = SherpaOnnxCreateKeywordStreamWithKeywords(p_, keywords.c_str()); | ||
| 450 | + return OnlineStream{s}; | ||
| 451 | +} | ||
| 452 | + | ||
| 453 | +bool KeywordSpotter::IsReady(const OnlineStream *s) const { | ||
| 454 | + return SherpaOnnxIsKeywordStreamReady(p_, s->Get()); | ||
| 455 | +} | ||
| 456 | + | ||
| 457 | +void KeywordSpotter::Decode(const OnlineStream *s) const { | ||
| 458 | + return SherpaOnnxDecodeKeywordStream(p_, s->Get()); | ||
| 459 | +} | ||
| 460 | + | ||
| 461 | +void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const { | ||
| 462 | + if (n <= 0) { | ||
| 463 | + return; | ||
| 464 | + } | ||
| 465 | + | ||
| 466 | + std::vector<const SherpaOnnxOnlineStream *> streams(n); | ||
| 467 | + for (int32_t i = 0; i != n; ++n) { | ||
| 468 | + streams[i] = ss[i].Get(); | ||
| 469 | + } | ||
| 470 | + | ||
| 471 | + SherpaOnnxDecodeMultipleKeywordStreams(p_, streams.data(), n); | ||
| 472 | +} | ||
| 473 | + | ||
| 474 | +KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const { | ||
| 475 | + auto r = SherpaOnnxGetKeywordResult(p_, s->Get()); | ||
| 476 | + | ||
| 477 | + KeywordResult ans; | ||
| 478 | + ans.keyword = r->keyword; | ||
| 479 | + | ||
| 480 | + ans.tokens.resize(r->count); | ||
| 481 | + for (int32_t i = 0; i < r->count; ++i) { | ||
| 482 | + ans.tokens[i] = r->tokens_arr[i]; | ||
| 483 | + } | ||
| 484 | + | ||
| 485 | + if (r->timestamps) { | ||
| 486 | + ans.timestamps.resize(r->count); | ||
| 487 | + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); | ||
| 488 | + } | ||
| 489 | + | ||
| 490 | + ans.start_time = r->start_time; | ||
| 491 | + ans.json = r->json; | ||
| 492 | + | ||
| 493 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 494 | + | ||
| 495 | + return ans; | ||
| 496 | +} | ||
| 497 | + | ||
| 498 | +void KeywordSpotter::Reset(const OnlineStream *s) const { | ||
| 499 | + SherpaOnnxResetKeywordStream(p_, s->Get()); | ||
| 500 | +} | ||
| 501 | + | ||
| 394 | } // namespace sherpa_onnx::cxx | 502 | } // namespace sherpa_onnx::cxx |
| @@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts | @@ -406,6 +406,53 @@ class SHERPA_ONNX_API OfflineTts | ||
| 406 | explicit OfflineTts(const SherpaOnnxOfflineTts *p); | 406 | explicit OfflineTts(const SherpaOnnxOfflineTts *p); |
| 407 | }; | 407 | }; |
| 408 | 408 | ||
| 409 | +// ============================================================ | ||
| 410 | +// For Keyword Spotter | ||
| 411 | +// ============================================================ | ||
| 412 | + | ||
| 413 | +struct KeywordResult { | ||
| 414 | + std::string keyword; | ||
| 415 | + std::vector<std::string> tokens; | ||
| 416 | + std::vector<float> timestamps; | ||
| 417 | + float start_time; | ||
| 418 | + std::string json; | ||
| 419 | +}; | ||
| 420 | + | ||
| 421 | +struct KeywordSpotterConfig { | ||
| 422 | + FeatureConfig feat_config; | ||
| 423 | + OnlineModelConfig model_config; | ||
| 424 | + int32_t max_active_paths = 4; | ||
| 425 | + int32_t num_trailing_blanks = 1; | ||
| 426 | + float keywords_score = 1.0f; | ||
| 427 | + float keywords_threshold = 0.25f; | ||
| 428 | + std::string keywords_file; | ||
| 429 | +}; | ||
| 430 | + | ||
| 431 | +class SHERPA_ONNX_API KeywordSpotter | ||
| 432 | + : public MoveOnly<KeywordSpotter, SherpaOnnxKeywordSpotter> { | ||
| 433 | + public: | ||
| 434 | + static KeywordSpotter Create(const KeywordSpotterConfig &config); | ||
| 435 | + | ||
| 436 | + void Destroy(const SherpaOnnxKeywordSpotter *p) const; | ||
| 437 | + | ||
| 438 | + OnlineStream CreateStream() const; | ||
| 439 | + | ||
| 440 | + OnlineStream CreateStream(const std::string &keywords) const; | ||
| 441 | + | ||
| 442 | + bool IsReady(const OnlineStream *s) const; | ||
| 443 | + | ||
| 444 | + void Decode(const OnlineStream *s) const; | ||
| 445 | + | ||
| 446 | + void Decode(const OnlineStream *ss, int32_t n) const; | ||
| 447 | + | ||
| 448 | + void Reset(const OnlineStream *s) const; | ||
| 449 | + | ||
| 450 | + KeywordResult GetResult(const OnlineStream *s) const; | ||
| 451 | + | ||
| 452 | + private: | ||
| 453 | + explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p); | ||
| 454 | +}; | ||
| 455 | + | ||
| 409 | } // namespace sherpa_onnx::cxx | 456 | } // namespace sherpa_onnx::cxx |
| 410 | 457 | ||
| 411 | #endif // SHERPA_ONNX_C_API_CXX_API_H_ | 458 | #endif // SHERPA_ONNX_C_API_CXX_API_H_ |
| @@ -38,6 +38,8 @@ class KeywordSpotterImpl { | @@ -38,6 +38,8 @@ class KeywordSpotterImpl { | ||
| 38 | 38 | ||
| 39 | virtual bool IsReady(OnlineStream *s) const = 0; | 39 | virtual bool IsReady(OnlineStream *s) const = 0; |
| 40 | 40 | ||
| 41 | + virtual void Reset(OnlineStream *s) const = 0; | ||
| 42 | + | ||
| 41 | virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; | 43 | virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; |
| 42 | 44 | ||
| 43 | virtual KeywordResult GetResult(OnlineStream *s) const = 0; | 45 | virtual KeywordResult GetResult(OnlineStream *s) const = 0; |
| @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -195,8 +195,24 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 195 | return s->GetNumProcessedFrames() + model_->ChunkSize() < | 195 | return s->GetNumProcessedFrames() + model_->ChunkSize() < |
| 196 | s->NumFramesReady(); | 196 | s->NumFramesReady(); |
| 197 | } | 197 | } |
| 198 | + void Reset(OnlineStream *s) const override { InitOnlineStream(s); } | ||
| 198 | 199 | ||
| 199 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { | 200 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { |
| 201 | + for (int32_t i = 0; i < n; ++i) { | ||
| 202 | + auto s = ss[i]; | ||
| 203 | + auto r = s->GetKeywordResult(true); | ||
| 204 | + int32_t num_trailing_blanks = r.num_trailing_blanks; | ||
| 205 | + // assume subsampling_factor is 4 | ||
| 206 | + // assume frameshift is 0.01 second | ||
| 207 | + float trailing_slience = num_trailing_blanks * 4 * 0.01; | ||
| 208 | + | ||
| 209 | + // it resets automatically after detecting 1.5 seconds of silence | ||
| 210 | + float threshold = 1.5; | ||
| 211 | + if (trailing_slience > threshold) { | ||
| 212 | + Reset(s); | ||
| 213 | + } | ||
| 214 | + } | ||
| 215 | + | ||
| 200 | int32_t chunk_size = model_->ChunkSize(); | 216 | int32_t chunk_size = model_->ChunkSize(); |
| 201 | int32_t chunk_shift = model_->ChunkShift(); | 217 | int32_t chunk_shift = model_->ChunkShift(); |
| 202 | 218 |
| @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const { | @@ -157,6 +157,8 @@ bool KeywordSpotter::IsReady(OnlineStream *s) const { | ||
| 157 | return impl_->IsReady(s); | 157 | return impl_->IsReady(s); |
| 158 | } | 158 | } |
| 159 | 159 | ||
| 160 | +void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); } | ||
| 161 | + | ||
| 160 | void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { | 162 | void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { |
| 161 | impl_->DecodeStreams(ss, n); | 163 | impl_->DecodeStreams(ss, n); |
| 162 | } | 164 | } |
| @@ -129,6 +129,9 @@ class KeywordSpotter { | @@ -129,6 +129,9 @@ class KeywordSpotter { | ||
| 129 | */ | 129 | */ |
| 130 | bool IsReady(OnlineStream *s) const; | 130 | bool IsReady(OnlineStream *s) const; |
| 131 | 131 | ||
| 132 | + // Remember to call it after detecting a keyword | ||
| 133 | + void Reset(OnlineStream *s) const; | ||
| 134 | + | ||
| 132 | /** Decode a single stream. */ | 135 | /** Decode a single stream. */ |
| 133 | void DecodeStream(OnlineStream *s) const { | 136 | void DecodeStream(OnlineStream *s) const { |
| 134 | OnlineStream *ss[1] = {s}; | 137 | OnlineStream *ss[1] = {s}; |
| @@ -106,13 +106,15 @@ as the device_name. | @@ -106,13 +106,15 @@ as the device_name. | ||
| 106 | 106 | ||
| 107 | while (spotter.IsReady(stream.get())) { | 107 | while (spotter.IsReady(stream.get())) { |
| 108 | spotter.DecodeStream(stream.get()); | 108 | spotter.DecodeStream(stream.get()); |
| 109 | - } | ||
| 110 | 109 | ||
| 111 | const auto r = spotter.GetResult(stream.get()); | 110 | const auto r = spotter.GetResult(stream.get()); |
| 112 | if (!r.keyword.empty()) { | 111 | if (!r.keyword.empty()) { |
| 113 | display.Print(keyword_index, r.AsJsonString()); | 112 | display.Print(keyword_index, r.AsJsonString()); |
| 114 | fflush(stderr); | 113 | fflush(stderr); |
| 115 | keyword_index++; | 114 | keyword_index++; |
| 115 | + | ||
| 116 | + spotter.Reset(stream.get()); | ||
| 117 | + } | ||
| 116 | } | 118 | } |
| 117 | } | 119 | } |
| 118 | 120 |
| @@ -150,13 +150,15 @@ for a list of pre-trained models to download. | @@ -150,13 +150,15 @@ for a list of pre-trained models to download. | ||
| 150 | while (!stop) { | 150 | while (!stop) { |
| 151 | while (spotter.IsReady(s.get())) { | 151 | while (spotter.IsReady(s.get())) { |
| 152 | spotter.DecodeStream(s.get()); | 152 | spotter.DecodeStream(s.get()); |
| 153 | - } | ||
| 154 | 153 | ||
| 155 | const auto r = spotter.GetResult(s.get()); | 154 | const auto r = spotter.GetResult(s.get()); |
| 156 | if (!r.keyword.empty()) { | 155 | if (!r.keyword.empty()) { |
| 157 | display.Print(keyword_index, r.AsJsonString()); | 156 | display.Print(keyword_index, r.AsJsonString()); |
| 158 | fflush(stderr); | 157 | fflush(stderr); |
| 159 | keyword_index++; | 158 | keyword_index++; |
| 159 | + | ||
| 160 | + spotter.Reset(s.get()); | ||
| 161 | + } | ||
| 160 | } | 162 | } |
| 161 | 163 | ||
| 162 | Pa_Sleep(20); // sleep for 20ms | 164 | Pa_Sleep(20); // sleep for 20ms |
| @@ -27,6 +27,10 @@ public class KeywordSpotter { | @@ -27,6 +27,10 @@ public class KeywordSpotter { | ||
| 27 | decode(ptr, s.getPtr()); | 27 | decode(ptr, s.getPtr()); |
| 28 | } | 28 | } |
| 29 | 29 | ||
| 30 | + public void reset(OnlineStream s) { | ||
| 31 | + reset(ptr, s.getPtr()); | ||
| 32 | + } | ||
| 33 | + | ||
| 30 | public boolean isReady(OnlineStream s) { | 34 | public boolean isReady(OnlineStream s) { |
| 31 | return isReady(ptr, s.getPtr()); | 35 | return isReady(ptr, s.getPtr()); |
| 32 | } | 36 | } |
| @@ -60,6 +64,8 @@ public class KeywordSpotter { | @@ -60,6 +64,8 @@ public class KeywordSpotter { | ||
| 60 | 64 | ||
| 61 | private native void decode(long ptr, long streamPtr); | 65 | private native void decode(long ptr, long streamPtr); |
| 62 | 66 | ||
| 67 | + private native void reset(long ptr, long streamPtr); | ||
| 68 | + | ||
| 63 | private native boolean isReady(long ptr, long streamPtr); | 69 | private native boolean isReady(long ptr, long streamPtr); |
| 64 | 70 | ||
| 65 | private native Object[] getResult(long ptr, long streamPtr); | 71 | private native Object[] getResult(long ptr, long streamPtr); |
| @@ -162,6 +162,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode( | @@ -162,6 +162,15 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode( | ||
| 162 | } | 162 | } |
| 163 | 163 | ||
| 164 | SHERPA_ONNX_EXTERN_C | 164 | SHERPA_ONNX_EXTERN_C |
| 165 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_reset( | ||
| 166 | + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { | ||
| 167 | + auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr); | ||
| 168 | + auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr); | ||
| 169 | + | ||
| 170 | + kws->Reset(stream); | ||
| 171 | +} | ||
| 172 | + | ||
| 173 | +SHERPA_ONNX_EXTERN_C | ||
| 165 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream( | 174 | JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream( |
| 166 | JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { | 175 | JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { |
| 167 | auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr); | 176 | auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr); |
| @@ -49,6 +49,7 @@ class KeywordSpotter( | @@ -49,6 +49,7 @@ class KeywordSpotter( | ||
| 49 | } | 49 | } |
| 50 | 50 | ||
| 51 | fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) | 51 | fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) |
| 52 | + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr) | ||
| 52 | fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) | 53 | fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) |
| 53 | fun getResult(stream: OnlineStream): KeywordSpotterResult { | 54 | fun getResult(stream: OnlineStream): KeywordSpotterResult { |
| 54 | val objArray = getResult(ptr, stream.ptr) | 55 | val objArray = getResult(ptr, stream.ptr) |
| @@ -74,6 +75,7 @@ class KeywordSpotter( | @@ -74,6 +75,7 @@ class KeywordSpotter( | ||
| 74 | private external fun createStream(ptr: Long, keywords: String): Long | 75 | private external fun createStream(ptr: Long, keywords: String): Long |
| 75 | private external fun isReady(ptr: Long, streamPtr: Long): Boolean | 76 | private external fun isReady(ptr: Long, streamPtr: Long): Boolean |
| 76 | private external fun decode(ptr: Long, streamPtr: Long) | 77 | private external fun decode(ptr: Long, streamPtr: Long) |
| 78 | + private external fun reset(ptr: Long, streamPtr: Long) | ||
| 77 | private external fun getResult(ptr: Long, streamPtr: Long): Array<Any> | 79 | private external fun getResult(ptr: Long, streamPtr: Long): Array<Any> |
| 78 | 80 | ||
| 79 | companion object { | 81 | companion object { |
| @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) { | @@ -67,6 +67,7 @@ void PybindKeywordSpotter(py::module *m) { | ||
| 67 | py::arg("keywords"), py::call_guard<py::gil_scoped_release>()) | 67 | py::arg("keywords"), py::call_guard<py::gil_scoped_release>()) |
| 68 | .def("is_ready", &PyClass::IsReady, | 68 | .def("is_ready", &PyClass::IsReady, |
| 69 | py::call_guard<py::gil_scoped_release>()) | 69 | py::call_guard<py::gil_scoped_release>()) |
| 70 | + .def("reset", &PyClass::Reset, py::call_guard<py::gil_scoped_release>()) | ||
| 70 | .def("decode_stream", &PyClass::DecodeStream, | 71 | .def("decode_stream", &PyClass::DecodeStream, |
| 71 | py::call_guard<py::gil_scoped_release>()) | 72 | py::call_guard<py::gil_scoped_release>()) |
| 72 | .def( | 73 | .def( |
| @@ -105,7 +105,7 @@ class KeywordSpotter(object): | @@ -105,7 +105,7 @@ class KeywordSpotter(object): | ||
| 105 | 105 | ||
| 106 | provider_config = ProviderConfig( | 106 | provider_config = ProviderConfig( |
| 107 | provider=provider, | 107 | provider=provider, |
| 108 | - device = device, | 108 | + device=device, |
| 109 | ) | 109 | ) |
| 110 | 110 | ||
| 111 | model_config = OnlineModelConfig( | 111 | model_config = OnlineModelConfig( |
| @@ -131,6 +131,9 @@ class KeywordSpotter(object): | @@ -131,6 +131,9 @@ class KeywordSpotter(object): | ||
| 131 | ) | 131 | ) |
| 132 | self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) | 132 | self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) |
| 133 | 133 | ||
| 134 | + def reset_stream(self, s: OnlineStream): | ||
| 135 | + self.keyword_spotter.reset(s) | ||
| 136 | + | ||
| 134 | def create_stream(self, keywords: Optional[str] = None): | 137 | def create_stream(self, keywords: Optional[str] = None): |
| 135 | if keywords is None: | 138 | if keywords is None: |
| 136 | return self.keyword_spotter.create_stream() | 139 | return self.keyword_spotter.create_stream() |
| @@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase): | @@ -98,6 +98,9 @@ class TestKeywordSpotter(unittest.TestCase): | ||
| 98 | if r: | 98 | if r: |
| 99 | print(f"{r} is detected.") | 99 | print(f"{r} is detected.") |
| 100 | results[i] += f"{r}/" | 100 | results[i] += f"{r}/" |
| 101 | + | ||
| 102 | + keyword_spotter.reset_stream(s) | ||
| 103 | + | ||
| 101 | if len(ready_list) == 0: | 104 | if len(ready_list) == 0: |
| 102 | break | 105 | break |
| 103 | keyword_spotter.decode_streams(ready_list) | 106 | keyword_spotter.decode_streams(ready_list) |
| @@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase): | @@ -158,6 +161,9 @@ class TestKeywordSpotter(unittest.TestCase): | ||
| 158 | if r: | 161 | if r: |
| 159 | print(f"{r} is detected.") | 162 | print(f"{r} is detected.") |
| 160 | results[i] += f"{r}/" | 163 | results[i] += f"{r}/" |
| 164 | + | ||
| 165 | + keyword_spotter.reset_stream(s) | ||
| 166 | + | ||
| 161 | if len(ready_list) == 0: | 167 | if len(ready_list) == 0: |
| 162 | break | 168 | break |
| 163 | keyword_spotter.decode_streams(ready_list) | 169 | keyword_spotter.decode_streams(ready_list) |
| @@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper { | @@ -1076,6 +1076,10 @@ class SherpaOnnxKeywordSpotterWrapper { | ||
| 1076 | SherpaOnnxDecodeKeywordStream(spotter, stream) | 1076 | SherpaOnnxDecodeKeywordStream(spotter, stream) |
| 1077 | } | 1077 | } |
| 1078 | 1078 | ||
| 1079 | + func reset() { | ||
| 1080 | + SherpaOnnxResetKeywordStream(spotter, stream) | ||
| 1081 | + } | ||
| 1082 | + | ||
| 1079 | func getResult() -> SherpaOnnxKeywordResultWrapper { | 1083 | func getResult() -> SherpaOnnxKeywordResultWrapper { |
| 1080 | let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult( | 1084 | let result: UnsafePointer<SherpaOnnxKeywordResult>? = SherpaOnnxGetKeywordResult( |
| 1081 | spotter, stream) | 1085 | spotter, stream) |
| @@ -70,6 +70,9 @@ func run() { | @@ -70,6 +70,9 @@ func run() { | ||
| 70 | spotter.decode() | 70 | spotter.decode() |
| 71 | let keyword = spotter.getResult().keyword | 71 | let keyword = spotter.getResult().keyword |
| 72 | if keyword != "" { | 72 | if keyword != "" { |
| 73 | + // Remember to call reset() right after detecting a keyword | ||
| 74 | + spotter.reset() | ||
| 75 | + | ||
| 73 | print("Detected: \(keyword)") | 76 | print("Detected: \(keyword)") |
| 74 | } | 77 | } |
| 75 | } | 78 | } |
| @@ -17,6 +17,7 @@ set(exported_functions | @@ -17,6 +17,7 @@ set(exported_functions | ||
| 17 | SherpaOnnxIsKeywordStreamReady | 17 | SherpaOnnxIsKeywordStreamReady |
| 18 | SherpaOnnxOnlineStreamAcceptWaveform | 18 | SherpaOnnxOnlineStreamAcceptWaveform |
| 19 | SherpaOnnxOnlineStreamInputFinished | 19 | SherpaOnnxOnlineStreamInputFinished |
| 20 | + SherpaOnnxResetKeywordStream | ||
| 20 | ) | 21 | ) |
| 21 | set(mangled_exported_functions) | 22 | set(mangled_exported_functions) |
| 22 | foreach(x IN LISTS exported_functions) | 23 | foreach(x IN LISTS exported_functions) |
| @@ -102,8 +102,6 @@ if (navigator.mediaDevices.getUserMedia) { | @@ -102,8 +102,6 @@ if (navigator.mediaDevices.getUserMedia) { | ||
| 102 | recognizer_stream.acceptWaveform(expectedSampleRate, samples); | 102 | recognizer_stream.acceptWaveform(expectedSampleRate, samples); |
| 103 | while (recognizer.isReady(recognizer_stream)) { | 103 | while (recognizer.isReady(recognizer_stream)) { |
| 104 | recognizer.decode(recognizer_stream); | 104 | recognizer.decode(recognizer_stream); |
| 105 | - } | ||
| 106 | - | ||
| 107 | 105 | ||
| 108 | let result = recognizer.getResult(recognizer_stream); | 106 | let result = recognizer.getResult(recognizer_stream); |
| 109 | 107 | ||
| @@ -111,6 +109,10 @@ if (navigator.mediaDevices.getUserMedia) { | @@ -111,6 +109,10 @@ if (navigator.mediaDevices.getUserMedia) { | ||
| 111 | console.log(result) | 109 | console.log(result) |
| 112 | lastResult = result; | 110 | lastResult = result; |
| 113 | resultList.push(JSON.stringify(result)); | 111 | resultList.push(JSON.stringify(result)); |
| 112 | + | ||
| 113 | + // remember to reset the stream right after detecting a keyword | ||
| 114 | + recognizer.reset(recognizer_stream); | ||
| 115 | + } | ||
| 114 | } | 116 | } |
| 115 | 117 | ||
| 116 | 118 |
| @@ -296,8 +296,11 @@ class Kws { | @@ -296,8 +296,11 @@ class Kws { | ||
| 296 | } | 296 | } |
| 297 | 297 | ||
| 298 | decode(stream) { | 298 | decode(stream) { |
| 299 | - return this.Module._SherpaOnnxDecodeKeywordStream( | ||
| 300 | - this.handle, stream.handle); | 299 | + this.Module._SherpaOnnxDecodeKeywordStream(this.handle, stream.handle); |
| 300 | + } | ||
| 301 | + | ||
| 302 | + reset(stream) { | ||
| 303 | + this.Module._SherpaOnnxResetKeywordStream(this.handle, stream.handle); | ||
| 301 | } | 304 | } |
| 302 | 305 | ||
| 303 | getResult(stream) { | 306 | getResult(stream) { |
-
请 注册 或 登录 后发表评论