Wei Kang
Committed by GitHub

Encode hotwords in C++ side (#828)

* Encode hotwords in C++ side
正在显示 43 个修改的文件 包含 713 行增加101 行删除
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 log "test online NeMo CTC" 13 log "test online NeMo CTC"
12 14
13 url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 15 url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
@@ -8,6 +8,8 @@ log() { @@ -8,6 +8,8 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +export GIT_CLONE_PROTECTION_ACTIVE=false
  12 +
11 echo "EXE is $EXE" 13 echo "EXE is $EXE"
12 echo "PATH: $PATH" 14 echo "PATH: $PATH"
13 15
@@ -234,6 +234,7 @@ endif() @@ -234,6 +234,7 @@ endif()
234 include(kaldi-native-fbank) 234 include(kaldi-native-fbank)
235 include(kaldi-decoder) 235 include(kaldi-decoder)
236 include(onnxruntime) 236 include(onnxruntime)
  237 +include(simple-sentencepiece)
237 set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) 238 set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR})
238 message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") 239 message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")
239 240
@@ -126,7 +126,7 @@ echo "Generate xcframework" @@ -126,7 +126,7 @@ echo "Generate xcframework"
126 126
127 mkdir -p "build/simulator/lib" 127 mkdir -p "build/simulator/lib"
128 for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ 128 for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
129 - libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a; do 129 + libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a libssentencepiece_core.a; do
130 lipo -create build/simulator_arm64/lib/${f} \ 130 lipo -create build/simulator_arm64/lib/${f} \
131 build/simulator_x86_64/lib/${f} \ 131 build/simulator_x86_64/lib/${f} \
132 -output build/simulator/lib/${f} 132 -output build/simulator/lib/${f}
@@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \ @@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \
140 build/simulator/lib/libsherpa-onnx-core.a \ 140 build/simulator/lib/libsherpa-onnx-core.a \
141 build/simulator/lib/libsherpa-onnx-fst.a \ 141 build/simulator/lib/libsherpa-onnx-fst.a \
142 build/simulator/lib/libsherpa-onnx-kaldifst-core.a \ 142 build/simulator/lib/libsherpa-onnx-kaldifst-core.a \
143 - build/simulator/lib/libkaldi-decoder-core.a 143 + build/simulator/lib/libkaldi-decoder-core.a \
  144 + build/simulator/lib/libssentencepiece_core.a
144 145
145 libtool -static -o build/os64/sherpa-onnx.a \ 146 libtool -static -o build/os64/sherpa-onnx.a \
146 build/os64/lib/libkaldi-native-fbank-core.a \ 147 build/os64/lib/libkaldi-native-fbank-core.a \
@@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \ @@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \
148 build/os64/lib/libsherpa-onnx-core.a \ 149 build/os64/lib/libsherpa-onnx-core.a \
149 build/os64/lib/libsherpa-onnx-fst.a \ 150 build/os64/lib/libsherpa-onnx-fst.a \
150 build/os64/lib/libsherpa-onnx-kaldifst-core.a \ 151 build/os64/lib/libsherpa-onnx-kaldifst-core.a \
151 - build/os64/lib/libkaldi-decoder-core.a 152 + build/os64/lib/libkaldi-decoder-core.a \
  153 + build/os64/lib/libssentencepiece_core.a
152 154
153 rm -rf sherpa-onnx.xcframework 155 rm -rf sherpa-onnx.xcframework
154 156
@@ -129,7 +129,7 @@ echo "Generate xcframework" @@ -129,7 +129,7 @@ echo "Generate xcframework"
129 129
130 mkdir -p "build/simulator/lib" 130 mkdir -p "build/simulator/lib"
131 for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ 131 for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
132 - libsherpa-onnx-fstfar.a \ 132 + libsherpa-onnx-fstfar.a libssentencepiece_core.a \
133 libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \ 133 libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \
134 libucd.a libpiper_phonemize.a libespeak-ng.a; do 134 libucd.a libpiper_phonemize.a libespeak-ng.a; do
135 lipo -create build/simulator_arm64/lib/${f} \ 135 lipo -create build/simulator_arm64/lib/${f} \
@@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \ @@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \
150 build/simulator/lib/libucd.a \ 150 build/simulator/lib/libucd.a \
151 build/simulator/lib/libpiper_phonemize.a \ 151 build/simulator/lib/libpiper_phonemize.a \
152 build/simulator/lib/libespeak-ng.a \ 152 build/simulator/lib/libespeak-ng.a \
  153 + build/simulator/lib/libssentencepiece_core.a
153 154
154 libtool -static -o build/os64/sherpa-onnx.a \ 155 libtool -static -o build/os64/sherpa-onnx.a \
155 build/os64/lib/libkaldi-native-fbank-core.a \ 156 build/os64/lib/libkaldi-native-fbank-core.a \
@@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \ @@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \
162 build/os64/lib/libucd.a \ 163 build/os64/lib/libucd.a \
163 build/os64/lib/libpiper_phonemize.a \ 164 build/os64/lib/libpiper_phonemize.a \
164 build/os64/lib/libespeak-ng.a \ 165 build/os64/lib/libespeak-ng.a \
  166 + build/os64/lib/libssentencepiece_core.a
165 167
166 168
167 rm -rf sherpa-onnx.xcframework 169 rm -rf sherpa-onnx.xcframework
@@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \ @@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \
33 ./install/lib/libkaldi-decoder-core.a \ 33 ./install/lib/libkaldi-decoder-core.a \
34 ./install/lib/libucd.a \ 34 ./install/lib/libucd.a \
35 ./install/lib/libpiper_phonemize.a \ 35 ./install/lib/libpiper_phonemize.a \
36 - ./install/lib/libespeak-ng.a 36 + ./install/lib/libespeak-ng.a \
  37 + ./install/lib/libssentencepiece_core.a
  1 +function(download_simple_sentencepiece)
  2 + include(FetchContent)
  3 +
  4 + set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
  5 + set(simple-sentencepiece_URL2 "https://hub.nauu.cf/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
  6 + set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download simple-sentencepiece
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz
  12 + ${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz
  13 + ${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz
  14 + /tmp/simple-sentencepiece-0.7.tar.gz
  15 + /star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(simple-sentencepiece_URL "${f}")
  21 + file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL)
  22 + message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}")
  23 + set(simple-sentencepiece_URL2)
  24 + break()
  25 + endif()
  26 + endforeach()
  27 +
  28 + set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
  29 + set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  30 +
  31 + FetchContent_Declare(simple-sentencepiece
  32 + URL
  33 + ${simple-sentencepiece_URL}
  34 + ${simple-sentencepiece_URL2}
  35 + URL_HASH
  36 + ${simple-sentencepiece_HASH}
  37 + )
  38 +
  39 + FetchContent_GetProperties(simple-sentencepiece)
  40 + if(NOT simple-sentencepiece_POPULATED)
  41 + message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}")
  42 + FetchContent_Populate(simple-sentencepiece)
  43 + endif()
  44 + message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}")
  45 + add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL)
  46 +
  47 + target_include_directories(ssentencepiece_core
  48 + PUBLIC
  49 + ${simple-sentencepiece_SOURCE_DIR}/
  50 + )
  51 +
  52 + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
  53 + install(TARGETS ssentencepiece_core DESTINATION ..)
  54 + else()
  55 + install(TARGETS ssentencepiece_core DESTINATION lib)
  56 + endif()
  57 +
  58 + if(WIN32 AND BUILD_SHARED_LIBS)
  59 + install(TARGETS ssentencepiece_core DESTINATION bin)
  60 + endif()
  61 +endfunction()
  62 +
  63 +download_simple_sentencepiece()
@@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() { @@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() {
60 function testOnlineAsr() { 60 function testOnlineAsr() {
61 if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then 61 if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
62 git lfs install 62 git lfs install
63 - git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 63 + GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
64 fi 64 fi
65 65
66 if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then 66 if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 piper_phonemize.lib; 18 piper_phonemize.lib;
19 espeak-ng.lib; 19 espeak-ng.lib;
20 ucd.lib; 20 ucd.lib;
  21 + ssentencepiece_core.lib;
21 </SherpaOnnxLibraries> 22 </SherpaOnnxLibraries>
22 </PropertyGroup> 23 </PropertyGroup>
23 <ItemDefinitionGroup> 24 <ItemDefinitionGroup>
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 piper_phonemize.lib; 18 piper_phonemize.lib;
19 espeak-ng.lib; 19 espeak-ng.lib;
20 ucd.lib; 20 ucd.lib;
  21 + ssentencepiece_core.lib;
21 </SherpaOnnxLibraries> 22 </SherpaOnnxLibraries>
22 </PropertyGroup> 23 </PropertyGroup>
23 <ItemDefinitionGroup> 24 <ItemDefinitionGroup>
@@ -18,6 +18,7 @@ @@ -18,6 +18,7 @@
18 piper_phonemize.lib; 18 piper_phonemize.lib;
19 espeak-ng.lib; 19 espeak-ng.lib;
20 ucd.lib; 20 ucd.lib;
  21 + ssentencepiece_core.lib;
21 </SherpaOnnxLibraries> 22 </SherpaOnnxLibraries>
22 </PropertyGroup> 23 </PropertyGroup>
23 <ItemDefinitionGroup> 24 <ItemDefinitionGroup>
@@ -110,11 +110,9 @@ def get_args(): @@ -110,11 +110,9 @@ def get_args():
110 type=str, 110 type=str,
111 default="", 111 default="",
112 help=""" 112 help="""
113 - The file containing hotwords, one words/phrases per line, and for each  
114 - phrase the bpe/cjkchar are separated by a space. For example:  
115 -  
116 - ▁HE LL O ▁WORLD  
117 - 你 好 世 界 113 + The file containing hotwords, one words/phrases per line, like
  114 + HELLO WORLD
  115 + 你好世界
118 """, 116 """,
119 ) 117 )
120 118
@@ -129,6 +127,28 @@ def get_args(): @@ -129,6 +127,28 @@ def get_args():
129 ) 127 )
130 128
131 parser.add_argument( 129 parser.add_argument(
  130 + "--modeling-unit",
  131 + type=str,
  132 + default="",
  133 + help="""
  134 + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
  135 + Used only when hotwords-file is given.
  136 + """,
  137 + )
  138 +
  139 + parser.add_argument(
  140 + "--bpe-vocab",
  141 + type=str,
  142 + default="",
  143 + help="""
  144 + The path to the bpe vocabulary, the bpe vocabulary is generated by
  145 + sentencepiece, you can also export the bpe vocabulary through a bpe model
  146 + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
  147 + and modeling-unit is bpe or cjkchar+bpe.
  148 + """,
  149 + )
  150 +
  151 + parser.add_argument(
132 "--encoder", 152 "--encoder",
133 default="", 153 default="",
134 type=str, 154 type=str,
@@ -347,6 +367,8 @@ def main(): @@ -347,6 +367,8 @@ def main():
347 decoding_method=args.decoding_method, 367 decoding_method=args.decoding_method,
348 hotwords_file=args.hotwords_file, 368 hotwords_file=args.hotwords_file,
349 hotwords_score=args.hotwords_score, 369 hotwords_score=args.hotwords_score,
  370 + modeling_unit=args.modeling_unit,
  371 + bpe_vocab=args.bpe_vocab,
350 blank_penalty=args.blank_penalty, 372 blank_penalty=args.blank_penalty,
351 debug=args.debug, 373 debug=args.debug,
352 ) 374 )
@@ -198,11 +198,9 @@ def get_args(): @@ -198,11 +198,9 @@ def get_args():
198 type=str, 198 type=str,
199 default="", 199 default="",
200 help=""" 200 help="""
201 - The file containing hotwords, one words/phrases per line, and for each  
202 - phrase the bpe/cjkchar are separated by a space. For example:  
203 -  
204 - ▁HE LL O ▁WORLD  
205 - 你 好 世 界 201 + The file containing hotwords, one words/phrases per line, like
  202 + HELLO WORLD
  203 + 你好世界
206 """, 204 """,
207 ) 205 )
208 206
@@ -217,6 +215,28 @@ def get_args(): @@ -217,6 +215,28 @@ def get_args():
217 ) 215 )
218 216
219 parser.add_argument( 217 parser.add_argument(
  218 + "--modeling-unit",
  219 + type=str,
  220 + default="",
  221 + help="""
  222 + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
  223 + Used only when hotwords-file is given.
  224 + """,
  225 + )
  226 +
  227 + parser.add_argument(
  228 + "--bpe-vocab",
  229 + type=str,
  230 + default="",
  231 + help="""
  232 + The path to the bpe vocabulary, the bpe vocabulary is generated by
  233 + sentencepiece, you can also export the bpe vocabulary through a bpe model
  234 + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
  235 + and modeling-unit is bpe or cjkchar+bpe.
  236 + """,
  237 + )
  238 +
  239 + parser.add_argument(
220 "--blank-penalty", 240 "--blank-penalty",
221 type=float, 241 type=float,
222 default=0.0, 242 default=0.0,
@@ -302,6 +322,8 @@ def main(): @@ -302,6 +322,8 @@ def main():
302 lm_scale=args.lm_scale, 322 lm_scale=args.lm_scale,
303 hotwords_file=args.hotwords_file, 323 hotwords_file=args.hotwords_file,
304 hotwords_score=args.hotwords_score, 324 hotwords_score=args.hotwords_score,
  325 + modeling_unit=args.modeling_unit,
  326 + bpe_vocab=args.bpe_vocab,
305 blank_penalty=args.blank_penalty, 327 blank_penalty=args.blank_penalty,
306 ) 328 )
307 elif args.zipformer2_ctc: 329 elif args.zipformer2_ctc:
  1 +#!/usr/bin/env python3
  2 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
  3 +#
  4 +# See ../../../../LICENSE for clarification regarding multiple authors
  5 +#
  6 +# Licensed under the Apache License, Version 2.0 (the "License");
  7 +# you may not use this file except in compliance with the License.
  8 +# You may obtain a copy of the License at
  9 +#
  10 +# http://www.apache.org/licenses/LICENSE-2.0
  11 +#
  12 +# Unless required by applicable law or agreed to in writing, software
  13 +# distributed under the License is distributed on an "AS IS" BASIS,
  14 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15 +# See the License for the specific language governing permissions and
  16 +# limitations under the License.
  17 +
  18 +
  19 +# You can install sentencepiece via:
  20 +#
  21 +# pip install sentencepiece
  22 +#
  23 +# Due to an issue reported in
  24 +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
  25 +#
  26 +# Please install a version >=0.1.96
  27 +
  28 +import argparse
  29 +from typing import Dict
  30 +
  31 +try:
  32 + import sentencepiece as spm
  33 +except ImportError:
  34 + print('Please run')
  35 + print(' pip install sentencepiece')
  36 + print('before you continue')
  37 + raise
  38 +
  39 +
  40 +def get_args():
  41 + parser = argparse.ArgumentParser()
  42 + parser.add_argument(
  43 + "--bpe-model",
  44 + type=str,
  45 + help="The path to the bpe model.",
  46 + )
  47 +
  48 + return parser.parse_args()
  49 +
  50 +
  51 +def main():
  52 + args = get_args()
  53 + model_file = args.bpe_model
  54 +
  55 + vocab_file = model_file.replace(".model", ".vocab")
  56 +
  57 + sp = spm.SentencePieceProcessor()
  58 + sp.Load(model_file)
  59 + vocabs = [sp.IdToPiece(id) for id in range(sp.GetPieceSize())]
  60 + with open(vocab_file, "w") as vfile:
  61 + for v in vocabs:
  62 + id = sp.PieceToId(v)
  63 + vfile.write(f"{v}\t{sp.GetScore(id)}\n")
  64 + print(f"Vocabulary file is written to {vocab_file}")
  65 +
  66 +
  67 +if __name__ == "__main__":
  68 + main()
@@ -165,6 +165,7 @@ endif() @@ -165,6 +165,7 @@ endif()
165 target_link_libraries(sherpa-onnx-core 165 target_link_libraries(sherpa-onnx-core
166 kaldi-native-fbank-core 166 kaldi-native-fbank-core
167 kaldi-decoder-core 167 kaldi-decoder-core
  168 + ssentencepiece_core
168 ) 169 )
169 170
170 if(SHERPA_ONNX_ENABLE_GPU) 171 if(SHERPA_ONNX_ENABLE_GPU)
@@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) @@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
491 pad-sequence-test.cc 492 pad-sequence-test.cc
492 slice-test.cc 493 slice-test.cc
493 stack-test.cc 494 stack-test.cc
  495 + text2token-test.cc
494 transpose-test.cc 496 transpose-test.cc
495 unbind-test.cc 497 unbind-test.cc
496 utfcpp-test.cc 498 utfcpp-test.cc
@@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) {
35 "Valid values are: transducer, paraformer, nemo_ctc, whisper, " 35 "Valid values are: transducer, paraformer, nemo_ctc, whisper, "
36 "tdnn, zipformer2_ctc" 36 "tdnn, zipformer2_ctc"
37 "All other values lead to loading the model twice."); 37 "All other values lead to loading the model twice.");
  38 + po->Register("modeling-unit", &modeling_unit,
  39 + "The modeling unit of the model, commonly used units are bpe, "
  40 + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
  41 + "hotwords are provided, we need it to encode the hotwords into "
  42 + "token sequence.");
  43 + po->Register("bpe-vocab", &bpe_vocab,
  44 + "The vocabulary generated by google's sentencepiece program. "
  45 + "It is a file has two columns, one is the token, the other is "
  46 + "the log probability, you can get it from the directory where "
  47 + "your bpe model is generated. Only used when hotwords provided "
  48 + "and the modeling unit is bpe or cjkchar+bpe");
38 } 49 }
39 50
40 bool OfflineModelConfig::Validate() const { 51 bool OfflineModelConfig::Validate() const {
@@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const { @@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const {
48 return false; 59 return false;
49 } 60 }
50 61
  62 + if (!modeling_unit.empty() &&
  63 + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
  64 + if (!FileExists(bpe_vocab)) {
  65 + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
  66 + return false;
  67 + }
  68 + }
  69 +
51 if (!paraformer.model.empty()) { 70 if (!paraformer.model.empty()) {
52 return paraformer.Validate(); 71 return paraformer.Validate();
53 } 72 }
@@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const { @@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const {
90 os << "num_threads=" << num_threads << ", "; 109 os << "num_threads=" << num_threads << ", ";
91 os << "debug=" << (debug ? "True" : "False") << ", "; 110 os << "debug=" << (debug ? "True" : "False") << ", ";
92 os << "provider=\"" << provider << "\", "; 111 os << "provider=\"" << provider << "\", ";
93 - os << "model_type=\"" << model_type << "\")"; 112 + os << "model_type=\"" << model_type << "\", ";
  113 + os << "modeling_unit=\"" << modeling_unit << "\", ";
  114 + os << "bpe_vocab=\"" << bpe_vocab << "\")";
94 115
95 return os.str(); 116 return os.str();
96 } 117 }
@@ -41,6 +41,9 @@ struct OfflineModelConfig { @@ -41,6 +41,9 @@ struct OfflineModelConfig {
41 // All other values are invalid and lead to loading the model twice. 41 // All other values are invalid and lead to loading the model twice.
42 std::string model_type; 42 std::string model_type;
43 43
  44 + std::string modeling_unit = "cjkchar";
  45 + std::string bpe_vocab;
  46 +
44 OfflineModelConfig() = default; 47 OfflineModelConfig() = default;
45 OfflineModelConfig(const OfflineTransducerModelConfig &transducer, 48 OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
46 const OfflineParaformerModelConfig &paraformer, 49 const OfflineParaformerModelConfig &paraformer,
@@ -50,7 +53,9 @@ struct OfflineModelConfig { @@ -50,7 +53,9 @@ struct OfflineModelConfig {
50 const OfflineZipformerCtcModelConfig &zipformer_ctc, 53 const OfflineZipformerCtcModelConfig &zipformer_ctc,
51 const OfflineWenetCtcModelConfig &wenet_ctc, 54 const OfflineWenetCtcModelConfig &wenet_ctc,
52 const std::string &tokens, int32_t num_threads, bool debug, 55 const std::string &tokens, int32_t num_threads, bool debug,
53 - const std::string &provider, const std::string &model_type) 56 + const std::string &provider, const std::string &model_type,
  57 + const std::string &modeling_unit,
  58 + const std::string &bpe_vocab)
54 : transducer(transducer), 59 : transducer(transducer),
55 paraformer(paraformer), 60 paraformer(paraformer),
56 nemo_ctc(nemo_ctc), 61 nemo_ctc(nemo_ctc),
@@ -62,7 +67,9 @@ struct OfflineModelConfig { @@ -62,7 +67,9 @@ struct OfflineModelConfig {
62 num_threads(num_threads), 67 num_threads(num_threads),
63 debug(debug), 68 debug(debug),
64 provider(provider), 69 provider(provider),
65 - model_type(model_type) {} 70 + model_type(model_type),
  71 + modeling_unit(modeling_unit),
  72 + bpe_vocab(bpe_vocab) {}
66 73
67 void Register(ParseOptions *po); 74 void Register(ParseOptions *po);
68 bool Validate() const; 75 bool Validate() const;
@@ -31,6 +31,7 @@ @@ -31,6 +31,7 @@
31 #include "sherpa-onnx/csrc/pad-sequence.h" 31 #include "sherpa-onnx/csrc/pad-sequence.h"
32 #include "sherpa-onnx/csrc/symbol-table.h" 32 #include "sherpa-onnx/csrc/symbol-table.h"
33 #include "sherpa-onnx/csrc/utils.h" 33 #include "sherpa-onnx/csrc/utils.h"
  34 +#include "ssentencepiece/csrc/ssentencepiece.h"
34 35
35 namespace sherpa_onnx { 36 namespace sherpa_onnx {
36 37
@@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
76 : config_(config), 77 : config_(config),
77 symbol_table_(config_.model_config.tokens), 78 symbol_table_(config_.model_config.tokens),
78 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { 79 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
79 - if (!config_.hotwords_file.empty()) {  
80 - InitHotwords();  
81 - }  
82 if (config_.decoding_method == "greedy_search") { 80 if (config_.decoding_method == "greedy_search") {
83 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( 81 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
84 model_.get(), config_.blank_penalty); 82 model_.get(), config_.blank_penalty);
@@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
87 lm_ = OfflineLM::Create(config.lm_config); 85 lm_ = OfflineLM::Create(config.lm_config);
88 } 86 }
89 87
  88 + if (!config_.model_config.bpe_vocab.empty()) {
  89 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
  90 + config_.model_config.bpe_vocab);
  91 + }
  92 +
  93 + if (!config_.hotwords_file.empty()) {
  94 + InitHotwords();
  95 + }
  96 +
90 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 97 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
91 model_.get(), lm_.get(), config_.max_active_paths, 98 model_.get(), lm_.get(), config_.max_active_paths,
92 config_.lm_config.scale, config_.blank_penalty); 99 config_.lm_config.scale, config_.blank_penalty);
@@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
112 lm_ = OfflineLM::Create(mgr, config.lm_config); 119 lm_ = OfflineLM::Create(mgr, config.lm_config);
113 } 120 }
114 121
  122 + if (!config_.model_config.bpe_vocab.empty()) {
  123 + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
  124 + std::istringstream iss(std::string(buf.begin(), buf.end()));
  125 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
  126 + }
  127 +
  128 + if (!config_.hotwords_file.empty()) {
  129 + InitHotwords(mgr);
  130 + }
  131 +
115 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 132 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
116 model_.get(), lm_.get(), config_.max_active_paths, 133 model_.get(), lm_.get(), config_.max_active_paths,
117 config_.lm_config.scale, config_.blank_penalty); 134 config_.lm_config.scale, config_.blank_penalty);
@@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
128 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); 145 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
129 std::istringstream is(hws); 146 std::istringstream is(hws);
130 std::vector<std::vector<int32_t>> current; 147 std::vector<std::vector<int32_t>> current;
131 - if (!EncodeHotwords(is, symbol_table_, &current)) { 148 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
  149 + bpe_encoder_.get(), &current)) {
132 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", 150 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
133 hotwords.c_str()); 151 hotwords.c_str());
134 } 152 }
@@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
207 exit(-1); 225 exit(-1);
208 } 226 }
209 227
210 - if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {  
211 - SHERPA_ONNX_LOGE("Encode hotwords failed."); 228 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
  229 + bpe_encoder_.get(), &hotwords_)) {
  230 + SHERPA_ONNX_LOGE(
  231 + "Failed to encode some hotwords, skip them already, see logs above "
  232 + "for details.");
  233 + }
  234 + hotwords_graph_ =
  235 + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
  236 + }
  237 +
  238 +#if __ANDROID_API__ >= 9
  239 + void InitHotwords(AAssetManager *mgr) {
  240 + // each line in hotwords_file contains space-separated words
  241 +
  242 + auto buf = ReadFile(mgr, config_.hotwords_file);
  243 +
  244 + std::istringstream is(std::string(buf.begin(), buf.end()));
  245 +
  246 + if (!is) {
  247 + SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
  248 + config_.hotwords_file.c_str());
212 exit(-1); 249 exit(-1);
213 } 250 }
  251 +
  252 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
  253 + bpe_encoder_.get(), &hotwords_)) {
  254 + SHERPA_ONNX_LOGE(
  255 + "Failed to encode some hotwords, skip them already, see logs above "
  256 + "for details.");
  257 + }
214 hotwords_graph_ = 258 hotwords_graph_ =
215 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 259 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
216 } 260 }
  261 +#endif
217 262
218 private: 263 private:
219 OfflineRecognizerConfig config_; 264 OfflineRecognizerConfig config_;
220 SymbolTable symbol_table_; 265 SymbolTable symbol_table_;
221 std::vector<std::vector<int32_t>> hotwords_; 266 std::vector<std::vector<int32_t>> hotwords_;
222 ContextGraphPtr hotwords_graph_; 267 ContextGraphPtr hotwords_graph_;
  268 + std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
223 std::unique_ptr<OfflineTransducerModel> model_; 269 std::unique_ptr<OfflineTransducerModel> model_;
224 std::unique_ptr<OfflineTransducerDecoder> decoder_; 270 std::unique_ptr<OfflineTransducerDecoder> decoder_;
225 std::unique_ptr<OfflineLM> lm_; 271 std::unique_ptr<OfflineLM> lm_;
@@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
37 37
38 po->Register( 38 po->Register(
39 "hotwords-file", &hotwords_file, 39 "hotwords-file", &hotwords_file,
40 - "The file containing hotwords, one words/phrases per line, and for each"  
41 - "phrase the bpe/cjkchar are separated by a space. For example: "  
42 - "▁HE LL O ▁WORLD"  
43 - "你 好 世 界"); 40 + "The file containing hotwords, one words/phrases per line, For example: "
  41 + "HELLO WORLD"
  42 + "你好世界");
44 43
45 po->Register("hotwords-score", &hotwords_score, 44 po->Register("hotwords-score", &hotwords_score,
46 "The bonus score for each token in context word/phrase. " 45 "The bonus score for each token in context word/phrase. "
@@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) { @@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) {
32 po->Register("provider", &provider, 32 po->Register("provider", &provider,
33 "Specify a provider to use: cpu, cuda, coreml"); 33 "Specify a provider to use: cpu, cuda, coreml");
34 34
  35 + po->Register("modeling-unit", &modeling_unit,
  36 + "The modeling unit of the model, commonly used units are bpe, "
  37 + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
  38 + "hotwords are provided, we need it to encode the hotwords into "
  39 + "token sequence.");
  40 +
  41 + po->Register("bpe-vocab", &bpe_vocab,
  42 + "The vocabulary generated by google's sentencepiece program. "
  43 + "It is a file has two columns, one is the token, the other is "
  44 + "the log probability, you can get it from the directory where "
  45 + "your bpe model is generated. Only used when hotwords provided "
  46 + "and the modeling unit is bpe or cjkchar+bpe");
  47 +
35 po->Register("model-type", &model_type, 48 po->Register("model-type", &model_type,
36 "Specify it to reduce model initialization time. " 49 "Specify it to reduce model initialization time. "
37 "Valid values are: conformer, lstm, zipformer, zipformer2, " 50 "Valid values are: conformer, lstm, zipformer, zipformer2, "
@@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const { @@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const {
50 return false; 63 return false;
51 } 64 }
52 65
  66 + if (!modeling_unit.empty() &&
  67 + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
  68 + if (!FileExists(bpe_vocab)) {
  69 + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
  70 + return false;
  71 + }
  72 + }
  73 +
53 if (!paraformer.encoder.empty()) { 74 if (!paraformer.encoder.empty()) {
54 return paraformer.Validate(); 75 return paraformer.Validate();
55 } 76 }
@@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const { @@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const {
83 os << "warm_up=" << warm_up << ", "; 104 os << "warm_up=" << warm_up << ", ";
84 os << "debug=" << (debug ? "True" : "False") << ", "; 105 os << "debug=" << (debug ? "True" : "False") << ", ";
85 os << "provider=\"" << provider << "\", "; 106 os << "provider=\"" << provider << "\", ";
86 - os << "model_type=\"" << model_type << "\")"; 107 + os << "model_type=\"" << model_type << "\", ";
  108 + os << "modeling_unit=\"" << modeling_unit << "\", ";
  109 + os << "bpe_vocab=\"" << bpe_vocab << "\")";
87 110
88 return os.str(); 111 return os.str();
89 } 112 }
@@ -37,6 +37,13 @@ struct OnlineModelConfig { @@ -37,6 +37,13 @@ struct OnlineModelConfig {
37 // All other values are invalid and lead to loading the model twice. 37 // All other values are invalid and lead to loading the model twice.
38 std::string model_type; 38 std::string model_type;
39 39
  40 + // Valid values:
  41 + // - cjkchar
  42 + // - bpe
  43 + // - cjkchar+bpe
  44 + std::string modeling_unit = "cjkchar";
  45 + std::string bpe_vocab;
  46 +
40 OnlineModelConfig() = default; 47 OnlineModelConfig() = default;
41 OnlineModelConfig(const OnlineTransducerModelConfig &transducer, 48 OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
42 const OnlineParaformerModelConfig &paraformer, 49 const OnlineParaformerModelConfig &paraformer,
@@ -45,7 +52,9 @@ struct OnlineModelConfig { @@ -45,7 +52,9 @@ struct OnlineModelConfig {
45 const OnlineNeMoCtcModelConfig &nemo_ctc, 52 const OnlineNeMoCtcModelConfig &nemo_ctc,
46 const std::string &tokens, int32_t num_threads, 53 const std::string &tokens, int32_t num_threads,
47 int32_t warm_up, bool debug, const std::string &provider, 54 int32_t warm_up, bool debug, const std::string &provider,
48 - const std::string &model_type) 55 + const std::string &model_type,
  56 + const std::string &modeling_unit,
  57 + const std::string &bpe_vocab)
49 : transducer(transducer), 58 : transducer(transducer),
50 paraformer(paraformer), 59 paraformer(paraformer),
51 wenet_ctc(wenet_ctc), 60 wenet_ctc(wenet_ctc),
@@ -56,7 +65,9 @@ struct OnlineModelConfig { @@ -56,7 +65,9 @@ struct OnlineModelConfig {
56 warm_up(warm_up), 65 warm_up(warm_up),
57 debug(debug), 66 debug(debug),
58 provider(provider), 67 provider(provider),
59 - model_type(model_type) {} 68 + model_type(model_type),
  69 + modeling_unit(modeling_unit),
  70 + bpe_vocab(bpe_vocab) {}
60 71
61 void Register(ParseOptions *po); 72 void Register(ParseOptions *po);
62 bool Validate() const; 73 bool Validate() const;
@@ -15,8 +15,6 @@ @@ -15,8 +15,6 @@
15 #include <vector> 15 #include <vector>
16 16
17 #if __ANDROID_API__ >= 9 17 #if __ANDROID_API__ >= 9
18 -#include <strstream>  
19 -  
20 #include "android/asset_manager.h" 18 #include "android/asset_manager.h"
21 #include "android/asset_manager_jni.h" 19 #include "android/asset_manager_jni.h"
22 #endif 20 #endif
@@ -33,6 +31,7 @@ @@ -33,6 +31,7 @@
33 #include "sherpa-onnx/csrc/onnx-utils.h" 31 #include "sherpa-onnx/csrc/onnx-utils.h"
34 #include "sherpa-onnx/csrc/symbol-table.h" 32 #include "sherpa-onnx/csrc/symbol-table.h"
35 #include "sherpa-onnx/csrc/utils.h" 33 #include "sherpa-onnx/csrc/utils.h"
  34 +#include "ssentencepiece/csrc/ssentencepiece.h"
36 35
37 namespace sherpa_onnx { 36 namespace sherpa_onnx {
38 37
@@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
94 model_->SetFeatureDim(config.feat_config.feature_dim); 93 model_->SetFeatureDim(config.feat_config.feature_dim);
95 94
96 if (config.decoding_method == "modified_beam_search") { 95 if (config.decoding_method == "modified_beam_search") {
  96 + if (!config_.model_config.bpe_vocab.empty()) {
  97 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
  98 + config_.model_config.bpe_vocab);
  99 + }
  100 +
97 if (!config_.hotwords_file.empty()) { 101 if (!config_.hotwords_file.empty()) {
98 InitHotwords(); 102 InitHotwords();
99 } 103 }
@@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
140 } 144 }
141 #endif 145 #endif
142 146
  147 + if (!config_.model_config.bpe_vocab.empty()) {
  148 + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
  149 + std::istringstream iss(std::string(buf.begin(), buf.end()));
  150 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
  151 + }
  152 +
143 if (!config_.hotwords_file.empty()) { 153 if (!config_.hotwords_file.empty()) {
144 InitHotwords(mgr); 154 InitHotwords(mgr);
145 } 155 }
@@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
174 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); 184 auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
175 std::istringstream is(hws); 185 std::istringstream is(hws);
176 std::vector<std::vector<int32_t>> current; 186 std::vector<std::vector<int32_t>> current;
177 - if (!EncodeHotwords(is, sym_, &current)) { 187 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
  188 + bpe_encoder_.get(), &current)) {
178 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", 189 SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
179 hotwords.c_str()); 190 hotwords.c_str());
180 } 191 }
@@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
363 exit(-1); 374 exit(-1);
364 } 375 }
365 376
366 - if (!EncodeHotwords(is, sym_, &hotwords_)) {  
367 - SHERPA_ONNX_LOGE("Encode hotwords failed.");  
368 - exit(-1); 377 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
  378 + bpe_encoder_.get(), &hotwords_)) {
  379 + SHERPA_ONNX_LOGE(
  380 + "Failed to encode some hotwords, skip them already, see logs above "
  381 + "for details.");
369 } 382 }
370 hotwords_graph_ = 383 hotwords_graph_ =
371 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 384 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
377 390
378 auto buf = ReadFile(mgr, config_.hotwords_file); 391 auto buf = ReadFile(mgr, config_.hotwords_file);
379 392
380 - std::istrstream is(buf.data(), buf.size()); 393 + std::istringstream is(std::string(buf.begin(), buf.end()));
381 394
382 if (!is) { 395 if (!is) {
383 SHERPA_ONNX_LOGE("Open hotwords file failed: %s", 396 SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
@@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
385 exit(-1); 398 exit(-1);
386 } 399 }
387 400
388 - if (!EncodeHotwords(is, sym_, &hotwords_)) {  
389 - SHERPA_ONNX_LOGE("Encode hotwords failed.");  
390 - exit(-1); 401 + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
  402 + bpe_encoder_.get(), &hotwords_)) {
  403 + SHERPA_ONNX_LOGE(
  404 + "Failed to encode some hotwords, skip them already, see logs above "
  405 + "for details.");
391 } 406 }
392 hotwords_graph_ = 407 hotwords_graph_ =
393 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); 408 std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
@@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
413 OnlineRecognizerConfig config_; 428 OnlineRecognizerConfig config_;
414 std::vector<std::vector<int32_t>> hotwords_; 429 std::vector<std::vector<int32_t>> hotwords_;
415 ContextGraphPtr hotwords_graph_; 430 ContextGraphPtr hotwords_graph_;
  431 + std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
416 std::unique_ptr<OnlineTransducerModel> model_; 432 std::unique_ptr<OnlineTransducerModel> model_;
417 std::unique_ptr<OnlineLM> lm_; 433 std::unique_ptr<OnlineLM> lm_;
418 std::unique_ptr<OnlineTransducerDecoder> decoder_; 434 std::unique_ptr<OnlineTransducerDecoder> decoder_;
@@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec, @@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec,
51 std::string OnlineRecognizerResult::AsJsonString() const { 51 std::string OnlineRecognizerResult::AsJsonString() const {
52 std::ostringstream os; 52 std::ostringstream os;
53 os << "{ "; 53 os << "{ ";
54 - os << "\"text\": "  
55 - << "\"" << text << "\""  
56 - << ", "; 54 + os << "\"text\": " << "\"" << text << "\"" << ", ";
57 os << "\"tokens\": " << VecToString(tokens) << ", "; 55 os << "\"tokens\": " << VecToString(tokens) << ", ";
58 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; 56 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
59 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; 57 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
@@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
89 "Used only when decoding_method is modified_beam_search"); 87 "Used only when decoding_method is modified_beam_search");
90 po->Register( 88 po->Register(
91 "hotwords-file", &hotwords_file, 89 "hotwords-file", &hotwords_file,
92 - "The file containing hotwords, one words/phrases per line, and for each"  
93 - "phrase the bpe/cjkchar are separated by a space. For example: "  
94 - "▁HE LL O ▁WORLD"  
95 - "你 好 世 界"); 90 + "The file containing hotwords, one words/phrases per line, For example: "
  91 + "HELLO WORLD"
  92 + "你好世界");
96 po->Register("decoding-method", &decoding_method, 93 po->Register("decoding-method", &decoding_method,
97 "decoding method," 94 "decoding method,"
98 "now support greedy_search and modified_beam_search."); 95 "now support greedy_search and modified_beam_search.");
@@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) { @@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
38 std::string sym; 38 std::string sym;
39 int32_t id; 39 int32_t id;
40 while (is >> sym >> id) { 40 while (is >> sym >> id) {
41 - if (sym.size() >= 3) {  
42 - // For BPE-based models, we replace ▁ with a space  
43 - // Unicode 9601, hex 0x2581, utf8 0xe29681  
44 - const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());  
45 - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {  
46 - sym = sym.replace(0, 3, " ");  
47 - }  
48 - }  
49 -  
50 - // for byte-level BPE  
51 - // id 0 is blank, id 1 is sos/eos, id 2 is unk  
52 - if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&  
53 - sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {  
54 - std::ostringstream os;  
55 - os << std::hex << std::uppercase << (id - 3);  
56 -  
57 - if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {  
58 - uint8_t i = id - 3;  
59 - sym = std::string(&i, &i + 1);  
60 - }  
61 - }  
62 -  
63 - assert(!sym.empty());  
64 -  
65 - // for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,  
66 - // there is a conflict between the real byte 0x20 and ▁, so we disable  
67 - // the following check.  
68 - //  
69 - // Note: Only id2sym_ matters as we use it to convert ID to symbols.  
70 #if 0 41 #if 0
71 // we disable the test here since for some multi-lingual BPE models 42 // we disable the test here since for some multi-lingual BPE models
72 // from NeMo, the same symbol can appear multiple times with different IDs. 43 // from NeMo, the same symbol can appear multiple times with different IDs.
@@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const { @@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
92 return os.str(); 63 return os.str();
93 } 64 }
94 65
95 -const std::string &SymbolTable::operator[](int32_t id) const {  
96 - return id2sym_.at(id); 66 +const std::string SymbolTable::operator[](int32_t id) const {
  67 + std::string sym = id2sym_.at(id);
  68 + if (sym.size() >= 3) {
  69 + // For BPE-based models, we replace ▁ with a space
  70 + // Unicode 9601, hex 0x2581, utf8 0xe29681
  71 + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
  72 + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
  73 + sym = sym.replace(0, 3, " ");
  74 + }
  75 + }
  76 +
  77 + // for byte-level BPE
  78 + // id 0 is blank, id 1 is sos/eos, id 2 is unk
  79 + if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
  80 + sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
  81 + std::ostringstream os;
  82 + os << std::hex << std::uppercase << (id - 3);
  83 +
  84 + if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
  85 + uint8_t i = id - 3;
  86 + sym = std::string(&i, &i + 1);
  87 + }
  88 + }
  89 + return sym;
97 } 90 }
98 91
99 int32_t SymbolTable::operator[](const std::string &sym) const { 92 int32_t SymbolTable::operator[](const std::string &sym) const {
@@ -35,7 +35,7 @@ class SymbolTable { @@ -35,7 +35,7 @@ class SymbolTable {
35 std::string ToString() const; 35 std::string ToString() const;
36 36
37 /// Return the symbol corresponding to the given ID. 37 /// Return the symbol corresponding to the given ID.
38 - const std::string &operator[](int32_t id) const; 38 + const std::string operator[](int32_t id) const;
39 /// Return the ID corresponding to the given symbol. 39 /// Return the ID corresponding to the given symbol.
40 int32_t operator[](const std::string &sym) const; 40 int32_t operator[](const std::string &sym) const;
41 41
  1 +// sherpa-onnx/csrc/text2token-test.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include <fstream>
  6 +#include <sstream>
  7 +#include <string>
  8 +
  9 +#include "gtest/gtest.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/utils.h"
  12 +#include "ssentencepiece/csrc/ssentencepiece.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +// Please refer to
  17 +// https://github.com/pkufool/sherpa-test-data
  18 +// to download test data for testing
  19 +static const char dir[] = "/tmp/sherpa-test-data";
  20 +
  21 +TEST(TEXT2TOKEN, TEST_cjkchar) {
  22 + std::ostringstream oss;
  23 + oss << dir << "/text2token/tokens_cn.txt";
  24 +
  25 + std::string tokens = oss.str();
  26 +
  27 + if (!std::ifstream(tokens).good()) {
  28 + SHERPA_ONNX_LOGE(
  29 + "No test data found, skipping TEST_cjkchar()."
  30 + "You can download the test data by: "
  31 + "git clone https://github.com/pkufool/sherpa-test-data.git "
  32 + "/tmp/sherpa-test-data");
  33 + return;
  34 + }
  35 +
  36 + auto sym_table = SymbolTable(tokens);
  37 +
  38 + std::string text = "世界人民大团结\n中国 V S 美国";
  39 +
  40 + std::istringstream iss(text);
  41 +
  42 + std::vector<std::vector<int32_t>> ids;
  43 +
  44 + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
  45 +
  46 + std::vector<std::vector<int32_t>> expected_ids(
  47 + {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
  48 + EXPECT_EQ(ids, expected_ids);
  49 +}
  50 +
  51 +TEST(TEXT2TOKEN, TEST_bpe) {
  52 + std::ostringstream oss;
  53 + oss << dir << "/text2token/tokens_en.txt";
  54 + std::string tokens = oss.str();
  55 + oss.clear();
  56 + oss.str("");
  57 + oss << dir << "/text2token/bpe_en.vocab";
  58 + std::string bpe = oss.str();
  59 + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
  60 + SHERPA_ONNX_LOGE(
  61 + "No test data found, skipping TEST_bpe()."
  62 + "You can download the test data by: "
  63 + "git clone https://github.com/pkufool/sherpa-test-data.git "
  64 + "/tmp/sherpa-test-data");
  65 + return;
  66 + }
  67 +
  68 + auto sym_table = SymbolTable(tokens);
  69 + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
  70 +
  71 + std::string text = "HELLO WORLD\nI LOVE YOU";
  72 +
  73 + std::istringstream iss(text);
  74 +
  75 + std::vector<std::vector<int32_t>> ids;
  76 +
  77 + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
  78 +
  79 + std::vector<std::vector<int32_t>> expected_ids(
  80 + {{22, 58, 24, 425}, {19, 370, 47}});
  81 + EXPECT_EQ(ids, expected_ids);
  82 +}
  83 +
  84 +TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
  85 + std::ostringstream oss;
  86 + oss << dir << "/text2token/tokens_mix.txt";
  87 + std::string tokens = oss.str();
  88 + oss.clear();
  89 + oss.str("");
  90 + oss << dir << "/text2token/bpe_mix.vocab";
  91 + std::string bpe = oss.str();
  92 + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
  93 + SHERPA_ONNX_LOGE(
  94 + "No test data found, skipping TEST_cjkchar_bpe()."
  95 + "You can download the test data by: "
  96 + "git clone https://github.com/pkufool/sherpa-test-data.git "
  97 + "/tmp/sherpa-test-data");
  98 + return;
  99 + }
  100 +
  101 + auto sym_table = SymbolTable(tokens);
  102 + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
  103 +
  104 + std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
  105 +
  106 + std::istringstream iss(text);
  107 +
  108 + std::vector<std::vector<int32_t>> ids;
  109 +
  110 + auto r =
  111 + EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
  112 +
  113 + std::vector<std::vector<int32_t>> expected_ids(
  114 + {{1368, 1392, 557, 680, 275, 178, 475},
  115 + {685, 736, 275, 178, 179, 921, 736}});
  116 + EXPECT_EQ(ids, expected_ids);
  117 +}
  118 +
  119 +TEST(TEXT2TOKEN, TEST_bbpe) {
  120 + std::ostringstream oss;
  121 + oss << dir << "/text2token/tokens_bbpe.txt";
  122 + std::string tokens = oss.str();
  123 + oss.clear();
  124 + oss.str("");
  125 + oss << dir << "/text2token/bbpe.vocab";
  126 + std::string bpe = oss.str();
  127 + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
  128 + SHERPA_ONNX_LOGE(
  129 + "No test data found, skipping TEST_bbpe()."
  130 + "You can download the test data by: "
  131 + "git clone https://github.com/pkufool/sherpa-test-data.git "
  132 + "/tmp/sherpa-test-data");
  133 + return;
  134 + }
  135 +
  136 + auto sym_table = SymbolTable(tokens);
  137 + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
  138 +
  139 + std::string text = "频繁\n李鞑靼";
  140 +
  141 + std::istringstream iss(text);
  142 +
  143 + std::vector<std::vector<int32_t>> ids;
  144 +
  145 + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
  146 +
  147 + std::vector<std::vector<int32_t>> expected_ids(
  148 + {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
  149 + EXPECT_EQ(ids, expected_ids);
  150 +}
  151 +
  152 +} // namespace sherpa_onnx
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/utils.h" 5 #include "sherpa-onnx/csrc/utils.h"
6 6
  7 +#include <cassert>
7 #include <iostream> 8 #include <iostream>
8 #include <sstream> 9 #include <sstream>
9 #include <string> 10 #include <string>
@@ -12,15 +13,16 @@ @@ -12,15 +13,16 @@
12 13
13 #include "sherpa-onnx/csrc/log.h" 14 #include "sherpa-onnx/csrc/log.h"
14 #include "sherpa-onnx/csrc/macros.h" 15 #include "sherpa-onnx/csrc/macros.h"
  16 +#include "sherpa-onnx/csrc/text-utils.h"
15 17
16 namespace sherpa_onnx { 18 namespace sherpa_onnx {
17 19
18 -static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, 20 +static bool EncodeBase(const std::vector<std::string> &lines,
  21 + const SymbolTable &symbol_table,
19 std::vector<std::vector<int32_t>> *ids, 22 std::vector<std::vector<int32_t>> *ids,
20 std::vector<std::string> *phrases, 23 std::vector<std::string> *phrases,
21 std::vector<float> *scores, 24 std::vector<float> *scores,
22 std::vector<float> *thresholds) { 25 std::vector<float> *thresholds) {
23 - SHERPA_ONNX_CHECK(ids != nullptr);  
24 ids->clear(); 26 ids->clear();
25 27
26 std::vector<int32_t> tmp_ids; 28 std::vector<int32_t> tmp_ids;
@@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, @@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
33 bool has_scores = false; 35 bool has_scores = false;
34 bool has_thresholds = false; 36 bool has_thresholds = false;
35 bool has_phrases = false; 37 bool has_phrases = false;
  38 + bool has_oov = false;
36 39
37 - while (std::getline(is, line)) { 40 + for (const auto &line : lines) {
38 float score = 0; 41 float score = 0;
39 float threshold = 0; 42 float threshold = 0;
40 std::string phrase = ""; 43 std::string phrase = "";
41 44
42 std::istringstream iss(line); 45 std::istringstream iss(line);
43 while (iss >> word) { 46 while (iss >> word) {
44 - if (word.size() >= 3) {  
45 - // For BPE-based models, we replace ▁ with a space  
46 - // Unicode 9601, hex 0x2581, utf8 0xe29681  
47 - const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());  
48 - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {  
49 - word = word.replace(0, 3, " ");  
50 - }  
51 - }  
52 if (symbol_table.Contains(word)) { 47 if (symbol_table.Contains(word)) {
53 int32_t id = symbol_table[word]; 48 int32_t id = symbol_table[word];
54 tmp_ids.push_back(id); 49 tmp_ids.push_back(id);
@@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, @@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
71 "Cannot find ID for token %s at line: %s. (Hint: words on " 66 "Cannot find ID for token %s at line: %s. (Hint: words on "
72 "the same line are separated by spaces)", 67 "the same line are separated by spaces)",
73 word.c_str(), line.c_str()); 68 word.c_str(), line.c_str());
74 - return false; 69 + has_oov = true;
  70 + break;
75 } 71 }
76 } 72 }
77 } 73 }
@@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, @@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
101 thresholds->clear(); 97 thresholds->clear();
102 } 98 }
103 } 99 }
104 - return true; 100 + return !has_oov;
105 } 101 }
106 102
107 -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, 103 +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
  104 + const SymbolTable &symbol_table,
  105 + const ssentencepiece::Ssentencepiece *bpe_encoder,
108 std::vector<std::vector<int32_t>> *hotwords) { 106 std::vector<std::vector<int32_t>> *hotwords) {
109 - return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr); 107 + std::vector<std::string> lines;
  108 + std::string line;
  109 + std::string word;
  110 +
  111 + while (std::getline(is, line)) {
  112 + std::string score;
  113 + std::string phrase;
  114 +
  115 + std::ostringstream oss;
  116 + std::istringstream iss(line);
  117 + while (iss >> word) {
  118 + switch (word[0]) {
  119 + case ':': // boosting score for current keyword
  120 + score = word;
  121 + break;
  122 + default:
  123 + if (!score.empty()) {
  124 + SHERPA_ONNX_LOGE(
  125 + "Boosting score should be put after the words/phrase, given "
  126 + "%s.",
  127 + line.c_str());
  128 + return false;
  129 + }
  130 + oss << " " << word;
  131 + break;
  132 + }
  133 + }
  134 + phrase = oss.str().substr(1);
  135 + std::istringstream piss(phrase);
  136 + oss.clear();
  137 + oss.str("");
  138 + while (piss >> word) {
  139 + if (modeling_unit == "cjkchar") {
  140 + for (const auto &w : SplitUtf8(word)) {
  141 + oss << " " << w;
  142 + }
  143 + } else if (modeling_unit == "bpe") {
  144 + std::vector<std::string> bpes;
  145 + bpe_encoder->Encode(word, &bpes);
  146 + for (const auto &bpe : bpes) {
  147 + oss << " " << bpe;
  148 + }
  149 + } else {
  150 + if (modeling_unit != "cjkchar+bpe") {
  151 + SHERPA_ONNX_LOGE(
  152 + "modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
  153 + "given "
  154 + "%s",
  155 + modeling_unit.c_str());
  156 + exit(-1);
  157 + }
  158 + for (const auto &w : SplitUtf8(word)) {
  159 + if (isalpha(w[0])) {
  160 + std::vector<std::string> bpes;
  161 + bpe_encoder->Encode(w, &bpes);
  162 + for (const auto &bpe : bpes) {
  163 + oss << " " << bpe;
  164 + }
  165 + } else {
  166 + oss << " " << w;
  167 + }
  168 + }
  169 + }
  170 + }
  171 + std::string encoded_phrase = oss.str().substr(1);
  172 + oss.clear();
  173 + oss.str("");
  174 + oss << encoded_phrase;
  175 + if (!score.empty()) {
  176 + oss << " " << score;
  177 + }
  178 + lines.push_back(oss.str());
  179 + }
  180 + return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
110 } 181 }
111 182
112 bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, 183 bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
@@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, @@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
114 std::vector<std::string> *keywords, 185 std::vector<std::string> *keywords,
115 std::vector<float> *boost_scores, 186 std::vector<float> *boost_scores,
116 std::vector<float> *threshold) { 187 std::vector<float> *threshold) {
117 - return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores, 188 + std::vector<std::string> lines;
  189 + std::string line;
  190 + while (std::getline(is, line)) {
  191 + lines.push_back(line);
  192 + }
  193 + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
118 threshold); 194 threshold);
119 } 195 }
120 196
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/symbol-table.h" 10 #include "sherpa-onnx/csrc/symbol-table.h"
  11 +#include "ssentencepiece/csrc/ssentencepiece.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
@@ -25,7 +26,9 @@ namespace sherpa_onnx { @@ -25,7 +26,9 @@ namespace sherpa_onnx {
25 * @return If all the symbols from ``is`` are in the symbol_table, returns true 26 * @return If all the symbols from ``is`` are in the symbol_table, returns true
26 * otherwise returns false. 27 * otherwise returns false.
27 */ 28 */
28 -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, 29 +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
  30 + const SymbolTable &symbol_table,
  31 + const ssentencepiece::Ssentencepiece *bpe_encoder_,
29 std::vector<std::vector<int32_t>> *hotwords_id); 32 std::vector<std::vector<int32_t>> *hotwords_id);
30 33
31 /* Encode the keywords in an input stream to be tokens ids. 34 /* Encode the keywords in an input stream to be tokens ids.
@@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { @@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
76 ans.model_config.model_type = p; 76 ans.model_config.model_type = p;
77 env->ReleaseStringUTFChars(s, p); 77 env->ReleaseStringUTFChars(s, p);
78 78
  79 + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
  80 + s = (jstring)env->GetObjectField(model_config, fid);
  81 + p = env->GetStringUTFChars(s, nullptr);
  82 + ans.model_config.modeling_unit = p;
  83 + env->ReleaseStringUTFChars(s, p);
  84 +
  85 + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
  86 + s = (jstring)env->GetObjectField(model_config, fid);
  87 + p = env->GetStringUTFChars(s, nullptr);
  88 + ans.model_config.bpe_vocab = p;
  89 + env->ReleaseStringUTFChars(s, p);
  90 +
79 // transducer 91 // transducer
80 fid = env->GetFieldID(model_config_cls, "transducer", 92 fid = env->GetFieldID(model_config_cls, "transducer",
81 "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); 93 "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
@@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
195 ans.model_config.model_type = p; 195 ans.model_config.model_type = p;
196 env->ReleaseStringUTFChars(s, p); 196 env->ReleaseStringUTFChars(s, p);
197 197
  198 + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
  199 + s = (jstring)env->GetObjectField(model_config, fid);
  200 + p = env->GetStringUTFChars(s, nullptr);
  201 + ans.model_config.modeling_unit = p;
  202 + env->ReleaseStringUTFChars(s, p);
  203 +
  204 + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
  205 + s = (jstring)env->GetObjectField(model_config, fid);
  206 + p = env->GetStringUTFChars(s, nullptr);
  207 + ans.model_config.bpe_vocab = p;
  208 + env->ReleaseStringUTFChars(s, p);
  209 +
198 //---------- rnn lm model config ---------- 210 //---------- rnn lm model config ----------
199 fid = env->GetFieldID(cls, "lmConfig", 211 fid = env->GetFieldID(cls, "lmConfig",
200 "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); 212 "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
@@ -40,6 +40,8 @@ data class OfflineModelConfig( @@ -40,6 +40,8 @@ data class OfflineModelConfig(
40 var provider: String = "cpu", 40 var provider: String = "cpu",
41 var modelType: String = "", 41 var modelType: String = "",
42 var tokens: String, 42 var tokens: String,
  43 + var modelingUnit: String = "",
  44 + var bpeVocab: String = "",
43 ) 45 )
44 46
45 data class OfflineRecognizerConfig( 47 data class OfflineRecognizerConfig(
@@ -43,6 +43,8 @@ data class OnlineModelConfig( @@ -43,6 +43,8 @@ data class OnlineModelConfig(
43 var debug: Boolean = false, 43 var debug: Boolean = false,
44 var provider: String = "cpu", 44 var provider: String = "cpu",
45 var modelType: String = "", 45 var modelType: String = "",
  46 + var modelingUnit: String = "",
  47 + var bpeVocab: String = "",
46 ) 48 )
47 49
48 data class OnlineLMConfig( 50 data class OnlineLMConfig(
@@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) {
36 const OfflineTdnnModelConfig &, 36 const OfflineTdnnModelConfig &,
37 const OfflineZipformerCtcModelConfig &, 37 const OfflineZipformerCtcModelConfig &,
38 const OfflineWenetCtcModelConfig &, const std::string &, 38 const OfflineWenetCtcModelConfig &, const std::string &,
39 - int32_t, bool, const std::string &, const std::string &>(), 39 + int32_t, bool, const std::string &, const std::string &,
  40 + const std::string &, const std::string &>(),
40 py::arg("transducer") = OfflineTransducerModelConfig(), 41 py::arg("transducer") = OfflineTransducerModelConfig(),
41 py::arg("paraformer") = OfflineParaformerModelConfig(), 42 py::arg("paraformer") = OfflineParaformerModelConfig(),
42 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), 43 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
@@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) {
45 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), 46 py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
46 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), 47 py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
47 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, 48 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
48 - py::arg("provider") = "cpu", py::arg("model_type") = "") 49 + py::arg("provider") = "cpu", py::arg("model_type") = "",
  50 + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
49 .def_readwrite("transducer", &PyClass::transducer) 51 .def_readwrite("transducer", &PyClass::transducer)
50 .def_readwrite("paraformer", &PyClass::paraformer) 52 .def_readwrite("paraformer", &PyClass::paraformer)
51 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 53 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
@@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) { @@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) {
58 .def_readwrite("debug", &PyClass::debug) 60 .def_readwrite("debug", &PyClass::debug)
59 .def_readwrite("provider", &PyClass::provider) 61 .def_readwrite("provider", &PyClass::provider)
60 .def_readwrite("model_type", &PyClass::model_type) 62 .def_readwrite("model_type", &PyClass::model_type)
  63 + .def_readwrite("modeling_unit", &PyClass::modeling_unit)
  64 + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
61 .def("validate", &PyClass::Validate) 65 .def("validate", &PyClass::Validate)
62 .def("__str__", &PyClass::ToString); 66 .def("__str__", &PyClass::ToString);
63 } 67 }
@@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) { @@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) {
32 const OnlineZipformer2CtcModelConfig &, 32 const OnlineZipformer2CtcModelConfig &,
33 const OnlineNeMoCtcModelConfig &, const std::string &, 33 const OnlineNeMoCtcModelConfig &, const std::string &,
34 int32_t, int32_t, bool, const std::string &, 34 int32_t, int32_t, bool, const std::string &,
  35 + const std::string &, const std::string &,
35 const std::string &>(), 36 const std::string &>(),
36 py::arg("transducer") = OnlineTransducerModelConfig(), 37 py::arg("transducer") = OnlineTransducerModelConfig(),
37 py::arg("paraformer") = OnlineParaformerModelConfig(), 38 py::arg("paraformer") = OnlineParaformerModelConfig(),
@@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) { @@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) {
40 py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), 41 py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),
41 py::arg("num_threads"), py::arg("warm_up") = 0, 42 py::arg("num_threads"), py::arg("warm_up") = 0,
42 py::arg("debug") = false, py::arg("provider") = "cpu", 43 py::arg("debug") = false, py::arg("provider") = "cpu",
43 - py::arg("model_type") = "") 44 + py::arg("model_type") = "", py::arg("modeling_unit") = "",
  45 + py::arg("bpe_vocab") = "")
44 .def_readwrite("transducer", &PyClass::transducer) 46 .def_readwrite("transducer", &PyClass::transducer)
45 .def_readwrite("paraformer", &PyClass::paraformer) 47 .def_readwrite("paraformer", &PyClass::paraformer)
46 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 48 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
@@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) { @@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) {
51 .def_readwrite("debug", &PyClass::debug) 53 .def_readwrite("debug", &PyClass::debug)
52 .def_readwrite("provider", &PyClass::provider) 54 .def_readwrite("provider", &PyClass::provider)
53 .def_readwrite("model_type", &PyClass::model_type) 55 .def_readwrite("model_type", &PyClass::model_type)
  56 + .def_readwrite("modeling_unit", &PyClass::modeling_unit)
  57 + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
54 .def("validate", &PyClass::Validate) 58 .def("validate", &PyClass::Validate)
55 .def("__str__", &PyClass::ToString); 59 .def("__str__", &PyClass::ToString);
56 } 60 }
@@ -49,6 +49,8 @@ class OfflineRecognizer(object): @@ -49,6 +49,8 @@ class OfflineRecognizer(object):
49 hotwords_file: str = "", 49 hotwords_file: str = "",
50 hotwords_score: float = 1.5, 50 hotwords_score: float = 1.5,
51 blank_penalty: float = 0.0, 51 blank_penalty: float = 0.0,
  52 + modeling_unit: str = "cjkchar",
  53 + bpe_vocab: str = "",
52 debug: bool = False, 54 debug: bool = False,
53 provider: str = "cpu", 55 provider: str = "cpu",
54 model_type: str = "transducer", 56 model_type: str = "transducer",
@@ -91,6 +93,16 @@ class OfflineRecognizer(object): @@ -91,6 +93,16 @@ class OfflineRecognizer(object):
91 hotwords_file is given with modified_beam_search as decoding method. 93 hotwords_file is given with modified_beam_search as decoding method.
92 blank_penalty: 94 blank_penalty:
93 The penalty applied on blank symbol during decoding. 95 The penalty applied on blank symbol during decoding.
  96 + modeling_unit:
  97 + The modeling unit of the model, commonly used units are bpe, cjkchar,
  98 + cjkchar+bpe, etc. Currently, it is needed only when hotwords are
  99 + provided, we need it to encode the hotwords into token sequence.
  100 + and the modeling unit is bpe or cjkchar+bpe.
  101 + bpe_vocab:
  102 + The vocabulary generated by google's sentencepiece program.
  103 + It is a file has two columns, one is the token, the other is
  104 + the log probability, you can get it from the directory where
  105 + your bpe model is generated. Only used when hotwords provided
94 debug: 106 debug:
95 True to show debug messages. 107 True to show debug messages.
96 provider: 108 provider:
@@ -107,6 +119,8 @@ class OfflineRecognizer(object): @@ -107,6 +119,8 @@ class OfflineRecognizer(object):
107 num_threads=num_threads, 119 num_threads=num_threads,
108 debug=debug, 120 debug=debug,
109 provider=provider, 121 provider=provider,
  122 + modeling_unit=modeling_unit,
  123 + bpe_vocab=bpe_vocab,
110 model_type=model_type, 124 model_type=model_type,
111 ) 125 )
112 126
@@ -58,6 +58,8 @@ class OnlineRecognizer(object): @@ -58,6 +58,8 @@ class OnlineRecognizer(object):
58 hotwords_file: str = "", 58 hotwords_file: str = "",
59 provider: str = "cpu", 59 provider: str = "cpu",
60 model_type: str = "", 60 model_type: str = "",
  61 + modeling_unit: str = "cjkchar",
  62 + bpe_vocab: str = "",
61 lm: str = "", 63 lm: str = "",
62 lm_scale: float = 0.1, 64 lm_scale: float = 0.1,
63 temperature_scale: float = 2.0, 65 temperature_scale: float = 2.0,
@@ -136,6 +138,16 @@ class OnlineRecognizer(object): @@ -136,6 +138,16 @@ class OnlineRecognizer(object):
136 model_type: 138 model_type:
137 Online transducer model type. Valid values are: conformer, lstm, 139 Online transducer model type. Valid values are: conformer, lstm,
138 zipformer, zipformer2. All other values lead to loading the model twice. 140 zipformer, zipformer2. All other values lead to loading the model twice.
  141 + modeling_unit:
  142 + The modeling unit of the model, commonly used units are bpe, cjkchar,
  143 + cjkchar+bpe, etc. Currently, it is needed only when hotwords are
  144 + provided, we need it to encode the hotwords into token sequence.
  145 + bpe_vocab:
  146 + The vocabulary generated by google's sentencepiece program.
  147 + It is a file has two columns, one is the token, the other is
  148 + the log probability, you can get it from the directory where
  149 + your bpe model is generated. Only used when hotwords provided
  150 + and the modeling unit is bpe or cjkchar+bpe.
139 """ 151 """
140 self = cls.__new__(cls) 152 self = cls.__new__(cls)
141 _assert_file_exists(tokens) 153 _assert_file_exists(tokens)
@@ -157,6 +169,8 @@ class OnlineRecognizer(object): @@ -157,6 +169,8 @@ class OnlineRecognizer(object):
157 num_threads=num_threads, 169 num_threads=num_threads,
158 provider=provider, 170 provider=provider,
159 model_type=model_type, 171 model_type=model_type,
  172 + modeling_unit=modeling_unit,
  173 + bpe_vocab=bpe_vocab,
160 debug=debug, 174 debug=debug,
161 ) 175 )
162 176