Committed by
GitHub
re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)
Co-authored-by: xiao <shawl336@163.com>
正在显示
12 个修改的文件
包含
414 行增加
和
16 行删除
| 1 | +name: c-api-test-loading-tokens-hotwords-from-memory | ||
| 2 | + | ||
| 3 | +on: | ||
| 4 | + push: | ||
| 5 | + branches: | ||
| 6 | + - master | ||
| 7 | + tags: | ||
| 8 | + - 'v[0-9]+.[0-9]+.[0-9]+*' | ||
| 9 | + paths: | ||
| 10 | + - '.github/workflows/c-api.yaml' | ||
| 11 | + - 'CMakeLists.txt' | ||
| 12 | + - 'cmake/**' | ||
| 13 | + - 'sherpa-onnx/csrc/*' | ||
| 14 | + - 'sherpa-onnx/c-api/*' | ||
| 15 | + - 'c-api-examples/**' | ||
| 16 | + - 'ffmpeg-examples/**' | ||
| 17 | + pull_request: | ||
| 18 | + branches: | ||
| 19 | + - master | ||
| 20 | + paths: | ||
| 21 | + - '.github/workflows/c-api.yaml' | ||
| 22 | + - 'CMakeLists.txt' | ||
| 23 | + - 'cmake/**' | ||
| 24 | + - 'sherpa-onnx/csrc/*' | ||
| 25 | + - 'sherpa-onnx/c-api/*' | ||
| 26 | + - 'c-api-examples/**' | ||
| 27 | + - 'ffmpeg-examples/**' | ||
| 28 | + | ||
| 29 | + workflow_dispatch: | ||
| 30 | + | ||
| 31 | +concurrency: | ||
| 32 | + group: c-api-${{ github.ref }} | ||
| 33 | + cancel-in-progress: true | ||
| 34 | + | ||
| 35 | +jobs: | ||
| 36 | + c_api: | ||
| 37 | + name: ${{ matrix.os }} | ||
| 38 | + runs-on: ${{ matrix.os }} | ||
| 39 | + strategy: | ||
| 40 | + fail-fast: false | ||
| 41 | + matrix: | ||
| 42 | + os: [ubuntu-latest, macos-latest] | ||
| 43 | + | ||
| 44 | + steps: | ||
| 45 | + - uses: actions/checkout@v4 | ||
| 46 | + with: | ||
| 47 | + fetch-depth: 0 | ||
| 48 | + | ||
| 49 | + - name: ccache | ||
| 50 | + uses: hendrikmuhs/ccache-action@v1.2 | ||
| 51 | + with: | ||
| 52 | + key: ${{ matrix.os }}-c-api-shared | ||
| 53 | + | ||
| 54 | + - name: Build sherpa-onnx | ||
| 55 | + shell: bash | ||
| 56 | + run: | | ||
| 57 | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache | ||
| 58 | + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" | ||
| 59 | + cmake --version | ||
| 60 | + | ||
| 61 | + mkdir build | ||
| 62 | + cd build | ||
| 63 | + | ||
| 64 | + cmake \ | ||
| 65 | + -D CMAKE_BUILD_TYPE=Release \ | ||
| 66 | + -D BUILD_SHARED_LIBS=ON \ | ||
| 67 | + -D CMAKE_INSTALL_PREFIX=./install \ | ||
| 68 | + -D SHERPA_ONNX_ENABLE_BINARY=OFF \ | ||
| 69 | + .. | ||
| 70 | + | ||
| 71 | + make -j2 install | ||
| 72 | + | ||
| 73 | + ls -lh install/lib | ||
| 74 | + ls -lh install/include | ||
| 75 | + | ||
| 76 | + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then | ||
| 77 | + ldd ./install/lib/libsherpa-onnx-c-api.so | ||
| 78 | + echo "---" | ||
| 79 | + readelf -d ./install/lib/libsherpa-onnx-c-api.so | ||
| 80 | + fi | ||
| 81 | + | ||
| 82 | + if [[ ${{ matrix.os }} == macos-latest ]]; then | ||
| 83 | + otool -L ./install/lib/libsherpa-onnx-c-api.dylib | ||
| 84 | + fi | ||
| 85 | + | ||
| 86 | + - name: Test streaming zipformer with tokens and hotwords loaded from buffers | ||
| 87 | + shell: bash | ||
| 88 | + run: | | ||
| 89 | + gcc -o streaming-zipformer-buffered-tokens-hotwords-c-api ./c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c \ | ||
| 90 | + -I ./build/install/include \ | ||
| 91 | + -L ./build/install/lib/ \ | ||
| 92 | + -l sherpa-onnx-c-api \ | ||
| 93 | + -l onnxruntime | ||
| 94 | + | ||
| 95 | + ls -lh streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 96 | + | ||
| 97 | + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then | ||
| 98 | + ldd ./streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 99 | + echo "----" | ||
| 100 | + readelf -d ./streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 101 | + fi | ||
| 102 | + | ||
| 103 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 104 | + tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 105 | + rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 106 | + curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model | ||
| 107 | + cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/ | ||
| 108 | + rm bpe.model | ||
| 109 | + | ||
| 110 | + printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt | ||
| 111 | + | ||
| 112 | + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 | ||
| 113 | + echo "---" | ||
| 114 | + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs | ||
| 115 | + | ||
| 116 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 117 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 118 | + | ||
| 119 | + ./streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 120 | + | ||
| 121 | + rm -rf sherpa-onnx-streaming-zipformer-* |
| @@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api) | @@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api) | ||
| 48 | add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c) | 48 | add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c) |
| 49 | target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api) | 49 | target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api) |
| 50 | 50 | ||
| 51 | +add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api | ||
| 52 | + streaming-zipformer-buffered-tokens-hotwords-c-api.c) | ||
| 53 | +target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api) | ||
| 54 | + | ||
| 51 | if(SHERPA_ONNX_HAS_ALSA) | 55 | if(SHERPA_ONNX_HAS_ALSA) |
| 52 | add_subdirectory(./asr-microphone-example) | 56 | add_subdirectory(./asr-microphone-example) |
| 53 | elseif((UNIX AND NOT APPLE) OR LINUX) | 57 | elseif((UNIX AND NOT APPLE) OR LINUX) |
| 1 | +// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +// Copyright (c) 2024 Luo Xiao | ||
| 5 | + | ||
| 6 | +// | ||
| 7 | +// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C | ||
| 8 | +// and with tokens and hotwords loaded from buffered strings instead of from external | ||
| 9 | +// files API. | ||
| 10 | +// clang-format off | ||
| 11 | +// | ||
| 12 | +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 13 | +// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 14 | +// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | ||
| 15 | +// | ||
| 16 | +// clang-format on | ||
| 17 | + | ||
| 18 | +#include <stdio.h> | ||
| 19 | +#include <stdlib.h> | ||
| 20 | +#include <string.h> | ||
| 21 | + | ||
| 22 | +#include "sherpa-onnx/c-api/c-api.h" | ||
| 23 | + | ||
| 24 | +static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 25 | + FILE *file = fopen(filename, "rb"); | ||
| 26 | + if (file == NULL) { | ||
| 27 | + fprintf(stderr, "Failed to open %s\n", filename); | ||
| 28 | + return -1; | ||
| 29 | + } | ||
| 30 | + fseek(file, 0L, SEEK_END); | ||
| 31 | + long size = ftell(file); | ||
| 32 | + rewind(file); | ||
| 33 | + *buffer_out = malloc(size); | ||
| 34 | + if (*buffer_out == NULL) { | ||
| 35 | + fclose(file); | ||
| 36 | + fprintf(stderr, "Memory error\n"); | ||
| 37 | + return -1; | ||
| 38 | + } | ||
| 39 | + size_t read_bytes = fread(*buffer_out, 1, size, file); | ||
| 40 | + if (read_bytes != size) { | ||
| 41 | + printf("Errors occured in reading the file %s\n", filename); | ||
| 42 | + free(*buffer_out); | ||
| 43 | + *buffer_out = NULL; | ||
| 44 | + fclose(file); | ||
| 45 | + return -1; | ||
| 46 | + } | ||
| 47 | + fclose(file); | ||
| 48 | + return read_bytes; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +int32_t main() { | ||
| 52 | + const char *wav_filename = | ||
| 53 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"; | ||
| 54 | + const char *encoder_filename = | ||
| 55 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" | ||
| 56 | + "encoder-epoch-99-avg-1.onnx"; | ||
| 57 | + const char *decoder_filename = | ||
| 58 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" | ||
| 59 | + "decoder-epoch-99-avg-1.onnx"; | ||
| 60 | + const char *joiner_filename = | ||
| 61 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" | ||
| 62 | + "joiner-epoch-99-avg-1.onnx"; | ||
| 63 | + const char *provider = "cpu"; | ||
| 64 | + const char *modeling_unit = "bpe"; | ||
| 65 | + const char *tokens_filename = | ||
| 66 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt"; | ||
| 67 | + const char *hotwords_filename = | ||
| 68 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt"; | ||
| 69 | + const char *bpe_vocab = | ||
| 70 | + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" | ||
| 71 | + "bpe.vocab"; | ||
| 72 | + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); | ||
| 73 | + if (wave == NULL) { | ||
| 74 | + fprintf(stderr, "Failed to read %s\n", wav_filename); | ||
| 75 | + return -1; | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + // reading tokens and hotwords to buffers | ||
| 79 | + const char *tokens_buf; | ||
| 80 | + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); | ||
| 81 | + if (token_buf_size < 1) { | ||
| 82 | + fprintf(stderr, "Please check your tokens.txt!\n"); | ||
| 83 | + free(tokens_buf); | ||
| 84 | + return -1; | ||
| 85 | + } | ||
| 86 | + const char *hotwords_buf; | ||
| 87 | + size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); | ||
| 88 | + if (hotwords_buf_size < 1) { | ||
| 89 | + fprintf(stderr, "Please check your hotwords.txt!\n"); | ||
| 90 | + free(hotwords_buf); | ||
| 91 | + return -1; | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + // Zipformer config | ||
| 95 | + SherpaOnnxOnlineTransducerModelConfig zipformer_config; | ||
| 96 | + memset(&zipformer_config, 0, sizeof(zipformer_config)); | ||
| 97 | + zipformer_config.encoder = encoder_filename; | ||
| 98 | + zipformer_config.decoder = decoder_filename; | ||
| 99 | + zipformer_config.joiner = joiner_filename; | ||
| 100 | + | ||
| 101 | + // Online model config | ||
| 102 | + SherpaOnnxOnlineModelConfig online_model_config; | ||
| 103 | + memset(&online_model_config, 0, sizeof(online_model_config)); | ||
| 104 | + online_model_config.debug = 1; | ||
| 105 | + online_model_config.num_threads = 1; | ||
| 106 | + online_model_config.provider = provider; | ||
| 107 | + online_model_config.tokens_buf = tokens_buf; | ||
| 108 | + online_model_config.tokens_buf_size = token_buf_size; | ||
| 109 | + online_model_config.transducer = zipformer_config; | ||
| 110 | + | ||
| 111 | + // Recognizer config | ||
| 112 | + SherpaOnnxOnlineRecognizerConfig recognizer_config; | ||
| 113 | + memset(&recognizer_config, 0, sizeof(recognizer_config)); | ||
| 114 | + recognizer_config.decoding_method = "modified_beam_search"; | ||
| 115 | + recognizer_config.model_config = online_model_config; | ||
| 116 | + recognizer_config.hotwords_buf = hotwords_buf; | ||
| 117 | + recognizer_config.hotwords_buf_size = hotwords_buf_size; | ||
| 118 | + | ||
| 119 | + SherpaOnnxOnlineRecognizer *recognizer = | ||
| 120 | + SherpaOnnxCreateOnlineRecognizer(&recognizer_config); | ||
| 121 | + | ||
| 122 | + free(tokens_buf); | ||
| 123 | + tokens_buf = NULL; | ||
| 124 | + free(hotwords_buf); | ||
| 125 | + hotwords_buf = NULL; | ||
| 126 | + | ||
| 127 | + if (recognizer == NULL) { | ||
| 128 | + fprintf(stderr, "Please check your config!\n"); | ||
| 129 | + SherpaOnnxFreeWave(wave); | ||
| 130 | + return -1; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); | ||
| 134 | + | ||
| 135 | + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); | ||
| 136 | + int32_t segment_id = 0; | ||
| 137 | + | ||
| 138 | +// simulate streaming. You can choose an arbitrary N | ||
| 139 | +#define N 3200 | ||
| 140 | + | ||
| 141 | + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", | ||
| 142 | + wave->sample_rate, wave->num_samples, | ||
| 143 | + (float)wave->num_samples / wave->sample_rate); | ||
| 144 | + | ||
| 145 | + int32_t k = 0; | ||
| 146 | + while (k < wave->num_samples) { | ||
| 147 | + int32_t start = k; | ||
| 148 | + int32_t end = | ||
| 149 | + (start + N > wave->num_samples) ? wave->num_samples : (start + N); | ||
| 150 | + k += N; | ||
| 151 | + | ||
| 152 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, | ||
| 153 | + wave->samples + start, end - start); | ||
| 154 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 155 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 159 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 160 | + | ||
| 161 | + if (strlen(r->text)) { | ||
| 162 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) { | ||
| 166 | + if (strlen(r->text)) { | ||
| 167 | + ++segment_id; | ||
| 168 | + } | ||
| 169 | + SherpaOnnxOnlineStreamReset(recognizer, stream); | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + // add some tail padding | ||
| 176 | + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate | ||
| 177 | + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, | ||
| 178 | + 4800); | ||
| 179 | + | ||
| 180 | + SherpaOnnxFreeWave(wave); | ||
| 181 | + | ||
| 182 | + SherpaOnnxOnlineStreamInputFinished(stream); | ||
| 183 | + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { | ||
| 184 | + SherpaOnnxDecodeOnlineStream(recognizer, stream); | ||
| 185 | + } | ||
| 186 | + | ||
| 187 | + const SherpaOnnxOnlineRecognizerResult *r = | ||
| 188 | + SherpaOnnxGetOnlineStreamResult(recognizer, stream); | ||
| 189 | + | ||
| 190 | + if (strlen(r->text)) { | ||
| 191 | + SherpaOnnxPrint(display, segment_id, r->text); | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + SherpaOnnxDestroyOnlineRecognizerResult(r); | ||
| 195 | + | ||
| 196 | + SherpaOnnxDestroyDisplay(display); | ||
| 197 | + SherpaOnnxDestroyOnlineStream(stream); | ||
| 198 | + SherpaOnnxDestroyOnlineRecognizer(recognizer); | ||
| 199 | + fprintf(stderr, "\n"); | ||
| 200 | + | ||
| 201 | + return 0; | ||
| 202 | +} |
| @@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( | @@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( | ||
| 73 | 73 | ||
| 74 | recognizer_config.model_config.tokens = | 74 | recognizer_config.model_config.tokens = |
| 75 | SHERPA_ONNX_OR(config->model_config.tokens, ""); | 75 | SHERPA_ONNX_OR(config->model_config.tokens, ""); |
| 76 | + if (config->model_config.tokens_buf && | ||
| 77 | + config->model_config.tokens_buf_size > 0) { | ||
| 78 | + recognizer_config.model_config.tokens_buf = std::string( | ||
| 79 | + config->model_config.tokens_buf, config->model_config.tokens_buf_size); | ||
| 80 | + } | ||
| 81 | + | ||
| 76 | recognizer_config.model_config.num_threads = | 82 | recognizer_config.model_config.num_threads = |
| 77 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); | 83 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); |
| 78 | recognizer_config.model_config.provider_config.provider = | 84 | recognizer_config.model_config.provider_config.provider = |
| @@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( | @@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( | ||
| 120 | recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); | 126 | recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); |
| 121 | recognizer_config.hotwords_score = | 127 | recognizer_config.hotwords_score = |
| 122 | SHERPA_ONNX_OR(config->hotwords_score, 1.5); | 128 | SHERPA_ONNX_OR(config->hotwords_score, 1.5); |
| 129 | + if (config->hotwords_buf && config->hotwords_buf_size > 0) { | ||
| 130 | + recognizer_config.hotwords_buf = | ||
| 131 | + std::string(config->hotwords_buf, config->hotwords_buf_size); | ||
| 132 | + } | ||
| 123 | 133 | ||
| 124 | recognizer_config.blank_penalty = config->blank_penalty; | 134 | recognizer_config.blank_penalty = config->blank_penalty; |
| 125 | 135 |
| @@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | @@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | ||
| 88 | // - cjkchar+bpe | 88 | // - cjkchar+bpe |
| 89 | const char *modeling_unit; | 89 | const char *modeling_unit; |
| 90 | const char *bpe_vocab; | 90 | const char *bpe_vocab; |
| 91 | + /// if non-null, loading the tokens from the buffered string directly in | ||
| 92 | + /// prioriy | ||
| 93 | + const char *tokens_buf; | ||
| 94 | + /// byte size excluding the tailing '\0' | ||
| 95 | + int32_t tokens_buf_size; | ||
| 91 | } SherpaOnnxOnlineModelConfig; | 96 | } SherpaOnnxOnlineModelConfig; |
| 92 | 97 | ||
| 93 | /// It expects 16 kHz 16-bit single channel wave format. | 98 | /// It expects 16 kHz 16-bit single channel wave format. |
| @@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { | @@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { | ||
| 147 | const char *rule_fsts; | 152 | const char *rule_fsts; |
| 148 | const char *rule_fars; | 153 | const char *rule_fars; |
| 149 | float blank_penalty; | 154 | float blank_penalty; |
| 155 | + | ||
| 156 | + /// if non-nullptr, loading the hotwords from the buffered string directly in | ||
| 157 | + const char *hotwords_buf; | ||
| 158 | + /// byte size excluding the tailing '\0' | ||
| 159 | + int32_t hotwords_buf_size; | ||
| 150 | } SherpaOnnxOnlineRecognizerConfig; | 160 | } SherpaOnnxOnlineRecognizerConfig; |
| 151 | 161 | ||
| 152 | SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { | 162 | SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { |
| @@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const { | @@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const { | ||
| 56 | return false; | 56 | return false; |
| 57 | } | 57 | } |
| 58 | 58 | ||
| 59 | - if (!FileExists(tokens)) { | ||
| 60 | - SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); | 59 | + if (!tokens_buf.empty() && FileExists(tokens)) { |
| 60 | + SHERPA_ONNX_LOGE( | ||
| 61 | + "you can not provide a tokens_buf and a tokens file: '%s', " | ||
| 62 | + "at the same time, which is confusing", | ||
| 63 | + tokens.c_str()); | ||
| 64 | + return false; | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + if (tokens_buf.empty() && !FileExists(tokens)) { | ||
| 68 | + SHERPA_ONNX_LOGE( | ||
| 69 | + "tokens: '%s' does not exist, you should provide " | ||
| 70 | + "either a tokens buffer or a tokens file", | ||
| 71 | + tokens.c_str()); | ||
| 61 | return false; | 72 | return false; |
| 62 | } | 73 | } |
| 63 | 74 |
| @@ -45,6 +45,11 @@ struct OnlineModelConfig { | @@ -45,6 +45,11 @@ struct OnlineModelConfig { | ||
| 45 | std::string modeling_unit = "cjkchar"; | 45 | std::string modeling_unit = "cjkchar"; |
| 46 | std::string bpe_vocab; | 46 | std::string bpe_vocab; |
| 47 | 47 | ||
| 48 | + /// if tokens_buf is non-empty, | ||
| 49 | + /// the tokens will be loaded from the buffered string instead of from the | ||
| 50 | + /// ${tokens} file | ||
| 51 | + std::string tokens_buf; | ||
| 52 | + | ||
| 48 | OnlineModelConfig() = default; | 53 | OnlineModelConfig() = default; |
| 49 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | 54 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, |
| 50 | const OnlineParaformerModelConfig ¶former, | 55 | const OnlineParaformerModelConfig ¶former, |
| @@ -53,8 +58,7 @@ struct OnlineModelConfig { | @@ -53,8 +58,7 @@ struct OnlineModelConfig { | ||
| 53 | const OnlineNeMoCtcModelConfig &nemo_ctc, | 58 | const OnlineNeMoCtcModelConfig &nemo_ctc, |
| 54 | const ProviderConfig &provider_config, | 59 | const ProviderConfig &provider_config, |
| 55 | const std::string &tokens, int32_t num_threads, | 60 | const std::string &tokens, int32_t num_threads, |
| 56 | - int32_t warm_up, bool debug, | ||
| 57 | - const std::string &model_type, | 61 | + int32_t warm_up, bool debug, const std::string &model_type, |
| 58 | const std::string &modeling_unit, | 62 | const std::string &modeling_unit, |
| 59 | const std::string &bpe_vocab) | 63 | const std::string &bpe_vocab) |
| 60 | : transducer(transducer), | 64 | : transducer(transducer), |
| @@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 83 | : OnlineRecognizerImpl(config), | 83 | : OnlineRecognizerImpl(config), |
| 84 | config_(config), | 84 | config_(config), |
| 85 | model_(OnlineTransducerModel::Create(config.model_config)), | 85 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 86 | - sym_(config.model_config.tokens), | ||
| 87 | endpoint_(config_.endpoint_config) { | 86 | endpoint_(config_.endpoint_config) { |
| 87 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 88 | + sym_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 89 | + } else { | ||
| 90 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 91 | + sym_ = SymbolTable(config.model_config.tokens, true); | ||
| 92 | + } | ||
| 93 | + | ||
| 88 | if (sym_.Contains("<unk>")) { | 94 | if (sym_.Contains("<unk>")) { |
| 89 | unk_id_ = sym_["<unk>"]; | 95 | unk_id_ = sym_["<unk>"]; |
| 90 | } | 96 | } |
| @@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 97 | config_.model_config.bpe_vocab); | 103 | config_.model_config.bpe_vocab); |
| 98 | } | 104 | } |
| 99 | 105 | ||
| 100 | - if (!config_.hotwords_file.empty()) { | 106 | + if (!config_.hotwords_buf.empty()) { |
| 107 | + InitHotwordsFromBufStr(); | ||
| 108 | + } else if (!config_.hotwords_file.empty()) { | ||
| 101 | InitHotwords(); | 109 | InitHotwords(); |
| 102 | } | 110 | } |
| 103 | 111 | ||
| @@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 108 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 116 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 109 | model_.get(), lm_.get(), config_.max_active_paths, | 117 | model_.get(), lm_.get(), config_.max_active_paths, |
| 110 | config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, | 118 | config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, |
| 111 | - config_.blank_penalty, | ||
| 112 | - config_.temperature_scale); | 119 | + config_.blank_penalty, config_.temperature_scale); |
| 113 | 120 | ||
| 114 | } else if (config.decoding_method == "greedy_search") { | 121 | } else if (config.decoding_method == "greedy_search") { |
| 115 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 122 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| @@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 158 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( | 165 | decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( |
| 159 | model_.get(), lm_.get(), config_.max_active_paths, | 166 | model_.get(), lm_.get(), config_.max_active_paths, |
| 160 | config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, | 167 | config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, |
| 161 | - config_.blank_penalty, | ||
| 162 | - config_.temperature_scale); | 168 | + config_.blank_penalty, config_.temperature_scale); |
| 163 | 169 | ||
| 164 | } else if (config.decoding_method == "greedy_search") { | 170 | } else if (config.decoding_method == "greedy_search") { |
| 165 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( | 171 | decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( |
| @@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 446 | } | 452 | } |
| 447 | #endif | 453 | #endif |
| 448 | 454 | ||
| 455 | + void InitHotwordsFromBufStr() { | ||
| 456 | + // each line in hotwords_file contains space-separated words | ||
| 457 | + | ||
| 458 | + std::istringstream iss(config_.hotwords_buf); | ||
| 459 | + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_, | ||
| 460 | + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { | ||
| 461 | + SHERPA_ONNX_LOGE( | ||
| 462 | + "Failed to encode some hotwords, skip them already, see logs above " | ||
| 463 | + "for details."); | ||
| 464 | + } | ||
| 465 | + hotwords_graph_ = std::make_shared<ContextGraph>( | ||
| 466 | + hotwords_, config_.hotwords_score, boost_scores_); | ||
| 467 | + } | ||
| 468 | + | ||
| 449 | void InitOnlineStream(OnlineStream *stream) const { | 469 | void InitOnlineStream(OnlineStream *stream) const { |
| 450 | auto r = decoder_->GetEmptyResult(); | 470 | auto r = decoder_->GetEmptyResult(); |
| 451 | 471 |
| @@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 44 | const OnlineRecognizerConfig &config) | 44 | const OnlineRecognizerConfig &config) |
| 45 | : OnlineRecognizerImpl(config), | 45 | : OnlineRecognizerImpl(config), |
| 46 | config_(config), | 46 | config_(config), |
| 47 | - symbol_table_(config.model_config.tokens), | ||
| 48 | endpoint_(config_.endpoint_config), | 47 | endpoint_(config_.endpoint_config), |
| 49 | model_( | 48 | model_( |
| 50 | std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) { | 49 | std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) { |
| 50 | + if (!config.model_config.tokens_buf.empty()) { | ||
| 51 | + symbol_table_ = SymbolTable(config.model_config.tokens_buf, false); | ||
| 52 | + } else { | ||
| 53 | + /// assuming tokens_buf and tokens are guaranteed not being both empty | ||
| 54 | + symbol_table_ = SymbolTable(config.model_config.tokens, true); | ||
| 55 | + } | ||
| 56 | + | ||
| 51 | if (config.decoding_method == "greedy_search") { | 57 | if (config.decoding_method == "greedy_search") { |
| 52 | decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( | 58 | decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( |
| 53 | model_.get(), config_.blank_penalty); | 59 | model_.get(), config_.blank_penalty); |
| @@ -106,6 +106,11 @@ struct OnlineRecognizerConfig { | @@ -106,6 +106,11 @@ struct OnlineRecognizerConfig { | ||
| 106 | // If there are multiple FST archives, they are applied from left to right. | 106 | // If there are multiple FST archives, they are applied from left to right. |
| 107 | std::string rule_fars; | 107 | std::string rule_fars; |
| 108 | 108 | ||
| 109 | + /// used only for modified_beam_search, if hotwords_buf is non-empty, | ||
| 110 | + /// the hotwords will be loaded from the buffered string instead of from | ||
| 111 | + /// ${hotwords_file} | ||
| 112 | + std::string hotwords_buf; | ||
| 113 | + | ||
| 109 | OnlineRecognizerConfig() = default; | 114 | OnlineRecognizerConfig() = default; |
| 110 | 115 | ||
| 111 | OnlineRecognizerConfig( | 116 | OnlineRecognizerConfig( |
| @@ -20,9 +20,14 @@ | @@ -20,9 +20,14 @@ | ||
| 20 | 20 | ||
| 21 | namespace sherpa_onnx { | 21 | namespace sherpa_onnx { |
| 22 | 22 | ||
| 23 | -SymbolTable::SymbolTable(const std::string &filename) { | ||
| 24 | - std::ifstream is(filename); | ||
| 25 | - Init(is); | 23 | +SymbolTable::SymbolTable(const std::string &filename, bool is_file) { |
| 24 | + if (is_file) { | ||
| 25 | + std::ifstream is(filename); | ||
| 26 | + Init(is); | ||
| 27 | + } else { | ||
| 28 | + std::istringstream iss(filename); | ||
| 29 | + Init(iss); | ||
| 30 | + } | ||
| 26 | } | 31 | } |
| 27 | 32 | ||
| 28 | #if __ANDROID_API__ >= 9 | 33 | #if __ANDROID_API__ >= 9 |
| @@ -19,13 +19,13 @@ namespace sherpa_onnx { | @@ -19,13 +19,13 @@ namespace sherpa_onnx { | ||
| 19 | class SymbolTable { | 19 | class SymbolTable { |
| 20 | public: | 20 | public: |
| 21 | SymbolTable() = default; | 21 | SymbolTable() = default; |
| 22 | - /// Construct a symbol table from a file. | 22 | + /// Construct a symbol table from a file or from a buffered string. |
| 23 | /// Each line in the file contains two fields: | 23 | /// Each line in the file contains two fields: |
| 24 | /// | 24 | /// |
| 25 | /// sym ID | 25 | /// sym ID |
| 26 | /// | 26 | /// |
| 27 | /// Fields are separated by space(s). | 27 | /// Fields are separated by space(s). |
| 28 | - explicit SymbolTable(const std::string &filename); | 28 | + explicit SymbolTable(const std::string &filename, bool is_file = true); |
| 29 | 29 | ||
| 30 | #if __ANDROID_API__ >= 9 | 30 | #if __ANDROID_API__ >= 9 |
| 31 | SymbolTable(AAssetManager *mgr, const std::string &filename); | 31 | SymbolTable(AAssetManager *mgr, const std::string &filename); |
-
请 注册 或 登录 后发表评论