lxiao336
Committed by GitHub

re-pull-request allow tokens and hotwords be loaded from buffered string driectly (#1339)

Co-authored-by: xiao <shawl336@163.com>
  1 +name: c-api-test-loading-tokens-hotwords-from-memory
  2 +
  3 +on:
  4 + push:
  5 + branches:
  6 + - master
  7 + tags:
  8 + - 'v[0-9]+.[0-9]+.[0-9]+*'
  9 + paths:
  10 + - '.github/workflows/c-api.yaml'
  11 + - 'CMakeLists.txt'
  12 + - 'cmake/**'
  13 + - 'sherpa-onnx/csrc/*'
  14 + - 'sherpa-onnx/c-api/*'
  15 + - 'c-api-examples/**'
  16 + - 'ffmpeg-examples/**'
  17 + pull_request:
  18 + branches:
  19 + - master
  20 + paths:
  21 + - '.github/workflows/c-api.yaml'
  22 + - 'CMakeLists.txt'
  23 + - 'cmake/**'
  24 + - 'sherpa-onnx/csrc/*'
  25 + - 'sherpa-onnx/c-api/*'
  26 + - 'c-api-examples/**'
  27 + - 'ffmpeg-examples/**'
  28 +
  29 + workflow_dispatch:
  30 +
  31 +concurrency:
  32 + group: c-api-${{ github.ref }}
  33 + cancel-in-progress: true
  34 +
  35 +jobs:
  36 + c_api:
  37 + name: ${{ matrix.os }}
  38 + runs-on: ${{ matrix.os }}
  39 + strategy:
  40 + fail-fast: false
  41 + matrix:
  42 + os: [ubuntu-latest, macos-latest]
  43 +
  44 + steps:
  45 + - uses: actions/checkout@v4
  46 + with:
  47 + fetch-depth: 0
  48 +
  49 + - name: ccache
  50 + uses: hendrikmuhs/ccache-action@v1.2
  51 + with:
  52 + key: ${{ matrix.os }}-c-api-shared
  53 +
  54 + - name: Build sherpa-onnx
  55 + shell: bash
  56 + run: |
  57 + export CMAKE_CXX_COMPILER_LAUNCHER=ccache
  58 + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
  59 + cmake --version
  60 +
  61 + mkdir build
  62 + cd build
  63 +
  64 + cmake \
  65 + -D CMAKE_BUILD_TYPE=Release \
  66 + -D BUILD_SHARED_LIBS=ON \
  67 + -D CMAKE_INSTALL_PREFIX=./install \
  68 + -D SHERPA_ONNX_ENABLE_BINARY=OFF \
  69 + ..
  70 +
  71 + make -j2 install
  72 +
  73 + ls -lh install/lib
  74 + ls -lh install/include
  75 +
  76 + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
  77 + ldd ./install/lib/libsherpa-onnx-c-api.so
  78 + echo "---"
  79 + readelf -d ./install/lib/libsherpa-onnx-c-api.so
  80 + fi
  81 +
  82 + if [[ ${{ matrix.os }} == macos-latest ]]; then
  83 + otool -L ./install/lib/libsherpa-onnx-c-api.dylib
  84 + fi
  85 +
  86 + - name: Test streaming zipformer with tokens and hotwords loaded from buffers
  87 + shell: bash
  88 + run: |
  89 + gcc -o streaming-zipformer-buffered-tokens-hotwords-c-api ./c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c \
  90 + -I ./build/install/include \
  91 + -L ./build/install/lib/ \
  92 + -l sherpa-onnx-c-api \
  93 + -l onnxruntime
  94 +
  95 + ls -lh streaming-zipformer-buffered-tokens-hotwords-c-api
  96 +
  97 + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
  98 + ldd ./streaming-zipformer-buffered-tokens-hotwords-c-api
  99 + echo "----"
  100 + readelf -d ./streaming-zipformer-buffered-tokens-hotwords-c-api
  101 + fi
  102 +
  103 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  104 + tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  105 + rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  106 + curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model
  107 + cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/
  108 + rm bpe.model
  109 +
  110 + printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt
  111 +
  112 + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17
  113 + echo "---"
  114 + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs
  115 +
  116 + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
  117 + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
  118 +
  119 + ./streaming-zipformer-buffered-tokens-hotwords-c-api
  120 +
  121 + rm -rf sherpa-onnx-streaming-zipformer-*
@@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api) @@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api)
48 add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c) 48 add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c)
49 target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api) 49 target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api)
50 50
  51 +add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api
  52 + streaming-zipformer-buffered-tokens-hotwords-c-api.c)
  53 +target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api)
  54 +
51 if(SHERPA_ONNX_HAS_ALSA) 55 if(SHERPA_ONNX_HAS_ALSA)
52 add_subdirectory(./asr-microphone-example) 56 add_subdirectory(./asr-microphone-example)
53 elseif((UNIX AND NOT APPLE) OR LINUX) 57 elseif((UNIX AND NOT APPLE) OR LINUX)
  1 +// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +// Copyright (c) 2024 Luo Xiao
  5 +
  6 +//
  7 +// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C
  8 +// and with tokens and hotwords loaded from buffered strings instead of from external
  9 +// files API.
  10 +// clang-format off
  11 +//
  12 +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  13 +// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  14 +// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
  15 +//
  16 +// clang-format on
  17 +
  18 +#include <stdio.h>
  19 +#include <stdlib.h>
  20 +#include <string.h>
  21 +
  22 +#include "sherpa-onnx/c-api/c-api.h"
  23 +
  24 +static size_t ReadFile(const char *filename, const char **buffer_out) {
  25 + FILE *file = fopen(filename, "rb");
  26 + if (file == NULL) {
  27 + fprintf(stderr, "Failed to open %s\n", filename);
  28 + return -1;
  29 + }
  30 + fseek(file, 0L, SEEK_END);
  31 + long size = ftell(file);
  32 + rewind(file);
  33 + *buffer_out = malloc(size);
  34 + if (*buffer_out == NULL) {
  35 + fclose(file);
  36 + fprintf(stderr, "Memory error\n");
  37 + return -1;
  38 + }
  39 + size_t read_bytes = fread(*buffer_out, 1, size, file);
  40 + if (read_bytes != size) {
  41 + printf("Errors occured in reading the file %s\n", filename);
  42 + free(*buffer_out);
  43 + *buffer_out = NULL;
  44 + fclose(file);
  45 + return -1;
  46 + }
  47 + fclose(file);
  48 + return read_bytes;
  49 +}
  50 +
  51 +int32_t main() {
  52 + const char *wav_filename =
  53 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav";
  54 + const char *encoder_filename =
  55 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
  56 + "encoder-epoch-99-avg-1.onnx";
  57 + const char *decoder_filename =
  58 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
  59 + "decoder-epoch-99-avg-1.onnx";
  60 + const char *joiner_filename =
  61 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
  62 + "joiner-epoch-99-avg-1.onnx";
  63 + const char *provider = "cpu";
  64 + const char *modeling_unit = "bpe";
  65 + const char *tokens_filename =
  66 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt";
  67 + const char *hotwords_filename =
  68 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt";
  69 + const char *bpe_vocab =
  70 + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
  71 + "bpe.vocab";
  72 + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
  73 + if (wave == NULL) {
  74 + fprintf(stderr, "Failed to read %s\n", wav_filename);
  75 + return -1;
  76 + }
  77 +
  78 + // reading tokens and hotwords to buffers
  79 + const char *tokens_buf;
  80 + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf);
  81 + if (token_buf_size < 1) {
  82 + fprintf(stderr, "Please check your tokens.txt!\n");
  83 + free(tokens_buf);
  84 + return -1;
  85 + }
  86 + const char *hotwords_buf;
  87 + size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf);
  88 + if (hotwords_buf_size < 1) {
  89 + fprintf(stderr, "Please check your hotwords.txt!\n");
  90 + free(hotwords_buf);
  91 + return -1;
  92 + }
  93 +
  94 + // Zipformer config
  95 + SherpaOnnxOnlineTransducerModelConfig zipformer_config;
  96 + memset(&zipformer_config, 0, sizeof(zipformer_config));
  97 + zipformer_config.encoder = encoder_filename;
  98 + zipformer_config.decoder = decoder_filename;
  99 + zipformer_config.joiner = joiner_filename;
  100 +
  101 + // Online model config
  102 + SherpaOnnxOnlineModelConfig online_model_config;
  103 + memset(&online_model_config, 0, sizeof(online_model_config));
  104 + online_model_config.debug = 1;
  105 + online_model_config.num_threads = 1;
  106 + online_model_config.provider = provider;
  107 + online_model_config.tokens_buf = tokens_buf;
  108 + online_model_config.tokens_buf_size = token_buf_size;
  109 + online_model_config.transducer = zipformer_config;
  110 +
  111 + // Recognizer config
  112 + SherpaOnnxOnlineRecognizerConfig recognizer_config;
  113 + memset(&recognizer_config, 0, sizeof(recognizer_config));
  114 + recognizer_config.decoding_method = "modified_beam_search";
  115 + recognizer_config.model_config = online_model_config;
  116 + recognizer_config.hotwords_buf = hotwords_buf;
  117 + recognizer_config.hotwords_buf_size = hotwords_buf_size;
  118 +
  119 + SherpaOnnxOnlineRecognizer *recognizer =
  120 + SherpaOnnxCreateOnlineRecognizer(&recognizer_config);
  121 +
  122 + free(tokens_buf);
  123 + tokens_buf = NULL;
  124 + free(hotwords_buf);
  125 + hotwords_buf = NULL;
  126 +
  127 + if (recognizer == NULL) {
  128 + fprintf(stderr, "Please check your config!\n");
  129 + SherpaOnnxFreeWave(wave);
  130 + return -1;
  131 + }
  132 +
  133 + SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer);
  134 +
  135 + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50);
  136 + int32_t segment_id = 0;
  137 +
  138 +// simulate streaming. You can choose an arbitrary N
  139 +#define N 3200
  140 +
  141 + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n",
  142 + wave->sample_rate, wave->num_samples,
  143 + (float)wave->num_samples / wave->sample_rate);
  144 +
  145 + int32_t k = 0;
  146 + while (k < wave->num_samples) {
  147 + int32_t start = k;
  148 + int32_t end =
  149 + (start + N > wave->num_samples) ? wave->num_samples : (start + N);
  150 + k += N;
  151 +
  152 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate,
  153 + wave->samples + start, end - start);
  154 + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) {
  155 + SherpaOnnxDecodeOnlineStream(recognizer, stream);
  156 + }
  157 +
  158 + const SherpaOnnxOnlineRecognizerResult *r =
  159 + SherpaOnnxGetOnlineStreamResult(recognizer, stream);
  160 +
  161 + if (strlen(r->text)) {
  162 + SherpaOnnxPrint(display, segment_id, r->text);
  163 + }
  164 +
  165 + if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) {
  166 + if (strlen(r->text)) {
  167 + ++segment_id;
  168 + }
  169 + SherpaOnnxOnlineStreamReset(recognizer, stream);
  170 + }
  171 +
  172 + SherpaOnnxDestroyOnlineRecognizerResult(r);
  173 + }
  174 +
  175 + // add some tail padding
  176 + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
  177 + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings,
  178 + 4800);
  179 +
  180 + SherpaOnnxFreeWave(wave);
  181 +
  182 + SherpaOnnxOnlineStreamInputFinished(stream);
  183 + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) {
  184 + SherpaOnnxDecodeOnlineStream(recognizer, stream);
  185 + }
  186 +
  187 + const SherpaOnnxOnlineRecognizerResult *r =
  188 + SherpaOnnxGetOnlineStreamResult(recognizer, stream);
  189 +
  190 + if (strlen(r->text)) {
  191 + SherpaOnnxPrint(display, segment_id, r->text);
  192 + }
  193 +
  194 + SherpaOnnxDestroyOnlineRecognizerResult(r);
  195 +
  196 + SherpaOnnxDestroyDisplay(display);
  197 + SherpaOnnxDestroyOnlineStream(stream);
  198 + SherpaOnnxDestroyOnlineRecognizer(recognizer);
  199 + fprintf(stderr, "\n");
  200 +
  201 + return 0;
  202 +}
@@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( @@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
73 73
74 recognizer_config.model_config.tokens = 74 recognizer_config.model_config.tokens =
75 SHERPA_ONNX_OR(config->model_config.tokens, ""); 75 SHERPA_ONNX_OR(config->model_config.tokens, "");
  76 + if (config->model_config.tokens_buf &&
  77 + config->model_config.tokens_buf_size > 0) {
  78 + recognizer_config.model_config.tokens_buf = std::string(
  79 + config->model_config.tokens_buf, config->model_config.tokens_buf_size);
  80 + }
  81 +
76 recognizer_config.model_config.num_threads = 82 recognizer_config.model_config.num_threads =
77 SHERPA_ONNX_OR(config->model_config.num_threads, 1); 83 SHERPA_ONNX_OR(config->model_config.num_threads, 1);
78 recognizer_config.model_config.provider_config.provider = 84 recognizer_config.model_config.provider_config.provider =
@@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( @@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
120 recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); 126 recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
121 recognizer_config.hotwords_score = 127 recognizer_config.hotwords_score =
122 SHERPA_ONNX_OR(config->hotwords_score, 1.5); 128 SHERPA_ONNX_OR(config->hotwords_score, 1.5);
  129 + if (config->hotwords_buf && config->hotwords_buf_size > 0) {
  130 + recognizer_config.hotwords_buf =
  131 + std::string(config->hotwords_buf, config->hotwords_buf_size);
  132 + }
123 133
124 recognizer_config.blank_penalty = config->blank_penalty; 134 recognizer_config.blank_penalty = config->blank_penalty;
125 135
@@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { @@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
88 // - cjkchar+bpe 88 // - cjkchar+bpe
89 const char *modeling_unit; 89 const char *modeling_unit;
90 const char *bpe_vocab; 90 const char *bpe_vocab;
  91 + /// if non-null, loading the tokens from the buffered string directly in
  92 + /// prioriy
  93 + const char *tokens_buf;
  94 + /// byte size excluding the tailing '\0'
  95 + int32_t tokens_buf_size;
91 } SherpaOnnxOnlineModelConfig; 96 } SherpaOnnxOnlineModelConfig;
92 97
93 /// It expects 16 kHz 16-bit single channel wave format. 98 /// It expects 16 kHz 16-bit single channel wave format.
@@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { @@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
147 const char *rule_fsts; 152 const char *rule_fsts;
148 const char *rule_fars; 153 const char *rule_fars;
149 float blank_penalty; 154 float blank_penalty;
  155 +
  156 + /// if non-nullptr, loading the hotwords from the buffered string directly in
  157 + const char *hotwords_buf;
  158 + /// byte size excluding the tailing '\0'
  159 + int32_t hotwords_buf_size;
150 } SherpaOnnxOnlineRecognizerConfig; 160 } SherpaOnnxOnlineRecognizerConfig;
151 161
152 SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { 162 SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
@@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const { @@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const {
56 return false; 56 return false;
57 } 57 }
58 58
59 - if (!FileExists(tokens)) {  
60 - SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); 59 + if (!tokens_buf.empty() && FileExists(tokens)) {
  60 + SHERPA_ONNX_LOGE(
  61 + "you can not provide a tokens_buf and a tokens file: '%s', "
  62 + "at the same time, which is confusing",
  63 + tokens.c_str());
  64 + return false;
  65 + }
  66 +
  67 + if (tokens_buf.empty() && !FileExists(tokens)) {
  68 + SHERPA_ONNX_LOGE(
  69 + "tokens: '%s' does not exist, you should provide "
  70 + "either a tokens buffer or a tokens file",
  71 + tokens.c_str());
61 return false; 72 return false;
62 } 73 }
63 74
@@ -45,6 +45,11 @@ struct OnlineModelConfig { @@ -45,6 +45,11 @@ struct OnlineModelConfig {
45 std::string modeling_unit = "cjkchar"; 45 std::string modeling_unit = "cjkchar";
46 std::string bpe_vocab; 46 std::string bpe_vocab;
47 47
  48 + /// if tokens_buf is non-empty,
  49 + /// the tokens will be loaded from the buffered string instead of from the
  50 + /// ${tokens} file
  51 + std::string tokens_buf;
  52 +
48 OnlineModelConfig() = default; 53 OnlineModelConfig() = default;
49 OnlineModelConfig(const OnlineTransducerModelConfig &transducer, 54 OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
50 const OnlineParaformerModelConfig &paraformer, 55 const OnlineParaformerModelConfig &paraformer,
@@ -53,8 +58,7 @@ struct OnlineModelConfig { @@ -53,8 +58,7 @@ struct OnlineModelConfig {
53 const OnlineNeMoCtcModelConfig &nemo_ctc, 58 const OnlineNeMoCtcModelConfig &nemo_ctc,
54 const ProviderConfig &provider_config, 59 const ProviderConfig &provider_config,
55 const std::string &tokens, int32_t num_threads, 60 const std::string &tokens, int32_t num_threads,
56 - int32_t warm_up, bool debug,  
57 - const std::string &model_type, 61 + int32_t warm_up, bool debug, const std::string &model_type,
58 const std::string &modeling_unit, 62 const std::string &modeling_unit,
59 const std::string &bpe_vocab) 63 const std::string &bpe_vocab)
60 : transducer(transducer), 64 : transducer(transducer),
@@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
83 : OnlineRecognizerImpl(config), 83 : OnlineRecognizerImpl(config),
84 config_(config), 84 config_(config),
85 model_(OnlineTransducerModel::Create(config.model_config)), 85 model_(OnlineTransducerModel::Create(config.model_config)),
86 - sym_(config.model_config.tokens),  
87 endpoint_(config_.endpoint_config) { 86 endpoint_(config_.endpoint_config) {
  87 + if (!config.model_config.tokens_buf.empty()) {
  88 + sym_ = SymbolTable(config.model_config.tokens_buf, false);
  89 + } else {
  90 + /// assuming tokens_buf and tokens are guaranteed not being both empty
  91 + sym_ = SymbolTable(config.model_config.tokens, true);
  92 + }
  93 +
88 if (sym_.Contains("<unk>")) { 94 if (sym_.Contains("<unk>")) {
89 unk_id_ = sym_["<unk>"]; 95 unk_id_ = sym_["<unk>"];
90 } 96 }
@@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
97 config_.model_config.bpe_vocab); 103 config_.model_config.bpe_vocab);
98 } 104 }
99 105
100 - if (!config_.hotwords_file.empty()) { 106 + if (!config_.hotwords_buf.empty()) {
  107 + InitHotwordsFromBufStr();
  108 + } else if (!config_.hotwords_file.empty()) {
101 InitHotwords(); 109 InitHotwords();
102 } 110 }
103 111
@@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
108 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 116 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
109 model_.get(), lm_.get(), config_.max_active_paths, 117 model_.get(), lm_.get(), config_.max_active_paths,
110 config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, 118 config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
111 - config_.blank_penalty,  
112 - config_.temperature_scale); 119 + config_.blank_penalty, config_.temperature_scale);
113 120
114 } else if (config.decoding_method == "greedy_search") { 121 } else if (config.decoding_method == "greedy_search") {
115 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 122 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
158 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 165 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
159 model_.get(), lm_.get(), config_.max_active_paths, 166 model_.get(), lm_.get(), config_.max_active_paths,
160 config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, 167 config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_,
161 - config_.blank_penalty,  
162 - config_.temperature_scale); 168 + config_.blank_penalty, config_.temperature_scale);
163 169
164 } else if (config.decoding_method == "greedy_search") { 170 } else if (config.decoding_method == "greedy_search") {
165 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 171 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
@@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -446,6 +452,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
446 } 452 }
447 #endif 453 #endif
448 454
  455 + void InitHotwordsFromBufStr() {
  456 + // each line in hotwords_file contains space-separated words
  457 +
  458 + std::istringstream iss(config_.hotwords_buf);
  459 + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
  460 + bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
  461 + SHERPA_ONNX_LOGE(
  462 + "Failed to encode some hotwords, skip them already, see logs above "
  463 + "for details.");
  464 + }
  465 + hotwords_graph_ = std::make_shared<ContextGraph>(
  466 + hotwords_, config_.hotwords_score, boost_scores_);
  467 + }
  468 +
449 void InitOnlineStream(OnlineStream *stream) const { 469 void InitOnlineStream(OnlineStream *stream) const {
450 auto r = decoder_->GetEmptyResult(); 470 auto r = decoder_->GetEmptyResult();
451 471
@@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
44 const OnlineRecognizerConfig &config) 44 const OnlineRecognizerConfig &config)
45 : OnlineRecognizerImpl(config), 45 : OnlineRecognizerImpl(config),
46 config_(config), 46 config_(config),
47 - symbol_table_(config.model_config.tokens),  
48 endpoint_(config_.endpoint_config), 47 endpoint_(config_.endpoint_config),
49 model_( 48 model_(
50 std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) { 49 std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
  50 + if (!config.model_config.tokens_buf.empty()) {
  51 + symbol_table_ = SymbolTable(config.model_config.tokens_buf, false);
  52 + } else {
  53 + /// assuming tokens_buf and tokens are guaranteed not being both empty
  54 + symbol_table_ = SymbolTable(config.model_config.tokens, true);
  55 + }
  56 +
51 if (config.decoding_method == "greedy_search") { 57 if (config.decoding_method == "greedy_search") {
52 decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>( 58 decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
53 model_.get(), config_.blank_penalty); 59 model_.get(), config_.blank_penalty);
@@ -106,6 +106,11 @@ struct OnlineRecognizerConfig { @@ -106,6 +106,11 @@ struct OnlineRecognizerConfig {
106 // If there are multiple FST archives, they are applied from left to right. 106 // If there are multiple FST archives, they are applied from left to right.
107 std::string rule_fars; 107 std::string rule_fars;
108 108
  109 + /// used only for modified_beam_search, if hotwords_buf is non-empty,
  110 + /// the hotwords will be loaded from the buffered string instead of from
  111 + /// ${hotwords_file}
  112 + std::string hotwords_buf;
  113 +
109 OnlineRecognizerConfig() = default; 114 OnlineRecognizerConfig() = default;
110 115
111 OnlineRecognizerConfig( 116 OnlineRecognizerConfig(
@@ -20,9 +20,14 @@ @@ -20,9 +20,14 @@
20 20
21 namespace sherpa_onnx { 21 namespace sherpa_onnx {
22 22
23 -SymbolTable::SymbolTable(const std::string &filename) {  
24 - std::ifstream is(filename);  
25 - Init(is); 23 +SymbolTable::SymbolTable(const std::string &filename, bool is_file) {
  24 + if (is_file) {
  25 + std::ifstream is(filename);
  26 + Init(is);
  27 + } else {
  28 + std::istringstream iss(filename);
  29 + Init(iss);
  30 + }
26 } 31 }
27 32
28 #if __ANDROID_API__ >= 9 33 #if __ANDROID_API__ >= 9
@@ -19,13 +19,13 @@ namespace sherpa_onnx { @@ -19,13 +19,13 @@ namespace sherpa_onnx {
19 class SymbolTable { 19 class SymbolTable {
20 public: 20 public:
21 SymbolTable() = default; 21 SymbolTable() = default;
22 - /// Construct a symbol table from a file. 22 + /// Construct a symbol table from a file or from a buffered string.
23 /// Each line in the file contains two fields: 23 /// Each line in the file contains two fields:
24 /// 24 ///
25 /// sym ID 25 /// sym ID
26 /// 26 ///
27 /// Fields are separated by space(s). 27 /// Fields are separated by space(s).
28 - explicit SymbolTable(const std::string &filename); 28 + explicit SymbolTable(const std::string &filename, bool is_file = true);
29 29
30 #if __ANDROID_API__ >= 9 30 #if __ANDROID_API__ >= 9
31 SymbolTable(AAssetManager *mgr, const std::string &filename); 31 SymbolTable(AAssetManager *mgr, const std::string &filename);