正在显示
17 个修改的文件
包含
190 行增加
和
44 行删除
| 1 | -name: c-api-test-loading-tokens-hotwords-from-memory | 1 | +name: c-api-from-memory |
| 2 | 2 | ||
| 3 | on: | 3 | on: |
| 4 | push: | 4 | push: |
| @@ -7,7 +7,7 @@ on: | @@ -7,7 +7,7 @@ on: | ||
| 7 | tags: | 7 | tags: |
| 8 | - 'v[0-9]+.[0-9]+.[0-9]+*' | 8 | - 'v[0-9]+.[0-9]+.[0-9]+*' |
| 9 | paths: | 9 | paths: |
| 10 | - - '.github/workflows/c-api.yaml' | 10 | + - '.github/workflows/c-api-from-buffer.yaml' |
| 11 | - 'CMakeLists.txt' | 11 | - 'CMakeLists.txt' |
| 12 | - 'cmake/**' | 12 | - 'cmake/**' |
| 13 | - 'sherpa-onnx/csrc/*' | 13 | - 'sherpa-onnx/csrc/*' |
| @@ -18,7 +18,7 @@ on: | @@ -18,7 +18,7 @@ on: | ||
| 18 | branches: | 18 | branches: |
| 19 | - master | 19 | - master |
| 20 | paths: | 20 | paths: |
| 21 | - - '.github/workflows/c-api.yaml' | 21 | + - '.github/workflows/c-api-from-buffer.yaml' |
| 22 | - 'CMakeLists.txt' | 22 | - 'CMakeLists.txt' |
| 23 | - 'cmake/**' | 23 | - 'cmake/**' |
| 24 | - 'sherpa-onnx/csrc/*' | 24 | - 'sherpa-onnx/csrc/*' |
| @@ -29,11 +29,11 @@ on: | @@ -29,11 +29,11 @@ on: | ||
| 29 | workflow_dispatch: | 29 | workflow_dispatch: |
| 30 | 30 | ||
| 31 | concurrency: | 31 | concurrency: |
| 32 | - group: c-api-${{ github.ref }} | 32 | + group: c-api-from-buffer-${{ github.ref }} |
| 33 | cancel-in-progress: true | 33 | cancel-in-progress: true |
| 34 | 34 | ||
| 35 | jobs: | 35 | jobs: |
| 36 | - c_api: | 36 | + c_api_from_buffer: |
| 37 | name: ${{ matrix.os }} | 37 | name: ${{ matrix.os }} |
| 38 | runs-on: ${{ matrix.os }} | 38 | runs-on: ${{ matrix.os }} |
| 39 | strategy: | 39 | strategy: |
| @@ -106,8 +106,9 @@ jobs: | @@ -106,8 +106,9 @@ jobs: | ||
| 106 | curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model | 106 | curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model |
| 107 | cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/ | 107 | cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/ |
| 108 | rm bpe.model | 108 | rm bpe.model |
| 109 | - | 109 | + |
| 110 | printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt | 110 | printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt |
| 111 | + mv hotwords.txt ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 | ||
| 111 | 112 | ||
| 112 | ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 | 113 | ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 |
| 113 | echo "---" | 114 | echo "---" |
| @@ -115,7 +116,7 @@ jobs: | @@ -115,7 +116,7 @@ jobs: | ||
| 115 | 116 | ||
| 116 | export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | 117 | export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH |
| 117 | export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | 118 | export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH |
| 118 | - | 119 | + |
| 119 | ./streaming-zipformer-buffered-tokens-hotwords-c-api | 120 | ./streaming-zipformer-buffered-tokens-hotwords-c-api |
| 120 | - | 121 | + |
| 121 | rm -rf sherpa-onnx-streaming-zipformer-* | 122 | rm -rf sherpa-onnx-streaming-zipformer-* |
| @@ -5,8 +5,8 @@ | @@ -5,8 +5,8 @@ | ||
| 5 | 5 | ||
| 6 | // | 6 | // |
| 7 | // This file demonstrates how to use streaming Zipformer with sherpa-onnx's C | 7 | // This file demonstrates how to use streaming Zipformer with sherpa-onnx's C |
| 8 | -// and with tokens and hotwords loaded from buffered strings instead of from external | ||
| 9 | -// files API. | 8 | +// and with tokens and hotwords loaded from buffered strings instead of from |
| 9 | +// external files API. | ||
| 10 | // clang-format off | 10 | // clang-format off |
| 11 | // | 11 | // |
| 12 | // wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 | 12 | // wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 |
| @@ -22,7 +22,7 @@ | @@ -22,7 +22,7 @@ | ||
| 22 | #include "sherpa-onnx/c-api/c-api.h" | 22 | #include "sherpa-onnx/c-api/c-api.h" |
| 23 | 23 | ||
| 24 | static size_t ReadFile(const char *filename, const char **buffer_out) { | 24 | static size_t ReadFile(const char *filename, const char **buffer_out) { |
| 25 | - FILE *file = fopen(filename, "rb"); | 25 | + FILE *file = fopen(filename, "r"); |
| 26 | if (file == NULL) { | 26 | if (file == NULL) { |
| 27 | fprintf(stderr, "Failed to open %s\n", filename); | 27 | fprintf(stderr, "Failed to open %s\n", filename); |
| 28 | return -1; | 28 | return -1; |
| @@ -39,7 +39,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | @@ -39,7 +39,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 39 | size_t read_bytes = fread(*buffer_out, 1, size, file); | 39 | size_t read_bytes = fread(*buffer_out, 1, size, file); |
| 40 | if (read_bytes != size) { | 40 | if (read_bytes != size) { |
| 41 | printf("Errors occured in reading the file %s\n", filename); | 41 | printf("Errors occured in reading the file %s\n", filename); |
| 42 | - free(*buffer_out); | 42 | + free((void *)*buffer_out); |
| 43 | *buffer_out = NULL; | 43 | *buffer_out = NULL; |
| 44 | fclose(file); | 44 | fclose(file); |
| 45 | return -1; | 45 | return -1; |
| @@ -80,14 +80,14 @@ int32_t main() { | @@ -80,14 +80,14 @@ int32_t main() { | ||
| 80 | size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); | 80 | size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); |
| 81 | if (token_buf_size < 1) { | 81 | if (token_buf_size < 1) { |
| 82 | fprintf(stderr, "Please check your tokens.txt!\n"); | 82 | fprintf(stderr, "Please check your tokens.txt!\n"); |
| 83 | - free(tokens_buf); | 83 | + free((void *)tokens_buf); |
| 84 | return -1; | 84 | return -1; |
| 85 | } | 85 | } |
| 86 | const char *hotwords_buf; | 86 | const char *hotwords_buf; |
| 87 | size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); | 87 | size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); |
| 88 | if (hotwords_buf_size < 1) { | 88 | if (hotwords_buf_size < 1) { |
| 89 | fprintf(stderr, "Please check your hotwords.txt!\n"); | 89 | fprintf(stderr, "Please check your hotwords.txt!\n"); |
| 90 | - free(hotwords_buf); | 90 | + free((void *)hotwords_buf); |
| 91 | return -1; | 91 | return -1; |
| 92 | } | 92 | } |
| 93 | 93 | ||
| @@ -119,9 +119,9 @@ int32_t main() { | @@ -119,9 +119,9 @@ int32_t main() { | ||
| 119 | SherpaOnnxOnlineRecognizer *recognizer = | 119 | SherpaOnnxOnlineRecognizer *recognizer = |
| 120 | SherpaOnnxCreateOnlineRecognizer(&recognizer_config); | 120 | SherpaOnnxCreateOnlineRecognizer(&recognizer_config); |
| 121 | 121 | ||
| 122 | - free(tokens_buf); | 122 | + free((void *)tokens_buf); |
| 123 | tokens_buf = NULL; | 123 | tokens_buf = NULL; |
| 124 | - free(hotwords_buf); | 124 | + free((void *)hotwords_buf); |
| 125 | hotwords_buf = NULL; | 125 | hotwords_buf = NULL; |
| 126 | 126 | ||
| 127 | if (recognizer == NULL) { | 127 | if (recognizer == NULL) { |
| @@ -199,4 +199,4 @@ int32_t main() { | @@ -199,4 +199,4 @@ int32_t main() { | ||
| 199 | fprintf(stderr, "\n"); | 199 | fprintf(stderr, "\n"); |
| 200 | 200 | ||
| 201 | return 0; | 201 | return 0; |
| 202 | -} | ||
| 202 | +} |
| @@ -234,6 +234,11 @@ final class SherpaOnnxOnlineModelConfig extends Struct { | @@ -234,6 +234,11 @@ final class SherpaOnnxOnlineModelConfig extends Struct { | ||
| 234 | external Pointer<Utf8> modelingUnit; | 234 | external Pointer<Utf8> modelingUnit; |
| 235 | 235 | ||
| 236 | external Pointer<Utf8> bpeVocab; | 236 | external Pointer<Utf8> bpeVocab; |
| 237 | + | ||
| 238 | + external Pointer<Utf8> tokensBuf; | ||
| 239 | + | ||
| 240 | + @Int32() | ||
| 241 | + external int tokensBufSize; | ||
| 237 | } | 242 | } |
| 238 | 243 | ||
| 239 | final class SherpaOnnxOnlineCtcFstDecoderConfig extends Struct { | 244 | final class SherpaOnnxOnlineCtcFstDecoderConfig extends Struct { |
| @@ -275,6 +280,11 @@ final class SherpaOnnxOnlineRecognizerConfig extends Struct { | @@ -275,6 +280,11 @@ final class SherpaOnnxOnlineRecognizerConfig extends Struct { | ||
| 275 | 280 | ||
| 276 | @Float() | 281 | @Float() |
| 277 | external double blankPenalty; | 282 | external double blankPenalty; |
| 283 | + | ||
| 284 | + external Pointer<Utf8> hotwordsBuf; | ||
| 285 | + | ||
| 286 | + @Int32() | ||
| 287 | + external int hotwordsBufSize; | ||
| 278 | } | 288 | } |
| 279 | 289 | ||
| 280 | final class SherpaOnnxSileroVadModelConfig extends Struct { | 290 | final class SherpaOnnxSileroVadModelConfig extends Struct { |
| @@ -22,6 +22,8 @@ namespace SherpaOnnx | @@ -22,6 +22,8 @@ namespace SherpaOnnx | ||
| 22 | ModelType = ""; | 22 | ModelType = ""; |
| 23 | ModelingUnit = "cjkchar"; | 23 | ModelingUnit = "cjkchar"; |
| 24 | BpeVocab = ""; | 24 | BpeVocab = ""; |
| 25 | + TokensBuf = ""; | ||
| 26 | + TokensBufSize = 0; | ||
| 25 | } | 27 | } |
| 26 | 28 | ||
| 27 | public OnlineTransducerModelConfig Transducer; | 29 | public OnlineTransducerModelConfig Transducer; |
| @@ -48,6 +50,11 @@ namespace SherpaOnnx | @@ -48,6 +50,11 @@ namespace SherpaOnnx | ||
| 48 | 50 | ||
| 49 | [MarshalAs(UnmanagedType.LPStr)] | 51 | [MarshalAs(UnmanagedType.LPStr)] |
| 50 | public string BpeVocab; | 52 | public string BpeVocab; |
| 53 | + | ||
| 54 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 55 | + public string TokensBuf; | ||
| 56 | + | ||
| 57 | + public int TokensBufSize; | ||
| 51 | } | 58 | } |
| 52 | 59 | ||
| 53 | -} | ||
| 60 | +} |
| @@ -26,6 +26,8 @@ namespace SherpaOnnx | @@ -26,6 +26,8 @@ namespace SherpaOnnx | ||
| 26 | RuleFsts = ""; | 26 | RuleFsts = ""; |
| 27 | RuleFars = ""; | 27 | RuleFars = ""; |
| 28 | BlankPenalty = 0.0F; | 28 | BlankPenalty = 0.0F; |
| 29 | + HotwordsBuf = ""; | ||
| 30 | + HotwordsBufSize = 0; | ||
| 29 | } | 31 | } |
| 30 | public FeatureConfig FeatConfig; | 32 | public FeatureConfig FeatConfig; |
| 31 | public OnlineModelConfig ModelConfig; | 33 | public OnlineModelConfig ModelConfig; |
| @@ -72,5 +74,10 @@ namespace SherpaOnnx | @@ -72,5 +74,10 @@ namespace SherpaOnnx | ||
| 72 | public string RuleFars; | 74 | public string RuleFars; |
| 73 | 75 | ||
| 74 | public float BlankPenalty; | 76 | public float BlankPenalty; |
| 77 | + | ||
| 78 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 79 | + public string HotwordsBuf; | ||
| 80 | + | ||
| 81 | + public int HotwordsBufSize; | ||
| 75 | } | 82 | } |
| 76 | } | 83 | } |
| @@ -89,6 +89,8 @@ type OnlineModelConfig struct { | @@ -89,6 +89,8 @@ type OnlineModelConfig struct { | ||
| 89 | ModelType string // Optional. You can specify it for faster model initialization | 89 | ModelType string // Optional. You can specify it for faster model initialization |
| 90 | ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe | 90 | ModelingUnit string // Optional. cjkchar, bpe, cjkchar+bpe |
| 91 | BpeVocab string // Optional. | 91 | BpeVocab string // Optional. |
| 92 | + TokensBuf string // Optional. | ||
| 93 | + TokensBufSize int // Optional. | ||
| 92 | } | 94 | } |
| 93 | 95 | ||
| 94 | // Configuration for the feature extractor | 96 | // Configuration for the feature extractor |
| @@ -133,6 +135,8 @@ type OnlineRecognizerConfig struct { | @@ -133,6 +135,8 @@ type OnlineRecognizerConfig struct { | ||
| 133 | CtcFstDecoderConfig OnlineCtcFstDecoderConfig | 135 | CtcFstDecoderConfig OnlineCtcFstDecoderConfig |
| 134 | RuleFsts string | 136 | RuleFsts string |
| 135 | RuleFars string | 137 | RuleFars string |
| 138 | + HotwordsBuf string | ||
| 139 | + HotwordsBufSize int | ||
| 136 | } | 140 | } |
| 137 | 141 | ||
| 138 | // It contains the recognition result for a online stream. | 142 | // It contains the recognition result for a online stream. |
| @@ -184,6 +188,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | @@ -184,6 +188,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | ||
| 184 | c.model_config.tokens = C.CString(config.ModelConfig.Tokens) | 188 | c.model_config.tokens = C.CString(config.ModelConfig.Tokens) |
| 185 | defer C.free(unsafe.Pointer(c.model_config.tokens)) | 189 | defer C.free(unsafe.Pointer(c.model_config.tokens)) |
| 186 | 190 | ||
| 191 | + c.model_config.tokens_buf = C.CString(config.ModelConfig.TokensBuf) | ||
| 192 | + defer C.free(unsafe.Pointer(c.model_config.tokens_buf)) | ||
| 193 | + | ||
| 194 | + c.model_config.tokens_buf_size = C.int(config.ModelConfig.TokensBufSize) | ||
| 195 | + | ||
| 187 | c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) | 196 | c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) |
| 188 | 197 | ||
| 189 | c.model_config.provider = C.CString(config.ModelConfig.Provider) | 198 | c.model_config.provider = C.CString(config.ModelConfig.Provider) |
| @@ -212,6 +221,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | @@ -212,6 +221,11 @@ func NewOnlineRecognizer(config *OnlineRecognizerConfig) *OnlineRecognizer { | ||
| 212 | c.hotwords_file = C.CString(config.HotwordsFile) | 221 | c.hotwords_file = C.CString(config.HotwordsFile) |
| 213 | defer C.free(unsafe.Pointer(c.hotwords_file)) | 222 | defer C.free(unsafe.Pointer(c.hotwords_file)) |
| 214 | 223 | ||
| 224 | + c.hotwords_buf = C.CString(config.HotwordsBuf) | ||
| 225 | + defer C.free(unsafe.Pointer(c.hotwords_buf)) | ||
| 226 | + | ||
| 227 | + c.hotwords_buf_size = C.int(config.HotwordsBufSize) | ||
| 228 | + | ||
| 215 | c.hotwords_score = C.float(config.HotwordsScore) | 229 | c.hotwords_score = C.float(config.HotwordsScore) |
| 216 | c.blank_penalty = C.float(config.BlankPenalty) | 230 | c.blank_penalty = C.float(config.BlankPenalty) |
| 217 | 231 |
| @@ -120,6 +120,8 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) { | @@ -120,6 +120,8 @@ SherpaOnnxOnlineModelConfig GetOnlineModelConfig(Napi::Object obj) { | ||
| 120 | SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType); | 120 | SHERPA_ONNX_ASSIGN_ATTR_STR(model_type, modelType); |
| 121 | SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit); | 121 | SHERPA_ONNX_ASSIGN_ATTR_STR(modeling_unit, modelingUnit); |
| 122 | SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab); | 122 | SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab); |
| 123 | + SHERPA_ONNX_ASSIGN_ATTR_STR(tokens_buf, tokensBuf); | ||
| 124 | + SHERPA_ONNX_ASSIGN_ATTR_INT32(tokens_buf_size, tokensBufSize); | ||
| 123 | 125 | ||
| 124 | return c; | 126 | return c; |
| 125 | } | 127 | } |
| @@ -192,6 +194,8 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | @@ -192,6 +194,8 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | ||
| 192 | SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fsts, ruleFsts); | 194 | SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fsts, ruleFsts); |
| 193 | SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fars, ruleFars); | 195 | SHERPA_ONNX_ASSIGN_ATTR_STR(rule_fars, ruleFars); |
| 194 | SHERPA_ONNX_ASSIGN_ATTR_FLOAT(blank_penalty, blankPenalty); | 196 | SHERPA_ONNX_ASSIGN_ATTR_FLOAT(blank_penalty, blankPenalty); |
| 197 | + SHERPA_ONNX_ASSIGN_ATTR_STR(hotwords_buf, hotwordsBuf); | ||
| 198 | + SHERPA_ONNX_ASSIGN_ATTR_INT32(hotwords_buf_size, hotwordsBufSize); | ||
| 195 | 199 | ||
| 196 | c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(o); | 200 | c.ctc_fst_decoder_config = GetCtcFstDecoderConfig(o); |
| 197 | 201 | ||
| @@ -241,6 +245,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | @@ -241,6 +245,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | ||
| 241 | delete[] c.model_config.bpe_vocab; | 245 | delete[] c.model_config.bpe_vocab; |
| 242 | } | 246 | } |
| 243 | 247 | ||
| 248 | + if (c.model_config.tokens_buf) { | ||
| 249 | + delete[] c.model_config.tokens_buf; | ||
| 250 | + } | ||
| 251 | + | ||
| 244 | if (c.decoding_method) { | 252 | if (c.decoding_method) { |
| 245 | delete[] c.decoding_method; | 253 | delete[] c.decoding_method; |
| 246 | } | 254 | } |
| @@ -257,6 +265,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | @@ -257,6 +265,10 @@ static Napi::External<SherpaOnnxOnlineRecognizer> CreateOnlineRecognizerWrapper( | ||
| 257 | delete[] c.rule_fars; | 265 | delete[] c.rule_fars; |
| 258 | } | 266 | } |
| 259 | 267 | ||
| 268 | + if (c.hotwords_buf) { | ||
| 269 | + delete[] c.hotwords_buf; | ||
| 270 | + } | ||
| 271 | + | ||
| 260 | if (c.ctc_fst_decoder_config.graph) { | 272 | if (c.ctc_fst_decoder_config.graph) { |
| 261 | delete[] c.ctc_fst_decoder_config.graph; | 273 | delete[] c.ctc_fst_decoder_config.graph; |
| 262 | } | 274 | } |
| @@ -91,7 +91,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | @@ -91,7 +91,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { | ||
| 91 | /// if non-null, loading the tokens from the buffered string directly in | 91 | /// if non-null, loading the tokens from the buffered string directly in |
| 92 | /// prioriy | 92 | /// prioriy |
| 93 | const char *tokens_buf; | 93 | const char *tokens_buf; |
| 94 | - /// byte size excluding the tailing '\0' | 94 | + /// byte size excluding the trailing '\0' |
| 95 | int32_t tokens_buf_size; | 95 | int32_t tokens_buf_size; |
| 96 | } SherpaOnnxOnlineModelConfig; | 96 | } SherpaOnnxOnlineModelConfig; |
| 97 | 97 |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-stream.h" | 5 | #include "sherpa-onnx/csrc/offline-stream.h" |
| 6 | 6 | ||
| 7 | +#include <math.h> | ||
| 8 | + | ||
| 7 | #include <algorithm> | 9 | #include <algorithm> |
| 8 | #include <cassert> | 10 | #include <cassert> |
| 9 | #include <cmath> | 11 | #include <cmath> |
| @@ -245,7 +247,7 @@ class OfflineStream::Impl { | @@ -245,7 +247,7 @@ class OfflineStream::Impl { | ||
| 245 | for (int32_t i = 0; i != n; ++i) { | 247 | for (int32_t i = 0; i != n; ++i) { |
| 246 | float x = p[i]; | 248 | float x = p[i]; |
| 247 | x = (x > amin) ? x : amin; | 249 | x = (x > amin) ? x : amin; |
| 248 | - x = std::log10f(x) * multiplier; | 250 | + x = log10f(x) * multiplier; |
| 249 | 251 | ||
| 250 | max_x = (x > max_x) ? x : max_x; | 252 | max_x = (x > max_x) ? x : max_x; |
| 251 | p[i] = x; | 253 | p[i] = x; |
| @@ -372,7 +372,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -372,7 +372,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 372 | // segment is incremented only when the last | 372 | // segment is incremented only when the last |
| 373 | // result is not empty, contains non-blanks and longer than context_size) | 373 | // result is not empty, contains non-blanks and longer than context_size) |
| 374 | const auto &r = s->GetResult(); | 374 | const auto &r = s->GetResult(); |
| 375 | - if (!r.tokens.empty() && r.tokens.back() != 0 && r.tokens.size() > context_size) { | 375 | + if (!r.tokens.empty() && r.tokens.back() != 0 && |
| 376 | + r.tokens.size() > context_size) { | ||
| 376 | s->GetCurrentSegment() += 1; | 377 | s->GetCurrentSegment() += 1; |
| 377 | } | 378 | } |
| 378 | } | 379 | } |
| @@ -392,7 +393,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -392,7 +393,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 392 | // if last result is not empty, then | 393 | // if last result is not empty, then |
| 393 | // preserve last tokens as the context for next result | 394 | // preserve last tokens as the context for next result |
| 394 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { | 395 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { |
| 395 | - std::vector<int64_t> context(last_result.tokens.end() - context_size, last_result.tokens.end()); | 396 | + std::vector<int64_t> context(last_result.tokens.end() - context_size, |
| 397 | + last_result.tokens.end()); | ||
| 396 | 398 | ||
| 397 | Hypotheses context_hyp({{context, 0}}); | 399 | Hypotheses context_hyp({{context, 0}}); |
| 398 | r.hyps = std::move(context_hyp); | 400 | r.hyps = std::move(context_hyp); |
| @@ -145,6 +145,8 @@ type | @@ -145,6 +145,8 @@ type | ||
| 145 | ModelType: AnsiString; | 145 | ModelType: AnsiString; |
| 146 | ModelingUnit: AnsiString; | 146 | ModelingUnit: AnsiString; |
| 147 | BpeVocab: AnsiString; | 147 | BpeVocab: AnsiString; |
| 148 | + TokensBuf: AnsiString; | ||
| 149 | + TokensBufSize: Integer; | ||
| 148 | function ToString: AnsiString; | 150 | function ToString: AnsiString; |
| 149 | class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineModelConfig); | 151 | class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineModelConfig); |
| 150 | end; | 152 | end; |
| @@ -178,6 +180,8 @@ type | @@ -178,6 +180,8 @@ type | ||
| 178 | RuleFsts: AnsiString; | 180 | RuleFsts: AnsiString; |
| 179 | RuleFars: AnsiString; | 181 | RuleFars: AnsiString; |
| 180 | BlankPenalty: Single; | 182 | BlankPenalty: Single; |
| 183 | + HotwordsBuf: AnsiString; | ||
| 184 | + HotwordsBufSize: Integer; | ||
| 181 | function ToString: AnsiString; | 185 | function ToString: AnsiString; |
| 182 | class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineRecognizerConfig); | 186 | class operator Initialize({$IFDEF FPC}var{$ELSE}out{$ENDIF} Dest: TSherpaOnnxOnlineRecognizerConfig); |
| 183 | end; | 187 | end; |
| @@ -490,6 +494,8 @@ type | @@ -490,6 +494,8 @@ type | ||
| 490 | ModelType: PAnsiChar; | 494 | ModelType: PAnsiChar; |
| 491 | ModelingUnit: PAnsiChar; | 495 | ModelingUnit: PAnsiChar; |
| 492 | BpeVocab: PAnsiChar; | 496 | BpeVocab: PAnsiChar; |
| 497 | + TokensBuf: PAnsiChar; | ||
| 498 | + TokensBufSize: cint32; | ||
| 493 | end; | 499 | end; |
| 494 | SherpaOnnxFeatureConfig = record | 500 | SherpaOnnxFeatureConfig = record |
| 495 | SampleRate: cint32; | 501 | SampleRate: cint32; |
| @@ -514,6 +520,8 @@ type | @@ -514,6 +520,8 @@ type | ||
| 514 | RuleFsts: PAnsiChar; | 520 | RuleFsts: PAnsiChar; |
| 515 | RuleFars: PAnsiChar; | 521 | RuleFars: PAnsiChar; |
| 516 | BlankPenalty: cfloat; | 522 | BlankPenalty: cfloat; |
| 523 | + HotwordsBuf: PAnsiChar; | ||
| 524 | + HotwordsBufSize: cint32; | ||
| 517 | end; | 525 | end; |
| 518 | 526 | ||
| 519 | PSherpaOnnxOnlineRecognizerConfig = ^SherpaOnnxOnlineRecognizerConfig; | 527 | PSherpaOnnxOnlineRecognizerConfig = ^SherpaOnnxOnlineRecognizerConfig; |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/online-punctuation.h" | 5 | #include "sherpa-onnx/python/csrc/online-punctuation.h" |
| 6 | 6 | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 7 | #include "sherpa-onnx/csrc/online-punctuation.h" | 9 | #include "sherpa-onnx/csrc/online-punctuation.h" |
| 8 | 10 | ||
| 9 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| @@ -12,9 +14,11 @@ static void PybindOnlinePunctuationModelConfig(py::module *m) { | @@ -12,9 +14,11 @@ static void PybindOnlinePunctuationModelConfig(py::module *m) { | ||
| 12 | using PyClass = OnlinePunctuationModelConfig; | 14 | using PyClass = OnlinePunctuationModelConfig; |
| 13 | py::class_<PyClass>(*m, "OnlinePunctuationModelConfig") | 15 | py::class_<PyClass>(*m, "OnlinePunctuationModelConfig") |
| 14 | .def(py::init<>()) | 16 | .def(py::init<>()) |
| 15 | - .def(py::init<const std::string &, const std::string &, int32_t, bool, const std::string &>(), | ||
| 16 | - py::arg("cnn_bilstm"), py::arg("bpe_vocab"), py::arg("num_threads") = 1, | ||
| 17 | - py::arg("debug") = false, py::arg("provider") = "cpu") | 17 | + .def(py::init<const std::string &, const std::string &, int32_t, bool, |
| 18 | + const std::string &>(), | ||
| 19 | + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), | ||
| 20 | + py::arg("num_threads") = 1, py::arg("debug") = false, | ||
| 21 | + py::arg("provider") = "cpu") | ||
| 18 | .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) | 22 | .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) |
| 19 | .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) | 23 | .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) |
| 20 | .def_readwrite("num_threads", &PyClass::num_threads) | 24 | .def_readwrite("num_threads", &PyClass::num_threads) |
| @@ -30,7 +34,8 @@ static void PybindOnlinePunctuationConfig(py::module *m) { | @@ -30,7 +34,8 @@ static void PybindOnlinePunctuationConfig(py::module *m) { | ||
| 30 | 34 | ||
| 31 | py::class_<PyClass>(*m, "OnlinePunctuationConfig") | 35 | py::class_<PyClass>(*m, "OnlinePunctuationConfig") |
| 32 | .def(py::init<>()) | 36 | .def(py::init<>()) |
| 33 | - .def(py::init<const OnlinePunctuationModelConfig &>(), py::arg("model_config")) | 37 | + .def(py::init<const OnlinePunctuationModelConfig &>(), |
| 38 | + py::arg("model_config")) | ||
| 34 | .def_readwrite("model_config", &PyClass::model) | 39 | .def_readwrite("model_config", &PyClass::model) |
| 35 | .def("validate", &PyClass::Validate) | 40 | .def("validate", &PyClass::Validate) |
| 36 | .def("__str__", &PyClass::ToString); | 41 | .def("__str__", &PyClass::ToString); |
| @@ -43,8 +48,8 @@ void PybindOnlinePunctuation(py::module *m) { | @@ -43,8 +48,8 @@ void PybindOnlinePunctuation(py::module *m) { | ||
| 43 | py::class_<PyClass>(*m, "OnlinePunctuation") | 48 | py::class_<PyClass>(*m, "OnlinePunctuation") |
| 44 | .def(py::init<const OnlinePunctuationConfig &>(), py::arg("config"), | 49 | .def(py::init<const OnlinePunctuationConfig &>(), py::arg("config"), |
| 45 | py::call_guard<py::gil_scoped_release>()) | 50 | py::call_guard<py::gil_scoped_release>()) |
| 46 | - .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, py::arg("text"), | ||
| 47 | - py::call_guard<py::gil_scoped_release>()); | 51 | + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, |
| 52 | + py::arg("text"), py::call_guard<py::gil_scoped_release>()); | ||
| 48 | } | 53 | } |
| 49 | 54 | ||
| 50 | } // namespace sherpa_onnx | 55 | } // namespace sherpa_onnx |
| @@ -90,7 +90,9 @@ func sherpaOnnxOnlineModelConfig( | @@ -90,7 +90,9 @@ func sherpaOnnxOnlineModelConfig( | ||
| 90 | debug: Int = 0, | 90 | debug: Int = 0, |
| 91 | modelType: String = "", | 91 | modelType: String = "", |
| 92 | modelingUnit: String = "cjkchar", | 92 | modelingUnit: String = "cjkchar", |
| 93 | - bpeVocab: String = "" | 93 | + bpeVocab: String = "", |
| 94 | + tokensBuf: String = "", | ||
| 95 | + tokensBufSize: Int = 0 | ||
| 94 | ) -> SherpaOnnxOnlineModelConfig { | 96 | ) -> SherpaOnnxOnlineModelConfig { |
| 95 | return SherpaOnnxOnlineModelConfig( | 97 | return SherpaOnnxOnlineModelConfig( |
| 96 | transducer: transducer, | 98 | transducer: transducer, |
| @@ -102,7 +104,9 @@ func sherpaOnnxOnlineModelConfig( | @@ -102,7 +104,9 @@ func sherpaOnnxOnlineModelConfig( | ||
| 102 | debug: Int32(debug), | 104 | debug: Int32(debug), |
| 103 | model_type: toCPointer(modelType), | 105 | model_type: toCPointer(modelType), |
| 104 | modeling_unit: toCPointer(modelingUnit), | 106 | modeling_unit: toCPointer(modelingUnit), |
| 105 | - bpe_vocab: toCPointer(bpeVocab) | 107 | + bpe_vocab: toCPointer(bpeVocab), |
| 108 | + tokens_buf: toCPointer(tokensBuf), | ||
| 109 | + tokens_buf_size: Int32(tokensBufSize) | ||
| 106 | ) | 110 | ) |
| 107 | } | 111 | } |
| 108 | 112 | ||
| @@ -138,7 +142,9 @@ func sherpaOnnxOnlineRecognizerConfig( | @@ -138,7 +142,9 @@ func sherpaOnnxOnlineRecognizerConfig( | ||
| 138 | ctcFstDecoderConfig: SherpaOnnxOnlineCtcFstDecoderConfig = sherpaOnnxOnlineCtcFstDecoderConfig(), | 142 | ctcFstDecoderConfig: SherpaOnnxOnlineCtcFstDecoderConfig = sherpaOnnxOnlineCtcFstDecoderConfig(), |
| 139 | ruleFsts: String = "", | 143 | ruleFsts: String = "", |
| 140 | ruleFars: String = "", | 144 | ruleFars: String = "", |
| 141 | - blankPenalty: Float = 0.0 | 145 | + blankPenalty: Float = 0.0, |
| 146 | + hotwordsBuf: String = "", | ||
| 147 | + hotwordsBufSize: Int = 0 | ||
| 142 | ) -> SherpaOnnxOnlineRecognizerConfig { | 148 | ) -> SherpaOnnxOnlineRecognizerConfig { |
| 143 | return SherpaOnnxOnlineRecognizerConfig( | 149 | return SherpaOnnxOnlineRecognizerConfig( |
| 144 | feat_config: featConfig, | 150 | feat_config: featConfig, |
| @@ -154,7 +160,9 @@ func sherpaOnnxOnlineRecognizerConfig( | @@ -154,7 +160,9 @@ func sherpaOnnxOnlineRecognizerConfig( | ||
| 154 | ctc_fst_decoder_config: ctcFstDecoderConfig, | 160 | ctc_fst_decoder_config: ctcFstDecoderConfig, |
| 155 | rule_fsts: toCPointer(ruleFsts), | 161 | rule_fsts: toCPointer(ruleFsts), |
| 156 | rule_fars: toCPointer(ruleFars), | 162 | rule_fars: toCPointer(ruleFars), |
| 157 | - blank_penalty: blankPenalty | 163 | + blank_penalty: blankPenalty, |
| 164 | + hotwords_buf: toCPointer(hotwordsBuf), | ||
| 165 | + hotwords_buf_size: Int32(hotwordsBufSize) | ||
| 158 | ) | 166 | ) |
| 159 | } | 167 | } |
| 160 | 168 |
| @@ -155,6 +155,14 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | @@ -155,6 +155,14 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | ||
| 155 | }; | 155 | }; |
| 156 | } | 156 | } |
| 157 | 157 | ||
| 158 | + if (!('tokensBuf' in config)) { | ||
| 159 | + config.tokensBuf = ''; | ||
| 160 | + } | ||
| 161 | + | ||
| 162 | + if (!('tokensBufSize' in config)) { | ||
| 163 | + config.tokensBufSize = 0; | ||
| 164 | + } | ||
| 165 | + | ||
| 158 | const transducer = | 166 | const transducer = |
| 159 | initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); | 167 | initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); |
| 160 | 168 | ||
| @@ -164,7 +172,7 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | @@ -164,7 +172,7 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | ||
| 164 | const ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig( | 172 | const ctc = initSherpaOnnxOnlineZipformer2CtcModelConfig( |
| 165 | config.zipformer2Ctc, Module); | 173 | config.zipformer2Ctc, Module); |
| 166 | 174 | ||
| 167 | - const len = transducer.len + paraformer.len + ctc.len + 7 * 4; | 175 | + const len = transducer.len + paraformer.len + ctc.len + 9 * 4; |
| 168 | const ptr = Module._malloc(len); | 176 | const ptr = Module._malloc(len); |
| 169 | 177 | ||
| 170 | let offset = 0; | 178 | let offset = 0; |
| @@ -182,9 +190,10 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | @@ -182,9 +190,10 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | ||
| 182 | const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; | 190 | const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; |
| 183 | const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; | 191 | const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; |
| 184 | const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; | 192 | const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; |
| 193 | + const tokensBufLen = Module.lengthBytesUTF8(config.tokensBuf || '') + 1; | ||
| 185 | 194 | ||
| 186 | - const bufferLen = | ||
| 187 | - tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; | 195 | + const bufferLen = tokensLen + providerLen + modelTypeLen + modelingUnitLen + |
| 196 | + bpeVocabLen + tokensBufLen; | ||
| 188 | const buffer = Module._malloc(bufferLen); | 197 | const buffer = Module._malloc(bufferLen); |
| 189 | 198 | ||
| 190 | offset = 0; | 199 | offset = 0; |
| @@ -204,6 +213,9 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | @@ -204,6 +213,9 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | ||
| 204 | Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); | 213 | Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); |
| 205 | offset += bpeVocabLen; | 214 | offset += bpeVocabLen; |
| 206 | 215 | ||
| 216 | + Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen); | ||
| 217 | + offset += tokensBufLen; | ||
| 218 | + | ||
| 207 | offset = transducer.len + paraformer.len + ctc.len; | 219 | offset = transducer.len + paraformer.len + ctc.len; |
| 208 | Module.setValue(ptr + offset, buffer, 'i8*'); // tokens | 220 | Module.setValue(ptr + offset, buffer, 'i8*'); // tokens |
| 209 | offset += 4; | 221 | offset += 4; |
| @@ -232,6 +244,16 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | @@ -232,6 +244,16 @@ function initSherpaOnnxOnlineModelConfig(config, Module) { | ||
| 232 | 'i8*'); // bpeVocab | 244 | 'i8*'); // bpeVocab |
| 233 | offset += 4; | 245 | offset += 4; |
| 234 | 246 | ||
| 247 | + Module.setValue( | ||
| 248 | + ptr + offset, | ||
| 249 | + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen + | ||
| 250 | + bpeVocabLen, | ||
| 251 | + 'i8*'); // tokens_buf | ||
| 252 | + offset += 4; | ||
| 253 | + | ||
| 254 | + Module.setValue(ptr + offset, config.tokensBufSize || 0, 'i32'); | ||
| 255 | + offset += 4; | ||
| 256 | + | ||
| 235 | return { | 257 | return { |
| 236 | buffer: buffer, ptr: ptr, len: len, transducer: transducer, | 258 | buffer: buffer, ptr: ptr, len: len, transducer: transducer, |
| 237 | paraformer: paraformer, ctc: ctc | 259 | paraformer: paraformer, ctc: ctc |
| @@ -275,12 +297,20 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | @@ -275,12 +297,20 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | ||
| 275 | }; | 297 | }; |
| 276 | } | 298 | } |
| 277 | 299 | ||
| 300 | + if (!('hotwordsBuf' in config)) { | ||
| 301 | + config.hotwordsBuf = ''; | ||
| 302 | + } | ||
| 303 | + | ||
| 304 | + if (!('hotwordsBufSize' in config)) { | ||
| 305 | + config.hotwordsBufSize = 0; | ||
| 306 | + } | ||
| 307 | + | ||
| 278 | const feat = initSherpaOnnxFeatureConfig(config.featConfig, Module); | 308 | const feat = initSherpaOnnxFeatureConfig(config.featConfig, Module); |
| 279 | const model = initSherpaOnnxOnlineModelConfig(config.modelConfig, Module); | 309 | const model = initSherpaOnnxOnlineModelConfig(config.modelConfig, Module); |
| 280 | const ctcFstDecoder = initSherpaOnnxOnlineCtcFstDecoderConfig( | 310 | const ctcFstDecoder = initSherpaOnnxOnlineCtcFstDecoderConfig( |
| 281 | config.ctcFstDecoderConfig, Module) | 311 | config.ctcFstDecoderConfig, Module) |
| 282 | 312 | ||
| 283 | - const len = feat.len + model.len + 8 * 4 + ctcFstDecoder.len + 3 * 4; | 313 | + const len = feat.len + model.len + 8 * 4 + ctcFstDecoder.len + 5 * 4; |
| 284 | const ptr = Module._malloc(len); | 314 | const ptr = Module._malloc(len); |
| 285 | 315 | ||
| 286 | let offset = 0; | 316 | let offset = 0; |
| @@ -295,8 +325,9 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | @@ -295,8 +325,9 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | ||
| 295 | const hotwordsFileLen = Module.lengthBytesUTF8(config.hotwordsFile || '') + 1; | 325 | const hotwordsFileLen = Module.lengthBytesUTF8(config.hotwordsFile || '') + 1; |
| 296 | const ruleFstsFileLen = Module.lengthBytesUTF8(config.ruleFsts || '') + 1; | 326 | const ruleFstsFileLen = Module.lengthBytesUTF8(config.ruleFsts || '') + 1; |
| 297 | const ruleFarsFileLen = Module.lengthBytesUTF8(config.ruleFars || '') + 1; | 327 | const ruleFarsFileLen = Module.lengthBytesUTF8(config.ruleFars || '') + 1; |
| 298 | - const bufferLen = | ||
| 299 | - decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + ruleFarsFileLen; | 328 | + const hotwordsBufLen = Module.lengthBytesUTF8(config.hotwordsBuf || '') + 1; |
| 329 | + const bufferLen = decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + | ||
| 330 | + ruleFarsFileLen + hotwordsBufLen; | ||
| 300 | const buffer = Module._malloc(bufferLen); | 331 | const buffer = Module._malloc(bufferLen); |
| 301 | 332 | ||
| 302 | offset = 0; | 333 | offset = 0; |
| @@ -314,6 +345,10 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | @@ -314,6 +345,10 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | ||
| 314 | Module.stringToUTF8(config.ruleFars || '', buffer + offset, ruleFarsFileLen); | 345 | Module.stringToUTF8(config.ruleFars || '', buffer + offset, ruleFarsFileLen); |
| 315 | offset += ruleFarsFileLen; | 346 | offset += ruleFarsFileLen; |
| 316 | 347 | ||
| 348 | + Module.stringToUTF8( | ||
| 349 | + config.hotwordsBuf || '', buffer + offset, hotwordsBufLen); | ||
| 350 | + offset += hotwordsBufLen; | ||
| 351 | + | ||
| 317 | offset = feat.len + model.len; | 352 | offset = feat.len + model.len; |
| 318 | Module.setValue(ptr + offset, buffer, 'i8*'); // decoding method | 353 | Module.setValue(ptr + offset, buffer, 'i8*'); // decoding method |
| 319 | offset += 4; | 354 | offset += 4; |
| @@ -354,6 +389,16 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | @@ -354,6 +389,16 @@ function initSherpaOnnxOnlineRecognizerConfig(config, Module) { | ||
| 354 | Module.setValue(ptr + offset, config.blankPenalty || 0, 'float'); | 389 | Module.setValue(ptr + offset, config.blankPenalty || 0, 'float'); |
| 355 | offset += 4; | 390 | offset += 4; |
| 356 | 391 | ||
| 392 | + Module.setValue( | ||
| 393 | + ptr + offset, | ||
| 394 | + buffer + decodingMethodLen + hotwordsFileLen + ruleFstsFileLen + | ||
| 395 | + ruleFarsFileLen, | ||
| 396 | + 'i8*'); | ||
| 397 | + offset += 4; | ||
| 398 | + | ||
| 399 | + Module.setValue(ptr + offset, config.hotwordsBufSize || 0, 'i32'); | ||
| 400 | + offset += 4; | ||
| 401 | + | ||
| 357 | return { | 402 | return { |
| 358 | buffer: buffer, ptr: ptr, len: len, feat: feat, model: model, | 403 | buffer: buffer, ptr: ptr, len: len, feat: feat, model: model, |
| 359 | ctcFstDecoder: ctcFstDecoder | 404 | ctcFstDecoder: ctcFstDecoder |
| @@ -19,14 +19,14 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); | @@ -19,14 +19,14 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); | ||
| 19 | static_assert(sizeof(SherpaOnnxOnlineModelConfig) == | 19 | static_assert(sizeof(SherpaOnnxOnlineModelConfig) == |
| 20 | sizeof(SherpaOnnxOnlineTransducerModelConfig) + | 20 | sizeof(SherpaOnnxOnlineTransducerModelConfig) + |
| 21 | sizeof(SherpaOnnxOnlineParaformerModelConfig) + | 21 | sizeof(SherpaOnnxOnlineParaformerModelConfig) + |
| 22 | - sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 7 * 4, | 22 | + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4, |
| 23 | ""); | 23 | ""); |
| 24 | static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); | 24 | static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); |
| 25 | static_assert(sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) == 2 * 4, ""); | 25 | static_assert(sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) == 2 * 4, ""); |
| 26 | static_assert(sizeof(SherpaOnnxOnlineRecognizerConfig) == | 26 | static_assert(sizeof(SherpaOnnxOnlineRecognizerConfig) == |
| 27 | sizeof(SherpaOnnxFeatureConfig) + | 27 | sizeof(SherpaOnnxFeatureConfig) + |
| 28 | sizeof(SherpaOnnxOnlineModelConfig) + 8 * 4 + | 28 | sizeof(SherpaOnnxOnlineModelConfig) + 8 * 4 + |
| 29 | - sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) + 3 * 4, | 29 | + sizeof(SherpaOnnxOnlineCtcFstDecoderConfig) + 5 * 4, |
| 30 | ""); | 30 | ""); |
| 31 | 31 | ||
| 32 | void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { | 32 | void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { |
| @@ -54,6 +54,9 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { | @@ -54,6 +54,9 @@ void MyPrint(SherpaOnnxOnlineRecognizerConfig *config) { | ||
| 54 | fprintf(stdout, "model type: %s\n", model_config->model_type); | 54 | fprintf(stdout, "model type: %s\n", model_config->model_type); |
| 55 | fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit); | 55 | fprintf(stdout, "modeling unit: %s\n", model_config->modeling_unit); |
| 56 | fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab); | 56 | fprintf(stdout, "bpe vocab: %s\n", model_config->bpe_vocab); |
| 57 | + fprintf(stdout, "tokens_buf: %s\n", | ||
| 58 | + model_config->tokens_buf ? model_config->tokens_buf : ""); | ||
| 59 | + fprintf(stdout, "tokens_buf_size: %d\n", model_config->tokens_buf_size); | ||
| 57 | 60 | ||
| 58 | fprintf(stdout, "----------feat config----------\n"); | 61 | fprintf(stdout, "----------feat config----------\n"); |
| 59 | fprintf(stdout, "sample rate: %d\n", feat->sample_rate); | 62 | fprintf(stdout, "sample rate: %d\n", feat->sample_rate); |
| @@ -62,12 +62,20 @@ function initSherpaOnnxOnlineTransducerModelConfig(config, Module) { | @@ -62,12 +62,20 @@ function initSherpaOnnxOnlineTransducerModelConfig(config, Module) { | ||
| 62 | 62 | ||
| 63 | // The user should free the returned pointers | 63 | // The user should free the returned pointers |
| 64 | function initModelConfig(config, Module) { | 64 | function initModelConfig(config, Module) { |
| 65 | + if (!('tokensBuf' in config)) { | ||
| 66 | + config.tokensBuf = ''; | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + if (!('tokensBufSize' in config)) { | ||
| 70 | + config.tokensBufSize = 0; | ||
| 71 | + } | ||
| 72 | + | ||
| 65 | const transducer = | 73 | const transducer = |
| 66 | initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); | 74 | initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); |
| 67 | const paraformer_len = 2 * 4 | 75 | const paraformer_len = 2 * 4 |
| 68 | const ctc_len = 1 * 4 | 76 | const ctc_len = 1 * 4 |
| 69 | 77 | ||
| 70 | - const len = transducer.len + paraformer_len + ctc_len + 7 * 4; | 78 | + const len = transducer.len + paraformer_len + ctc_len + 9 * 4; |
| 71 | const ptr = Module._malloc(len); | 79 | const ptr = Module._malloc(len); |
| 72 | Module.HEAPU8.fill(0, ptr, ptr + len); | 80 | Module.HEAPU8.fill(0, ptr, ptr + len); |
| 73 | 81 | ||
| @@ -79,8 +87,9 @@ function initModelConfig(config, Module) { | @@ -79,8 +87,9 @@ function initModelConfig(config, Module) { | ||
| 79 | const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; | 87 | const modelTypeLen = Module.lengthBytesUTF8(config.modelType || '') + 1; |
| 80 | const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; | 88 | const modelingUnitLen = Module.lengthBytesUTF8(config.modelingUnit || '') + 1; |
| 81 | const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; | 89 | const bpeVocabLen = Module.lengthBytesUTF8(config.bpeVocab || '') + 1; |
| 82 | - const bufferLen = | ||
| 83 | - tokensLen + providerLen + modelTypeLen + modelingUnitLen + bpeVocabLen; | 90 | + const tokensBufLen = Module.lengthBytesUTF8(config.tokensBuf || '') + 1; |
| 91 | + const bufferLen = tokensLen + providerLen + modelTypeLen + modelingUnitLen + | ||
| 92 | + bpeVocabLen + tokensBufLen; | ||
| 84 | const buffer = Module._malloc(bufferLen); | 93 | const buffer = Module._malloc(bufferLen); |
| 85 | 94 | ||
| 86 | offset = 0; | 95 | offset = 0; |
| @@ -100,6 +109,9 @@ function initModelConfig(config, Module) { | @@ -100,6 +109,9 @@ function initModelConfig(config, Module) { | ||
| 100 | Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); | 109 | Module.stringToUTF8(config.bpeVocab || '', buffer + offset, bpeVocabLen); |
| 101 | offset += bpeVocabLen; | 110 | offset += bpeVocabLen; |
| 102 | 111 | ||
| 112 | + Module.stringToUTF8(config.tokensBuf || '', buffer + offset, tokensBufLen); | ||
| 113 | + offset += tokensBufLen; | ||
| 114 | + | ||
| 103 | offset = transducer.len + paraformer_len + ctc_len; | 115 | offset = transducer.len + paraformer_len + ctc_len; |
| 104 | Module.setValue(ptr + offset, buffer, 'i8*'); // tokens | 116 | Module.setValue(ptr + offset, buffer, 'i8*'); // tokens |
| 105 | offset += 4; | 117 | offset += 4; |
| @@ -128,6 +140,16 @@ function initModelConfig(config, Module) { | @@ -128,6 +140,16 @@ function initModelConfig(config, Module) { | ||
| 128 | 'i8*'); // bpeVocab | 140 | 'i8*'); // bpeVocab |
| 129 | offset += 4; | 141 | offset += 4; |
| 130 | 142 | ||
| 143 | + Module.setValue( | ||
| 144 | + ptr + offset, | ||
| 145 | + buffer + tokensLen + providerLen + modelTypeLen + modelingUnitLen + | ||
| 146 | + bpeVocabLen, | ||
| 147 | + 'i8*'); // tokens_buf | ||
| 148 | + offset += 4; | ||
| 149 | + | ||
| 150 | + Module.setValue(ptr + offset, config.tokensBufSize || 0, 'i32'); | ||
| 151 | + offset += 4; | ||
| 152 | + | ||
| 131 | return { | 153 | return { |
| 132 | buffer: buffer, ptr: ptr, len: len, transducer: transducer | 154 | buffer: buffer, ptr: ptr, len: len, transducer: transducer |
| 133 | } | 155 | } |
| @@ -19,7 +19,7 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); | @@ -19,7 +19,7 @@ static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); | ||
| 19 | static_assert(sizeof(SherpaOnnxOnlineModelConfig) == | 19 | static_assert(sizeof(SherpaOnnxOnlineModelConfig) == |
| 20 | sizeof(SherpaOnnxOnlineTransducerModelConfig) + | 20 | sizeof(SherpaOnnxOnlineTransducerModelConfig) + |
| 21 | sizeof(SherpaOnnxOnlineParaformerModelConfig) + | 21 | sizeof(SherpaOnnxOnlineParaformerModelConfig) + |
| 22 | - sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 7 * 4, | 22 | + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 9 * 4, |
| 23 | ""); | 23 | ""); |
| 24 | static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); | 24 | static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); |
| 25 | static_assert(sizeof(SherpaOnnxKeywordSpotterConfig) == | 25 | static_assert(sizeof(SherpaOnnxKeywordSpotterConfig) == |
-
请 注册 或 登录 后发表评论