Committed by
GitHub
Allow more online models to load tokens file from the memory (#1352)
Co-authored-by: xiao <shawl336@6163.com>
正在显示
15 个修改的文件
包含
735 行增加
和
15 行删除
| @@ -120,3 +120,99 @@ jobs: | @@ -120,3 +120,99 @@ jobs: | ||
| 120 | ./streaming-zipformer-buffered-tokens-hotwords-c-api | 120 | ./streaming-zipformer-buffered-tokens-hotwords-c-api |
| 121 | 121 | ||
| 122 | rm -rf sherpa-onnx-streaming-zipformer-* | 122 | rm -rf sherpa-onnx-streaming-zipformer-* |
| 123 | + | ||
| 124 | + - name: Test streaming paraformer with tokens loaded from buffers | ||
| 125 | + shell: bash | ||
| 126 | + run: | | ||
| 127 | + gcc -o streaming-paraformer-buffered-tokens-c-api ./c-api-examples/streaming-paraformer-buffered-tokens-c-api.c \ | ||
| 128 | + -I ./build/install/include \ | ||
| 129 | + -L ./build/install/lib/ \ | ||
| 130 | + -l sherpa-onnx-c-api \ | ||
| 131 | + -l onnxruntime | ||
| 132 | + | ||
| 133 | + ls -lh streaming-paraformer-buffered-tokens-c-api | ||
| 134 | + | ||
| 135 | + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then | ||
| 136 | + ldd ./streaming-paraformer-buffered-tokens-c-api | ||
| 137 | + echo "----" | ||
| 138 | + readelf -d ./streaming-paraformer-buffered-tokens-c-api | ||
| 139 | + fi | ||
| 140 | + | ||
| 141 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 142 | + tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 143 | + rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 144 | + | ||
| 145 | + ls -lh sherpa-onnx-streaming-paraformer-bilingual-zh-en | ||
| 146 | + echo "---" | ||
| 147 | + ls -lh sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs | ||
| 148 | + | ||
| 149 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 150 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 151 | + | ||
| 152 | + ./streaming-paraformer-buffered-tokens-c-api | ||
| 153 | + | ||
| 154 | + rm -rf sherpa-onnx-streaming-paraformer-* | ||
| 155 | + | ||
| 156 | + - name: Test streaming ctc with tokens loaded from buffers | ||
| 157 | + shell: bash | ||
| 158 | + run: | | ||
| 159 | + gcc -o streaming-ctc-buffered-tokens-c-api ./c-api-examples/streaming-ctc-buffered-tokens-c-api.c \ | ||
| 160 | + -I ./build/install/include \ | ||
| 161 | + -L ./build/install/lib/ \ | ||
| 162 | + -l sherpa-onnx-c-api \ | ||
| 163 | + -l onnxruntime | ||
| 164 | + | ||
| 165 | + ls -lh streaming-ctc-buffered-tokens-c-api | ||
| 166 | + | ||
| 167 | + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then | ||
| 168 | + ldd ./streaming-ctc-buffered-tokens-c-api | ||
| 169 | + echo "----" | ||
| 170 | + readelf -d ./streaming-ctc-buffered-tokens-c-api | ||
| 171 | + fi | ||
| 172 | + | ||
| 173 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 174 | + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 175 | + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 176 | + | ||
| 177 | + ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 | ||
| 178 | + echo "---" | ||
| 179 | + ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs | ||
| 180 | + | ||
| 181 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 182 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 183 | + | ||
| 184 | + ./streaming-ctc-buffered-tokens-c-api | ||
| 185 | + | ||
| 186 | + rm -rf sherpa-onnx-streaming-ctc-* | ||
| 187 | + | ||
| 188 | + - name: Test keywords spotting with tokens and keywords loaded from buffers | ||
| 189 | + shell: bash | ||
| 190 | + run: | | ||
| 191 | + gcc -o keywords-spotter-buffered-tokens-keywords-c-api ./c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c \ | ||
| 192 | + -I ./build/install/include \ | ||
| 193 | + -L ./build/install/lib/ \ | ||
| 194 | + -l sherpa-onnx-c-api \ | ||
| 195 | + -l onnxruntime | ||
| 196 | + | ||
| 197 | + ls -lh keywords-spotter-buffered-tokens-keywords-c-api | ||
| 198 | + | ||
| 199 | + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then | ||
| 200 | + ldd ./keywords-spotter-buffered-tokens-keywords-c-api | ||
| 201 | + echo "----" | ||
| 202 | + readelf -d ./keywords-spotter-buffered-tokens-keywords-c-api | ||
| 203 | + fi | ||
| 204 | + | ||
| 205 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 206 | + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 207 | + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 208 | + | ||
| 209 | + ls -lh sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile | ||
| 210 | + echo "---" | ||
| 211 | + ls -lh sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs | ||
| 212 | + | ||
| 213 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 214 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 215 | + | ||
| 216 | + ./keywords-spotter-buffered-tokens-keywords-c-api | ||
| 217 | + | ||
| 218 | + rm -rf sherpa-onnx-kws-zipformer-* |
| @@ -52,6 +52,18 @@ add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api | @@ -52,6 +52,18 @@ add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 52 | streaming-zipformer-buffered-tokens-hotwords-c-api.c) | 52 | streaming-zipformer-buffered-tokens-hotwords-c-api.c) |
| 53 | target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api) | 53 | target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api) |
| 54 | 54 | ||
| 55 | +add_executable(streaming-paraformer-buffered-tokens-c-api | ||
| 56 | + streaming-paraformer-buffered-tokens-c-api.c) | ||
| 57 | +target_link_libraries(streaming-paraformer-buffered-tokens-c-api sherpa-onnx-c-api) | ||
| 58 | + | ||
| 59 | +add_executable(streaming-ctc-buffered-tokens-c-api | ||
| 60 | + streaming-ctc-buffered-tokens-c-api.c) | ||
| 61 | +target_link_libraries(streaming-ctc-buffered-tokens-c-api sherpa-onnx-c-api) | ||
| 62 | + | ||
| 63 | +add_executable(keywords-spotter-buffered-tokens-keywords-c-api | ||
| 64 | + keywords-spotter-buffered-tokens-keywords-c-api.c) | ||
| 65 | +target_link_libraries(keywords-spotter-buffered-tokens-keywords-c-api sherpa-onnx-c-api) | ||
| 66 | + | ||
| 55 | if(SHERPA_ONNX_HAS_ALSA) | 67 | if(SHERPA_ONNX_HAS_ALSA) |
| 56 | add_subdirectory(./asr-microphone-example) | 68 | add_subdirectory(./asr-microphone-example) |
| 57 | elseif((UNIX AND NOT APPLE) OR LINUX) | 69 | elseif((UNIX AND NOT APPLE) OR LINUX) |
| 1 | +// c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Luo Xiao | ||
| 5 | + | ||
| 6 | +// | ||
| 7 | +// This file demonstrates how to use keywords spotter with sherpa-onnx's C | ||
| 8 | +// API and with tokens and keywords loaded from buffered strings instead of from | ||
| 9 | +// external files API. | ||
| 10 | +// clang-format off | ||
| 11 | +// | ||
| 12 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 13 | +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 14 | +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | + | ||
| 18 | +#include <stdio.h> | ||
| 19 | +#include <stdlib.h> | ||
| 20 | +#include <string.h> | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/c-api/c-api.h" | ||
| 23 | + | ||
| 24 | +static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 25 | + FILE *file = fopen(filename, "r"); | ||
| 26 | + if (file == NULL) { | ||
| 27 | + fprintf(stderr, "Failed to open %s\n", filename); | ||
| 28 | + return -1; | ||
| 29 | + } | ||
| 30 | + fseek(file, 0L, SEEK_END); | ||
| 31 | + long size = ftell(file); | ||
| 32 | + rewind(file); | ||
| 33 | + *buffer_out = malloc(size); | ||
| 34 | + if (*buffer_out == NULL) { | ||
| 35 | + fclose(file); | ||
| 36 | + fprintf(stderr, "Memory error\n"); | ||
| 37 | + return -1; | ||
| 38 | + } | ||
| 39 | + size_t read_bytes = fread(*buffer_out, 1, size, file); | ||
| 40 | + if (read_bytes != size) { | ||
| 41 | + printf("Errors occured in reading the file %s\n", filename); | ||
| 42 | + free((void *)*buffer_out); | ||
| 43 | + *buffer_out = NULL; | ||
| 44 | + fclose(file); | ||
| 45 | + return -1; | ||
| 46 | + } | ||
| 47 | + fclose(file); | ||
| 48 | + return read_bytes; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +int32_t main() { | ||
| 52 | + const char *wav_filename = | ||
| 53 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs/" | ||
| 54 | + "6.wav"; | ||
| 55 | + const char *encoder_filename = | ||
| 56 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" | ||
| 57 | + "encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; | ||
| 58 | + const char *decoder_filename = | ||
| 59 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" | ||
| 60 | + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; | ||
| 61 | + const char *joiner_filename = | ||
| 62 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" | ||
| 63 | + "joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; | ||
| 64 | + const char *provider = "cpu"; | ||
| 65 | + const char *tokens_filename = | ||
| 66 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/tokens.txt"; | ||
| 67 | + const char *keywords_filename = | ||
| 68 | + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs/" | ||
| 69 | + "test_keywords.txt"; | ||
| 70 | + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); | ||
| 71 | + if (wave == NULL) { | ||
| 72 | + fprintf(stderr, "Failed to read %s\n", wav_filename); | ||
| 73 | + return -1; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + // reading tokens and keywords to buffers | ||
| 77 | + const char *tokens_buf; | ||
| 78 | + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); | ||
| 79 | + if (token_buf_size < 1) { | ||
| 80 | + fprintf(stderr, "Please check your tokens.txt!\n"); | ||
| 81 | + free((void *)tokens_buf); | ||
| 82 | + return -1; | ||
| 83 | + } | ||
| 84 | + const char *keywords_buf; | ||
| 85 | + size_t keywords_buf_size = ReadFile(keywords_filename, &keywords_buf); | ||
| 86 | + if (keywords_buf_size < 1) { | ||
| 87 | + fprintf(stderr, "Please check your keywords.txt!\n"); | ||
| 88 | + free((void *)keywords_buf); | ||
| 89 | + return -1; | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + // Zipformer config | ||
| 93 | + SherpaOnnxOnlineTransducerModelConfig zipformer_config; | ||
| 94 | + memset(&zipformer_config, 0, sizeof(zipformer_config)); | ||
| 95 | + zipformer_config.encoder = encoder_filename; | ||
| 96 | + zipformer_config.decoder = decoder_filename; | ||
| 97 | + zipformer_config.joiner = joiner_filename; | ||
| 98 | + | ||
| 99 | + // Online model config | ||
| 100 | + SherpaOnnxOnlineModelConfig online_model_config; | ||
| 101 | + memset(&online_model_config, 0, sizeof(online_model_config)); | ||
| 102 | + online_model_config.debug = 1; | ||
| 103 | + online_model_config.num_threads = 1; | ||
| 104 | + online_model_config.provider = provider; | ||
| 105 | + online_model_config.tokens_buf = tokens_buf; | ||
| 106 | + online_model_config.tokens_buf_size = token_buf_size; | ||
| 107 | + online_model_config.transducer = zipformer_config; | ||
| 108 | + | ||
| 109 | + // Keywords-spotter config | ||
| 110 | + SherpaOnnxKeywordSpotterConfig keywords_spotter_config; | ||
| 111 | + memset(&keywords_spotter_config, 0, sizeof(keywords_spotter_config)); | ||
| 112 | + keywords_spotter_config.max_active_paths = 4; | ||
| 113 | + keywords_spotter_config.keywords_threshold = 0.1; | ||
| 114 | + keywords_spotter_config.keywords_score = 3.0; | ||
| 115 | + keywords_spotter_config.model_config = online_model_config; | ||
| 116 | + keywords_spotter_config.keywords_buf = keywords_buf; | ||
| 117 | + keywords_spotter_config.keywords_buf_size = keywords_buf_size; | ||
| 118 | + | ||
| 119 | + SherpaOnnxKeywordSpotter *keywords_spotter = | ||
| 120 | + SherpaOnnxCreateKeywordSpotter(&keywords_spotter_config); | ||
| 121 | + | ||
| 122 | + free((void *)tokens_buf); | ||
| 123 | + tokens_buf = NULL; | ||
| 124 | + free((void *)keywords_buf); | ||
| 125 | + keywords_buf = NULL; | ||
| 126 | + | ||
| 127 | + if (keywords_spotter == NULL) { | ||
| 128 | + fprintf(stderr, "Please check your config!\n"); | ||
| 129 | + SherpaOnnxFreeWave(wave); | ||
| 130 | + return -1; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + SherpaOnnxOnlineStream *stream = | ||
| 134 | + SherpaOnnxCreateKeywordStream(keywords_spotter); | ||
| 135 | + | ||
| 136 | + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); | ||
| 137 | + int32_t segment_id = 0; | ||
| 138 | + | ||
| 139 | +// simulate streaming. You can choose an arbitrary N | ||
| 140 | +#define N 3200 | ||
| 141 | + | ||
| 142 | + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", | ||
| 143 | + wave->sample_rate, wave->num_samples, | ||
| 144 | + (float)wave->num_samples / wave->sample_rate); | ||
| 145 | + | ||
| 146 | + int32_t k = 0; | ||
| 147 | + while (k < wave->num_samples) { | ||
| 148 | + int32_t start = k; | ||
| 149 | + int32_t end = | ||
| 150 | + (start + N > wave->num_samples) ? wave->num_samples : (start + N); | ||
| 151 | + k += N; | ||
| 152 | + | ||
| 153 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, | ||
| 154 | + wave->samples + start, end - start); | ||
| 155 | + while (SherpaOnnxIsKeywordStreamReady(keywords_spotter, stream)) { | ||
| 156 | + SherpaOnnxDecodeKeywordStream(keywords_spotter, stream); | ||
| 157 | + } | ||
| 158 | + | ||
| 159 | + const SherpaOnnxKeywordResult *r = | ||
| 160 | + SherpaOnnxGetKeywordResult(keywords_spotter, stream); | ||
| 161 | + | ||
| 162 | + if (strlen(r->keyword)) { | ||
| 163 | + SherpaOnnxPrint(display, segment_id, r->keyword); | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 167 | + } | ||
| 168 | + | ||
| 169 | + // add some tail padding | ||
| 170 | + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate | ||
| 171 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 172 | + 4800); | ||
| 173 | + | ||
| 174 | + SherpaOnnxFreeWave(wave); | ||
| 175 | + | ||
| 176 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 177 | + while (SherpaOnnxIsKeywordStreamReady(keywords_spotter, stream)) { | ||
| 178 | + SherpaOnnxDecodeKeywordStream(keywords_spotter, stream); | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + const SherpaOnnxKeywordResult *r = | ||
| 182 | + SherpaOnnxGetKeywordResult(keywords_spotter, stream); | ||
| 183 | + | ||
| 184 | + if (strlen(r->keyword)) { | ||
| 185 | + SherpaOnnxPrint(display, segment_id, r->keyword); | ||
| 186 | + } | ||
| 187 | + | ||
| 188 | + SherpaOnnxDestroyKeywordResult(r); | ||
| 189 | + | ||
| 190 | + SherpaOnnxDestroyDisplay(display); | ||
| 191 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 192 | + SherpaOnnxDestroyKeywordSpotter(keywords_spotter); | ||
| 193 | + fprintf(stderr, "\n"); | ||
| 194 | + | ||
| 195 | + return 0; | ||
| 196 | +} |
| 1 | +// c-api-examples/streaming-ctc-buffered-tokens-c-api.c | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Luo Xiao | ||
| 5 | + | ||
| 6 | +// | ||
| 7 | +// This file demonstrates how to use streaming Zipformer2 Ctc with sherpa-onnx's | ||
| 8 | +// C API and with tokens loaded from buffered strings instead of | ||
| 9 | +// from external files API. | ||
| 10 | +// clang-format off | ||
| 11 | +// | ||
| 12 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 13 | +// tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 14 | +// rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | + | ||
| 18 | +#include <stdio.h> | ||
| 19 | +#include <stdlib.h> | ||
| 20 | +#include <string.h> | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/c-api/c-api.h" | ||
| 23 | + | ||
| 24 | +static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 25 | + FILE *file = fopen(filename, "r"); | ||
| 26 | + if (file == NULL) { | ||
| 27 | + fprintf(stderr, "Failed to open %s\n", filename); | ||
| 28 | + return -1; | ||
| 29 | + } | ||
| 30 | + fseek(file, 0L, SEEK_END); | ||
| 31 | + long size = ftell(file); | ||
| 32 | + rewind(file); | ||
| 33 | + *buffer_out = malloc(size); | ||
| 34 | + if (*buffer_out == NULL) { | ||
| 35 | + fclose(file); | ||
| 36 | + fprintf(stderr, "Memory error\n"); | ||
| 37 | + return -1; | ||
| 38 | + } | ||
| 39 | + size_t read_bytes = fread(*buffer_out, 1, size, file); | ||
| 40 | + if (read_bytes != size) { | ||
| 41 | + printf("Errors occured in reading the file %s\n", filename); | ||
| 42 | + free((void *)*buffer_out); | ||
| 43 | + *buffer_out = NULL; | ||
| 44 | + fclose(file); | ||
| 45 | + return -1; | ||
| 46 | + } | ||
| 47 | + fclose(file); | ||
| 48 | + return read_bytes; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +int32_t main() { | ||
| 52 | + const char *wav_filename = | ||
| 53 | + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/" | ||
| 54 | + "DEV_T0000000000.wav"; | ||
| 55 | + const char *model_filename = | ||
| 56 | + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/" | ||
| 57 | + "ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"; | ||
| 58 | + const char *tokens_filename = | ||
| 59 | + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt"; | ||
| 60 | + const char *provider = "cpu"; | ||
| 61 | + | ||
| 62 | + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); | ||
| 63 | + if (wave == NULL) { | ||
| 64 | + fprintf(stderr, "Failed to read %s\n", wav_filename); | ||
| 65 | + return -1; | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + // reading tokens to buffers | ||
| 69 | + const char *tokens_buf; | ||
| 70 | + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); | ||
| 71 | + if (token_buf_size < 1) { | ||
| 72 | + fprintf(stderr, "Please check your tokens.txt!\n"); | ||
| 73 | + free((void *)tokens_buf); | ||
| 74 | + return -1; | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + // Zipformer2Ctc config | ||
| 78 | + SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2_ctc_config; | ||
| 79 | + memset(&zipformer2_ctc_config, 0, sizeof(zipformer2_ctc_config)); | ||
| 80 | + zipformer2_ctc_config.model = model_filename; | ||
| 81 | + | ||
| 82 | + // Online model config | ||
| 83 | + SherpaOnnxOnlineModelConfig online_model_config; | ||
| 84 | + memset(&online_model_config, 0, sizeof(online_model_config)); | ||
| 85 | + online_model_config.debug = 1; | ||
| 86 | + online_model_config.num_threads = 1; | ||
| 87 | + online_model_config.provider = provider; | ||
| 88 | + online_model_config.tokens_buf = tokens_buf; | ||
| 89 | + online_model_config.tokens_buf_size = token_buf_size; | ||
| 90 | + online_model_config.zipformer2_ctc = zipformer2_ctc_config; | ||
| 91 | + | ||
| 92 | + // Recognizer config | ||
| 93 | + SherpaOnnxOnlineRecognizerConfig recognizer_config; | ||
| 94 | + memset(&recognizer_config, 0, sizeof(recognizer_config)); | ||
| 95 | + recognizer_config.decoding_method = "greedy_search"; | ||
| 96 | + recognizer_config.model_config = online_model_config; | ||
| 97 | + | ||
| 98 | + SherpaOnnxOnlineRecognizer *recognizer = | ||
| 99 | + SherpaOnnxCreateOnlineRecognizer(&recognizer_config); | ||
| 100 | + | ||
| 101 | + free((void *)tokens_buf); | ||
| 102 | + tokens_buf = NULL; | ||
| 103 | + | ||
| 104 | + if (recognizer == NULL) { | ||
| 105 | + fprintf(stderr, "Please check your config!\n"); | ||
| 106 | + SherpaOnnxFreeWave(wave); | ||
| 107 | + return -1; | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); | ||
| 111 | + | ||
| 112 | + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); | ||
| 113 | + int32_t segment_id = 0; | ||
| 114 | + | ||
| 115 | +// simulate streaming. You can choose an arbitrary N | ||
| 116 | +#define N 3200 | ||
| 117 | + | ||
| 118 | + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", | ||
| 119 | + wave->sample_rate, wave->num_samples, | ||
| 120 | + (float)wave->num_samples / wave->sample_rate); | ||
| 121 | + | ||
| 122 | + int32_t k = 0; | ||
| 123 | + while (k < wave->num_samples) { | ||
| 124 | + int32_t start = k; | ||
| 125 | + int32_t end = | ||
| 126 | + (start + N > wave->num_samples) ? wave->num_samples : (start + N); | ||
| 127 | + k += N; | ||
| 128 | + | ||
| 129 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, | ||
| 130 | + wave->samples + start, end - start); | ||
| 131 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 132 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 136 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 137 | + | ||
| 138 | + if (strlen(r->text)) { | ||
| 139 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) { | ||
| 143 | + if (strlen(r->text)) { | ||
| 144 | + ++segment_id; | ||
| 145 | + } | ||
| 146 | + SherpaOnnxOnlineStreamReset(recognizer, stream); | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + // add some tail padding | ||
| 153 | + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate | ||
| 154 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 155 | + 4800); | ||
| 156 | + | ||
| 157 | + SherpaOnnxFreeWave(wave); | ||
| 158 | + | ||
| 159 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 160 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 161 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 165 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 166 | + | ||
| 167 | + if (strlen(r->text)) { | ||
| 168 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 169 | + } | ||
| 170 | + | ||
| 171 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 172 | + | ||
| 173 | + SherpaOnnxDestroyDisplay(display); | ||
| 174 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 175 | + SherpaOnnxDestroyOnlineRecognizer(recognizer); | ||
| 176 | + fprintf(stderr, "\n"); | ||
| 177 | + | ||
| 178 | + return 0; | ||
| 179 | +} |
| 1 | +// c-api-examples/streaming-paraformer-buffered-tokens-c-api.c | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Luo Xiao | ||
| 5 | + | ||
| 6 | +// | ||
| 7 | +// This file demonstrates how to use streaming Paraformer with sherpa-onnx's C | ||
| 8 | +// API and with tokens loaded from buffered strings instead of from | ||
| 9 | +// external files API. | ||
| 10 | +// clang-format off | ||
| 11 | +// | ||
| 12 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 13 | +// tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 14 | +// rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | + | ||
| 18 | +#include <stdio.h> | ||
| 19 | +#include <stdlib.h> | ||
| 20 | +#include <string.h> | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/c-api/c-api.h" | ||
| 23 | + | ||
| 24 | +static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 25 | + FILE *file = fopen(filename, "r"); | ||
| 26 | + if (file == NULL) { | ||
| 27 | + fprintf(stderr, "Failed to open %s\n", filename); | ||
| 28 | + return -1; | ||
| 29 | + } | ||
| 30 | + fseek(file, 0L, SEEK_END); | ||
| 31 | + long size = ftell(file); | ||
| 32 | + rewind(file); | ||
| 33 | + *buffer_out = malloc(size); | ||
| 34 | + if (*buffer_out == NULL) { | ||
| 35 | + fclose(file); | ||
| 36 | + fprintf(stderr, "Memory error\n"); | ||
| 37 | + return -1; | ||
| 38 | + } | ||
| 39 | + size_t read_bytes = fread(*buffer_out, 1, size, file); | ||
| 40 | + if (read_bytes != size) { | ||
| 41 | + printf("Errors occured in reading the file %s\n", filename); | ||
| 42 | + free((void *)*buffer_out); | ||
| 43 | + *buffer_out = NULL; | ||
| 44 | + fclose(file); | ||
| 45 | + return -1; | ||
| 46 | + } | ||
| 47 | + fclose(file); | ||
| 48 | + return read_bytes; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +int32_t main() { | ||
| 52 | + const char *wav_filename = | ||
| 53 | + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav"; | ||
| 54 | + const char *encoder_filename = | ||
| 55 | + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"; | ||
| 56 | + const char *decoder_filename = | ||
| 57 | + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"; | ||
| 58 | + const char *tokens_filename = | ||
| 59 | + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"; | ||
| 60 | + const char *provider = "cpu"; | ||
| 61 | + | ||
| 62 | + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); | ||
| 63 | + if (wave == NULL) { | ||
| 64 | + fprintf(stderr, "Failed to read %s\n", wav_filename); | ||
| 65 | + return -1; | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + // reading tokens to buffers | ||
| 69 | + const char *tokens_buf; | ||
| 70 | + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); | ||
| 71 | + if (token_buf_size < 1) { | ||
| 72 | + fprintf(stderr, "Please check your tokens.txt!\n"); | ||
| 73 | + free((void *)tokens_buf); | ||
| 74 | + return -1; | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + // Paraformer config | ||
| 78 | + SherpaOnnxOnlineParaformerModelConfig paraformer_config; | ||
| 79 | + memset(¶former_config, 0, sizeof(paraformer_config)); | ||
| 80 | + paraformer_config.encoder = encoder_filename; | ||
| 81 | + paraformer_config.decoder = decoder_filename; | ||
| 82 | + | ||
| 83 | + // Online model config | ||
| 84 | + SherpaOnnxOnlineModelConfig online_model_config; | ||
| 85 | + memset(&online_model_config, 0, sizeof(online_model_config)); | ||
| 86 | + online_model_config.debug = 1; | ||
| 87 | + online_model_config.num_threads = 1; | ||
| 88 | + online_model_config.provider = provider; | ||
| 89 | + online_model_config.tokens_buf = tokens_buf; | ||
| 90 | + online_model_config.tokens_buf_size = token_buf_size; | ||
| 91 | + online_model_config.paraformer = paraformer_config; | ||
| 92 | + | ||
| 93 | + // Recognizer config | ||
| 94 | + SherpaOnnxOnlineRecognizerConfig recognizer_config; | ||
| 95 | + memset(&recognizer_config, 0, sizeof(recognizer_config)); | ||
| 96 | + recognizer_config.decoding_method = "greedy_search"; | ||
| 97 | + recognizer_config.model_config = online_model_config; | ||
| 98 | + | ||
| 99 | + SherpaOnnxOnlineRecognizer *recognizer = | ||
| 100 | + SherpaOnnxCreateOnlineRecognizer(&recognizer_config); | ||
| 101 | + | ||
| 102 | + free((void *)tokens_buf); | ||
| 103 | + tokens_buf = NULL; | ||
| 104 | + | ||
| 105 | + if (recognizer == NULL) { | ||
| 106 | + fprintf(stderr, "Please check your config!\n"); | ||
| 107 | + SherpaOnnxFreeWave(wave); | ||
| 108 | + return -1; | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); | ||
| 112 | + | ||
| 113 | + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); | ||
| 114 | + int32_t segment_id = 0; | ||
| 115 | + | ||
| 116 | +// simulate streaming. You can choose an arbitrary N | ||
| 117 | +#define N 3200 | ||
| 118 | + | ||
| 119 | + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", | ||
| 120 | + wave->sample_rate, wave->num_samples, | ||
| 121 | + (float)wave->num_samples / wave->sample_rate); | ||
| 122 | + | ||
| 123 | + int32_t k = 0; | ||
| 124 | + while (k < wave->num_samples) { | ||
| 125 | + int32_t start = k; | ||
| 126 | + int32_t end = | ||
| 127 | + (start + N > wave->num_samples) ? wave->num_samples : (start + N); | ||
| 128 | + k += N; | ||
| 129 | + | ||
| 130 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, | ||
| 131 | + wave->samples + start, end - start); | ||
| 132 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 133 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 137 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 138 | + | ||
| 139 | + if (strlen(r->text)) { | ||
| 140 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) { | ||
| 144 | + if (strlen(r->text)) { | ||
| 145 | + ++segment_id; | ||
| 146 | + } | ||
| 147 | + SherpaOnnxOnlineStreamReset(recognizer, stream); | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + // add some tail padding | ||
| 154 | + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate | ||
| 155 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 156 | + 4800); | ||
| 157 | + | ||
| 158 | + SherpaOnnxFreeWave(wave); | ||
| 159 | + | ||
| 160 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 161 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 162 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 166 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 167 | + | ||
| 168 | + if (strlen(r->text)) { | ||
| 169 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 173 | + | ||
| 174 | + SherpaOnnxDestroyDisplay(display); | ||
| 175 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 176 | + SherpaOnnxDestroyOnlineRecognizer(recognizer); | ||
| 177 | + fprintf(stderr, "\n"); | ||
| 178 | + | ||
| 179 | + return 0; | ||
| 180 | +} |
| @@ -5,7 +5,7 @@ | @@ -5,7 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | // | 6 | // |
| 7 | // This file demonstrates how to use streaming Zipformer with sherpa-onnx's C | 7 | // This file demonstrates how to use streaming Zipformer with sherpa-onnx's C |
| 8 | -// and with tokens and hotwords loaded from buffered strings instead of from | 8 | +// API and with tokens and hotwords loaded from buffered strings instead of from |
| 9 | // external files API. | 9 | // external files API. |
| 10 | // clang-format off | 10 | // clang-format off |
| 11 | // | 11 | // |
| @@ -667,6 +667,12 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | @@ -667,6 +667,12 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | ||
| 667 | 667 | ||
| 668 | spotter_config.model_config.tokens = | 668 | spotter_config.model_config.tokens = |
| 669 | SHERPA_ONNX_OR(config->model_config.tokens, ""); | 669 | SHERPA_ONNX_OR(config->model_config.tokens, ""); |
| 670 | + if (config->model_config.tokens_buf && | ||
| 671 | + config->model_config.tokens_buf_size > 0) { | ||
| 672 | + spotter_config.model_config.tokens_buf = std::string( | ||
| 673 | + config->model_config.tokens_buf, config->model_config.tokens_buf_size); | ||
| 674 | + } | ||
| 675 | + | ||
| 670 | spotter_config.model_config.num_threads = | 676 | spotter_config.model_config.num_threads = |
| 671 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); | 677 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); |
| 672 | spotter_config.model_config.provider_config.provider = | 678 | spotter_config.model_config.provider_config.provider = |
| @@ -691,6 +697,10 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | @@ -691,6 +697,10 @@ SherpaOnnxKeywordSpotter *SherpaOnnxCreateKeywordSpotter( | ||
| 691 | SHERPA_ONNX_OR(config->keywords_threshold, 0.25); | 697 | SHERPA_ONNX_OR(config->keywords_threshold, 0.25); |
| 692 | 698 | ||
| 693 | spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); | 699 | spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); |
| 700 | + if (config->keywords_buf && config->keywords_buf_size > 0) { | ||
| 701 | + spotter_config.keywords_buf = | ||
| 702 | + std::string(config->keywords_buf, config->keywords_buf_size); | ||
| 703 | + } | ||
| 694 | 704 | ||
| 695 | if (config->model_config.debug) { | 705 | if (config->model_config.debug) { |
| 696 | SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); | 706 | SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); |
| @@ -88,8 +88,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | @@ -88,8 +88,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | ||
| 88 | // - cjkchar+bpe | 88 | // - cjkchar+bpe |
| 89 | const char *modeling_unit; | 89 | const char *modeling_unit; |
| 90 | const char *bpe_vocab; | 90 | const char *bpe_vocab; |
| 91 | - /// if non-null, loading the tokens from the buffered string directly in | ||
| 92 | - /// prioriy | 91 | + /// if non-null, loading the tokens from the buffer instead of from the |
| 92 | + /// "tokens" file | ||
| 93 | const char *tokens_buf; | 93 | const char *tokens_buf; |
| 94 | /// byte size excluding the trailing '\0' | 94 | /// byte size excluding the trailing '\0' |
| 95 | int32_t tokens_buf_size; | 95 | int32_t tokens_buf_size; |
| @@ -637,6 +637,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | @@ -637,6 +637,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | ||
| 637 | float keywords_score; | 637 | float keywords_score; |
| 638 | float keywords_threshold; | 638 | float keywords_threshold; |
| 639 | const char *keywords_file; | 639 | const char *keywords_file; |
| 640 | + /// if non-null, loading the keywords from the buffer instead of from the | ||
| 641 | + /// keywords_file | ||
| 642 | + const char *keywords_buf; | ||
| 643 | + /// byte size excluding the trailing '\0' | ||
| 644 | + int32_t keywords_buf_size; | ||
| 640 | } SherpaOnnxKeywordSpotterConfig; | 645 | } SherpaOnnxKeywordSpotterConfig; |
| 641 | 646 | ||
| 642 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | 647 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter |
| @@ -66,15 +66,25 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -66,15 +66,25 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 66 | public: | 66 | public: |
| 67 | explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) | 67 | explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) |
| 68 | : config_(config), | 68 | : config_(config), |
| 69 | - model_(OnlineTransducerModel::Create(config.model_config)), | ||
| 70 | - sym_(config.model_config.tokens) { | 69 | + model_(OnlineTransducerModel::Create(config.model_config)) { |
| 70 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 71 | + sym_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 72 | + } else { | ||
| 73 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 74 | + sym_ = SymbolTable(config.model_config.tokens, true); | ||
| 75 | + } | ||
| 76 | + | ||
| 71 | if (sym_.Contains("<unk>")) { | 77 | if (sym_.Contains("<unk>")) { |
| 72 | unk_id_ = sym_["<unk>"]; | 78 | unk_id_ = sym_["<unk>"]; |
| 73 | } | 79 | } |
| 74 | 80 | ||
| 75 | model_->SetFeatureDim(config.feat_config.feature_dim); | 81 | model_->SetFeatureDim(config.feat_config.feature_dim); |
| 76 | 82 | ||
| 77 | - InitKeywords(); | 83 | + if (config.keywords_buf.empty()) { |
| 84 | + InitKeywords(); | ||
| 85 | + } else { | ||
| 86 | + InitKeywordsFromBufStr(); | ||
| 87 | + } | ||
| 78 | 88 | ||
| 79 | decoder_ = std::make_unique<TransducerKeywordDecoder>( | 89 | decoder_ = std::make_unique<TransducerKeywordDecoder>( |
| 80 | model_.get(), config_.max_active_paths, config_.num_trailing_blanks, | 90 | model_.get(), config_.max_active_paths, config_.num_trailing_blanks, |
| @@ -305,6 +315,12 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -305,6 +315,12 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 305 | } | 315 | } |
| 306 | #endif | 316 | #endif |
| 307 | 317 | ||
| 318 | + void InitKeywordsFromBufStr() { | ||
| 319 | + // keywords_buf's content is supposed to be same as the keywords_file's | ||
| 320 | + std::istringstream is(config_.keywords_buf); | ||
| 321 | + InitKeywords(is); | ||
| 322 | + } | ||
| 323 | + | ||
| 308 | void InitOnlineStream(OnlineStream *stream) const { | 324 | void InitOnlineStream(OnlineStream *stream) const { |
| 309 | auto r = decoder_->GetEmptyResult(); | 325 | auto r = decoder_->GetEmptyResult(); |
| 310 | SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); | 326 | SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); |
| @@ -89,8 +89,17 @@ void KeywordSpotterConfig::Register(ParseOptions *po) { | @@ -89,8 +89,17 @@ void KeywordSpotterConfig::Register(ParseOptions *po) { | ||
| 89 | } | 89 | } |
| 90 | 90 | ||
| 91 | bool KeywordSpotterConfig::Validate() const { | 91 | bool KeywordSpotterConfig::Validate() const { |
| 92 | - if (keywords_file.empty()) { | ||
| 93 | - SHERPA_ONNX_LOGE("Please provide --keywords-file."); | 92 | + if (!keywords_file.empty() && !keywords_buf.empty()) { |
| 93 | + SHERPA_ONNX_LOGE( | ||
| 94 | + "you can not provide a keywords_buf and a keywords file: '%s', " | ||
| 95 | + "at the same time, which is confusing", | ||
| 96 | + keywords_file.c_str()); | ||
| 97 | + return false; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + if (keywords_file.empty() && keywords_buf.empty()) { | ||
| 101 | + SHERPA_ONNX_LOGE( | ||
| 102 | + "Please provide either a keywords-file or the keywords-buf"); | ||
| 94 | return false; | 103 | return false; |
| 95 | } | 104 | } |
| 96 | 105 | ||
| @@ -99,7 +108,7 @@ bool KeywordSpotterConfig::Validate() const { | @@ -99,7 +108,7 @@ bool KeywordSpotterConfig::Validate() const { | ||
| 99 | // keywords file will be packaged into the sherpa-onnx-wasm-kws-main.data file | 108 | // keywords file will be packaged into the sherpa-onnx-wasm-kws-main.data file |
| 100 | // Solution: take keyword_file variable is directly | 109 | // Solution: take keyword_file variable is directly |
| 101 | // parsed as a string of keywords | 110 | // parsed as a string of keywords |
| 102 | - if (!std::ifstream(keywords_file.c_str()).good()) { | 111 | + if (keywords_buf.empty() && !std::ifstream(keywords_file.c_str()).good()) { |
| 103 | SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.", | 112 | SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.", |
| 104 | keywords_file.c_str()); | 113 | keywords_file.c_str()); |
| 105 | return false; | 114 | return false; |
| @@ -69,6 +69,11 @@ struct KeywordSpotterConfig { | @@ -69,6 +69,11 @@ struct KeywordSpotterConfig { | ||
| 69 | 69 | ||
| 70 | std::string keywords_file; | 70 | std::string keywords_file; |
| 71 | 71 | ||
| 72 | + /// if keywords_buf is non-empty, | ||
| 73 | + /// the keywords will be loaded from the buffer instead of from the | ||
| 74 | + /// "keywrods_file" | ||
| 75 | + std::string keywords_buf; | ||
| 76 | + | ||
| 72 | KeywordSpotterConfig() = default; | 77 | KeywordSpotterConfig() = default; |
| 73 | 78 | ||
| 74 | KeywordSpotterConfig(const FeatureExtractorConfig &feat_config, | 79 | KeywordSpotterConfig(const FeatureExtractorConfig &feat_config, |
| @@ -46,8 +46,8 @@ struct OnlineModelConfig { | @@ -46,8 +46,8 @@ struct OnlineModelConfig { | ||
| 46 | std::string bpe_vocab; | 46 | std::string bpe_vocab; |
| 47 | 47 | ||
| 48 | /// if tokens_buf is non-empty, | 48 | /// if tokens_buf is non-empty, |
| 49 | - /// the tokens will be loaded from the buffered string instead of from the | ||
| 50 | - /// ${tokens} file | 49 | + /// the tokens will be loaded from the buffer instead of from the |
| 50 | + /// "tokens" file | ||
| 51 | std::string tokens_buf; | 51 | std::string tokens_buf; |
| 52 | 52 | ||
| 53 | OnlineModelConfig() = default; | 53 | OnlineModelConfig() = default; |
| @@ -71,8 +71,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -71,8 +71,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 71 | : OnlineRecognizerImpl(config), | 71 | : OnlineRecognizerImpl(config), |
| 72 | config_(config), | 72 | config_(config), |
| 73 | model_(OnlineCtcModel::Create(config.model_config)), | 73 | model_(OnlineCtcModel::Create(config.model_config)), |
| 74 | - sym_(config.model_config.tokens), | ||
| 75 | endpoint_(config_.endpoint_config) { | 74 | endpoint_(config_.endpoint_config) { |
| 75 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 76 | + sym_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 77 | + } else { | ||
| 78 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 79 | + sym_ = SymbolTable(config.model_config.tokens, true); | ||
| 80 | + } | ||
| 81 | + | ||
| 76 | if (!config.model_config.wenet_ctc.model.empty()) { | 82 | if (!config.model_config.wenet_ctc.model.empty()) { |
| 77 | // WeNet CTC models assume input samples are in the range | 83 | // WeNet CTC models assume input samples are in the range |
| 78 | // [-32768, 32767], so we set normalize_samples to false | 84 | // [-32768, 32767], so we set normalize_samples to false |
| @@ -99,8 +99,14 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | @@ -99,8 +99,14 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 99 | : OnlineRecognizerImpl(config), | 99 | : OnlineRecognizerImpl(config), |
| 100 | config_(config), | 100 | config_(config), |
| 101 | model_(config.model_config), | 101 | model_(config.model_config), |
| 102 | - sym_(config.model_config.tokens), | ||
| 103 | endpoint_(config_.endpoint_config) { | 102 | endpoint_(config_.endpoint_config) { |
| 103 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 104 | + sym_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 105 | + } else { | ||
| 106 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 107 | + sym_ = SymbolTable(config.model_config.tokens, true); | ||
| 108 | + } | ||
| 109 | + | ||
| 104 | if (config.decoding_method != "greedy_search") { | 110 | if (config.decoding_method != "greedy_search") { |
| 105 | SHERPA_ONNX_LOGE( | 111 | SHERPA_ONNX_LOGE( |
| 106 | "Unsupported decoding method: %s. Support only greedy_search at " | 112 | "Unsupported decoding method: %s. Support only greedy_search at " |
| @@ -107,8 +107,8 @@ struct OnlineRecognizerConfig { | @@ -107,8 +107,8 @@ struct OnlineRecognizerConfig { | ||
| 107 | std::string rule_fars; | 107 | std::string rule_fars; |
| 108 | 108 | ||
| 109 | /// used only for modified_beam_search, if hotwords_buf is non-empty, | 109 | /// used only for modified_beam_search, if hotwords_buf is non-empty, |
| 110 | - /// the hotwords will be loaded from the buffered string instead of from | ||
| 111 | - /// ${hotwords_file} | 110 | + /// the hotwords will be loaded from the buffered string instead of from the |
| 111 | + /// "hotwords_file" | ||
| 112 | std::string hotwords_buf; | 112 | std::string hotwords_buf; |
| 113 | 113 | ||
| 114 | OnlineRecognizerConfig() = default; | 114 | OnlineRecognizerConfig() = default; |
-
请 注册 或 登录 后发表评论