Committed by
GitHub
Encode hotwords in C++ side (#828)
* Encode hotwords in C++ side
正在显示
43 个修改的文件
包含
713 行增加
和
101 行删除
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | log "test online NeMo CTC" | 13 | log "test online NeMo CTC" |
| 12 | 14 | ||
| 13 | url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 | 15 | url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 |
| @@ -8,6 +8,8 @@ log() { | @@ -8,6 +8,8 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +export GIT_CLONE_PROTECTION_ACTIVE=false | ||
| 12 | + | ||
| 11 | echo "EXE is $EXE" | 13 | echo "EXE is $EXE" |
| 12 | echo "PATH: $PATH" | 14 | echo "PATH: $PATH" |
| 13 | 15 |
| @@ -234,6 +234,7 @@ endif() | @@ -234,6 +234,7 @@ endif() | ||
| 234 | include(kaldi-native-fbank) | 234 | include(kaldi-native-fbank) |
| 235 | include(kaldi-decoder) | 235 | include(kaldi-decoder) |
| 236 | include(onnxruntime) | 236 | include(onnxruntime) |
| 237 | +include(simple-sentencepiece) | ||
| 237 | set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) | 238 | set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) |
| 238 | message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") | 239 | message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") |
| 239 | 240 |
| @@ -126,7 +126,7 @@ echo "Generate xcframework" | @@ -126,7 +126,7 @@ echo "Generate xcframework" | ||
| 126 | 126 | ||
| 127 | mkdir -p "build/simulator/lib" | 127 | mkdir -p "build/simulator/lib" |
| 128 | for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ | 128 | for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ |
| 129 | - libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a; do | 129 | + libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a libssentencepiece_core.a; do |
| 130 | lipo -create build/simulator_arm64/lib/${f} \ | 130 | lipo -create build/simulator_arm64/lib/${f} \ |
| 131 | build/simulator_x86_64/lib/${f} \ | 131 | build/simulator_x86_64/lib/${f} \ |
| 132 | -output build/simulator/lib/${f} | 132 | -output build/simulator/lib/${f} |
| @@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \ | @@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \ | ||
| 140 | build/simulator/lib/libsherpa-onnx-core.a \ | 140 | build/simulator/lib/libsherpa-onnx-core.a \ |
| 141 | build/simulator/lib/libsherpa-onnx-fst.a \ | 141 | build/simulator/lib/libsherpa-onnx-fst.a \ |
| 142 | build/simulator/lib/libsherpa-onnx-kaldifst-core.a \ | 142 | build/simulator/lib/libsherpa-onnx-kaldifst-core.a \ |
| 143 | - build/simulator/lib/libkaldi-decoder-core.a | 143 | + build/simulator/lib/libkaldi-decoder-core.a \ |
| 144 | + build/simulator/lib/libssentencepiece_core.a | ||
| 144 | 145 | ||
| 145 | libtool -static -o build/os64/sherpa-onnx.a \ | 146 | libtool -static -o build/os64/sherpa-onnx.a \ |
| 146 | build/os64/lib/libkaldi-native-fbank-core.a \ | 147 | build/os64/lib/libkaldi-native-fbank-core.a \ |
| @@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \ | @@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \ | ||
| 148 | build/os64/lib/libsherpa-onnx-core.a \ | 149 | build/os64/lib/libsherpa-onnx-core.a \ |
| 149 | build/os64/lib/libsherpa-onnx-fst.a \ | 150 | build/os64/lib/libsherpa-onnx-fst.a \ |
| 150 | build/os64/lib/libsherpa-onnx-kaldifst-core.a \ | 151 | build/os64/lib/libsherpa-onnx-kaldifst-core.a \ |
| 151 | - build/os64/lib/libkaldi-decoder-core.a | 152 | + build/os64/lib/libkaldi-decoder-core.a \ |
| 153 | + build/os64/lib/libssentencepiece_core.a | ||
| 152 | 154 | ||
| 153 | rm -rf sherpa-onnx.xcframework | 155 | rm -rf sherpa-onnx.xcframework |
| 154 | 156 |
| @@ -129,7 +129,7 @@ echo "Generate xcframework" | @@ -129,7 +129,7 @@ echo "Generate xcframework" | ||
| 129 | 129 | ||
| 130 | mkdir -p "build/simulator/lib" | 130 | mkdir -p "build/simulator/lib" |
| 131 | for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ | 131 | for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \ |
| 132 | - libsherpa-onnx-fstfar.a \ | 132 | + libsherpa-onnx-fstfar.a libssentencepiece_core.a \ |
| 133 | libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \ | 133 | libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \ |
| 134 | libucd.a libpiper_phonemize.a libespeak-ng.a; do | 134 | libucd.a libpiper_phonemize.a libespeak-ng.a; do |
| 135 | lipo -create build/simulator_arm64/lib/${f} \ | 135 | lipo -create build/simulator_arm64/lib/${f} \ |
| @@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \ | @@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \ | ||
| 150 | build/simulator/lib/libucd.a \ | 150 | build/simulator/lib/libucd.a \ |
| 151 | build/simulator/lib/libpiper_phonemize.a \ | 151 | build/simulator/lib/libpiper_phonemize.a \ |
| 152 | build/simulator/lib/libespeak-ng.a \ | 152 | build/simulator/lib/libespeak-ng.a \ |
| 153 | + build/simulator/lib/libssentencepiece_core.a | ||
| 153 | 154 | ||
| 154 | libtool -static -o build/os64/sherpa-onnx.a \ | 155 | libtool -static -o build/os64/sherpa-onnx.a \ |
| 155 | build/os64/lib/libkaldi-native-fbank-core.a \ | 156 | build/os64/lib/libkaldi-native-fbank-core.a \ |
| @@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \ | @@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \ | ||
| 162 | build/os64/lib/libucd.a \ | 163 | build/os64/lib/libucd.a \ |
| 163 | build/os64/lib/libpiper_phonemize.a \ | 164 | build/os64/lib/libpiper_phonemize.a \ |
| 164 | build/os64/lib/libespeak-ng.a \ | 165 | build/os64/lib/libespeak-ng.a \ |
| 166 | + build/os64/lib/libssentencepiece_core.a | ||
| 165 | 167 | ||
| 166 | 168 | ||
| 167 | rm -rf sherpa-onnx.xcframework | 169 | rm -rf sherpa-onnx.xcframework |
| @@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \ | @@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \ | ||
| 33 | ./install/lib/libkaldi-decoder-core.a \ | 33 | ./install/lib/libkaldi-decoder-core.a \ |
| 34 | ./install/lib/libucd.a \ | 34 | ./install/lib/libucd.a \ |
| 35 | ./install/lib/libpiper_phonemize.a \ | 35 | ./install/lib/libpiper_phonemize.a \ |
| 36 | - ./install/lib/libespeak-ng.a | 36 | + ./install/lib/libespeak-ng.a \ |
| 37 | + ./install/lib/libssentencepiece_core.a |
cmake/simple-sentencepiece.cmake
0 → 100644
| 1 | +function(download_simple_sentencepiece) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz") | ||
| 5 | + set(simple-sentencepiece_URL2 "https://hub.nauu.cf/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz") | ||
| 6 | + set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download simple-sentencepiece | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz | ||
| 12 | + ${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz | ||
| 13 | + ${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz | ||
| 14 | + /tmp/simple-sentencepiece-0.7.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(simple-sentencepiece_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL) | ||
| 22 | + message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}") | ||
| 23 | + set(simple-sentencepiece_URL2) | ||
| 24 | + break() | ||
| 25 | + endif() | ||
| 26 | + endforeach() | ||
| 27 | + | ||
| 28 | + set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE) | ||
| 29 | + set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 30 | + | ||
| 31 | + FetchContent_Declare(simple-sentencepiece | ||
| 32 | + URL | ||
| 33 | + ${simple-sentencepiece_URL} | ||
| 34 | + ${simple-sentencepiece_URL2} | ||
| 35 | + URL_HASH | ||
| 36 | + ${simple-sentencepiece_HASH} | ||
| 37 | + ) | ||
| 38 | + | ||
| 39 | + FetchContent_GetProperties(simple-sentencepiece) | ||
| 40 | + if(NOT simple-sentencepiece_POPULATED) | ||
| 41 | + message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}") | ||
| 42 | + FetchContent_Populate(simple-sentencepiece) | ||
| 43 | + endif() | ||
| 44 | + message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}") | ||
| 45 | + add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 46 | + | ||
| 47 | + target_include_directories(ssentencepiece_core | ||
| 48 | + PUBLIC | ||
| 49 | + ${simple-sentencepiece_SOURCE_DIR}/ | ||
| 50 | + ) | ||
| 51 | + | ||
| 52 | + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) | ||
| 53 | + install(TARGETS ssentencepiece_core DESTINATION ..) | ||
| 54 | + else() | ||
| 55 | + install(TARGETS ssentencepiece_core DESTINATION lib) | ||
| 56 | + endif() | ||
| 57 | + | ||
| 58 | + if(WIN32 AND BUILD_SHARED_LIBS) | ||
| 59 | + install(TARGETS ssentencepiece_core DESTINATION bin) | ||
| 60 | + endif() | ||
| 61 | +endfunction() | ||
| 62 | + | ||
| 63 | +download_simple_sentencepiece() |
| @@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() { | @@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() { | ||
| 60 | function testOnlineAsr() { | 60 | function testOnlineAsr() { |
| 61 | if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then | 61 | if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then |
| 62 | git lfs install | 62 | git lfs install |
| 63 | - git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 | 63 | + GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 |
| 64 | fi | 64 | fi |
| 65 | 65 | ||
| 66 | if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then | 66 | if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | piper_phonemize.lib; | 18 | piper_phonemize.lib; |
| 19 | espeak-ng.lib; | 19 | espeak-ng.lib; |
| 20 | ucd.lib; | 20 | ucd.lib; |
| 21 | + ssentencepiece_core.lib; | ||
| 21 | </SherpaOnnxLibraries> | 22 | </SherpaOnnxLibraries> |
| 22 | </PropertyGroup> | 23 | </PropertyGroup> |
| 23 | <ItemDefinitionGroup> | 24 | <ItemDefinitionGroup> |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | piper_phonemize.lib; | 18 | piper_phonemize.lib; |
| 19 | espeak-ng.lib; | 19 | espeak-ng.lib; |
| 20 | ucd.lib; | 20 | ucd.lib; |
| 21 | + ssentencepiece_core.lib; | ||
| 21 | </SherpaOnnxLibraries> | 22 | </SherpaOnnxLibraries> |
| 22 | </PropertyGroup> | 23 | </PropertyGroup> |
| 23 | <ItemDefinitionGroup> | 24 | <ItemDefinitionGroup> |
| @@ -18,6 +18,7 @@ | @@ -18,6 +18,7 @@ | ||
| 18 | piper_phonemize.lib; | 18 | piper_phonemize.lib; |
| 19 | espeak-ng.lib; | 19 | espeak-ng.lib; |
| 20 | ucd.lib; | 20 | ucd.lib; |
| 21 | + ssentencepiece_core.lib; | ||
| 21 | </SherpaOnnxLibraries> | 22 | </SherpaOnnxLibraries> |
| 22 | </PropertyGroup> | 23 | </PropertyGroup> |
| 23 | <ItemDefinitionGroup> | 24 | <ItemDefinitionGroup> |
| @@ -110,11 +110,9 @@ def get_args(): | @@ -110,11 +110,9 @@ def get_args(): | ||
| 110 | type=str, | 110 | type=str, |
| 111 | default="", | 111 | default="", |
| 112 | help=""" | 112 | help=""" |
| 113 | - The file containing hotwords, one words/phrases per line, and for each | ||
| 114 | - phrase the bpe/cjkchar are separated by a space. For example: | ||
| 115 | - | ||
| 116 | - ▁HE LL O ▁WORLD | ||
| 117 | - 你 好 世 界 | 113 | + The file containing hotwords, one words/phrases per line, like |
| 114 | + HELLO WORLD | ||
| 115 | + 你好世界 | ||
| 118 | """, | 116 | """, |
| 119 | ) | 117 | ) |
| 120 | 118 | ||
| @@ -129,6 +127,28 @@ def get_args(): | @@ -129,6 +127,28 @@ def get_args(): | ||
| 129 | ) | 127 | ) |
| 130 | 128 | ||
| 131 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | + "--modeling-unit", | ||
| 131 | + type=str, | ||
| 132 | + default="", | ||
| 133 | + help=""" | ||
| 134 | + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. | ||
| 135 | + Used only when hotwords-file is given. | ||
| 136 | + """, | ||
| 137 | + ) | ||
| 138 | + | ||
| 139 | + parser.add_argument( | ||
| 140 | + "--bpe-vocab", | ||
| 141 | + type=str, | ||
| 142 | + default="", | ||
| 143 | + help=""" | ||
| 144 | + The path to the bpe vocabulary, the bpe vocabulary is generated by | ||
| 145 | + sentencepiece, you can also export the bpe vocabulary through a bpe model | ||
| 146 | + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given | ||
| 147 | + and modeling-unit is bpe or cjkchar+bpe. | ||
| 148 | + """, | ||
| 149 | + ) | ||
| 150 | + | ||
| 151 | + parser.add_argument( | ||
| 132 | "--encoder", | 152 | "--encoder", |
| 133 | default="", | 153 | default="", |
| 134 | type=str, | 154 | type=str, |
| @@ -347,6 +367,8 @@ def main(): | @@ -347,6 +367,8 @@ def main(): | ||
| 347 | decoding_method=args.decoding_method, | 367 | decoding_method=args.decoding_method, |
| 348 | hotwords_file=args.hotwords_file, | 368 | hotwords_file=args.hotwords_file, |
| 349 | hotwords_score=args.hotwords_score, | 369 | hotwords_score=args.hotwords_score, |
| 370 | + modeling_unit=args.modeling_unit, | ||
| 371 | + bpe_vocab=args.bpe_vocab, | ||
| 350 | blank_penalty=args.blank_penalty, | 372 | blank_penalty=args.blank_penalty, |
| 351 | debug=args.debug, | 373 | debug=args.debug, |
| 352 | ) | 374 | ) |
| @@ -198,11 +198,9 @@ def get_args(): | @@ -198,11 +198,9 @@ def get_args(): | ||
| 198 | type=str, | 198 | type=str, |
| 199 | default="", | 199 | default="", |
| 200 | help=""" | 200 | help=""" |
| 201 | - The file containing hotwords, one words/phrases per line, and for each | ||
| 202 | - phrase the bpe/cjkchar are separated by a space. For example: | ||
| 203 | - | ||
| 204 | - ▁HE LL O ▁WORLD | ||
| 205 | - 你 好 世 界 | 201 | + The file containing hotwords, one words/phrases per line, like |
| 202 | + HELLO WORLD | ||
| 203 | + 你好世界 | ||
| 206 | """, | 204 | """, |
| 207 | ) | 205 | ) |
| 208 | 206 | ||
| @@ -217,6 +215,28 @@ def get_args(): | @@ -217,6 +215,28 @@ def get_args(): | ||
| 217 | ) | 215 | ) |
| 218 | 216 | ||
| 219 | parser.add_argument( | 217 | parser.add_argument( |
| 218 | + "--modeling-unit", | ||
| 219 | + type=str, | ||
| 220 | + default="", | ||
| 221 | + help=""" | ||
| 222 | + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. | ||
| 223 | + Used only when hotwords-file is given. | ||
| 224 | + """, | ||
| 225 | + ) | ||
| 226 | + | ||
| 227 | + parser.add_argument( | ||
| 228 | + "--bpe-vocab", | ||
| 229 | + type=str, | ||
| 230 | + default="", | ||
| 231 | + help=""" | ||
| 232 | + The path to the bpe vocabulary, the bpe vocabulary is generated by | ||
| 233 | + sentencepiece, you can also export the bpe vocabulary through a bpe model | ||
| 234 | + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given | ||
| 235 | + and modeling-unit is bpe or cjkchar+bpe. | ||
| 236 | + """, | ||
| 237 | + ) | ||
| 238 | + | ||
| 239 | + parser.add_argument( | ||
| 220 | "--blank-penalty", | 240 | "--blank-penalty", |
| 221 | type=float, | 241 | type=float, |
| 222 | default=0.0, | 242 | default=0.0, |
| @@ -302,6 +322,8 @@ def main(): | @@ -302,6 +322,8 @@ def main(): | ||
| 302 | lm_scale=args.lm_scale, | 322 | lm_scale=args.lm_scale, |
| 303 | hotwords_file=args.hotwords_file, | 323 | hotwords_file=args.hotwords_file, |
| 304 | hotwords_score=args.hotwords_score, | 324 | hotwords_score=args.hotwords_score, |
| 325 | + modeling_unit=args.modeling_unit, | ||
| 326 | + bpe_vocab=args.bpe_vocab, | ||
| 305 | blank_penalty=args.blank_penalty, | 327 | blank_penalty=args.blank_penalty, |
| 306 | ) | 328 | ) |
| 307 | elif args.zipformer2_ctc: | 329 | elif args.zipformer2_ctc: |
scripts/export_bpe_vocab.py
0 → 100755
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang) | ||
| 3 | +# | ||
| 4 | +# See ../../../../LICENSE for clarification regarding multiple authors | ||
| 5 | +# | ||
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 7 | +# you may not use this file except in compliance with the License. | ||
| 8 | +# You may obtain a copy of the License at | ||
| 9 | +# | ||
| 10 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
| 11 | +# | ||
| 12 | +# Unless required by applicable law or agreed to in writing, software | ||
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
| 14 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 15 | +# See the License for the specific language governing permissions and | ||
| 16 | +# limitations under the License. | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +# You can install sentencepiece via: | ||
| 20 | +# | ||
| 21 | +# pip install sentencepiece | ||
| 22 | +# | ||
| 23 | +# Due to an issue reported in | ||
| 24 | +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 | ||
| 25 | +# | ||
| 26 | +# Please install a version >=0.1.96 | ||
| 27 | + | ||
| 28 | +import argparse | ||
| 29 | +from typing import Dict | ||
| 30 | + | ||
| 31 | +try: | ||
| 32 | + import sentencepiece as spm | ||
| 33 | +except ImportError: | ||
| 34 | + print('Please run') | ||
| 35 | + print(' pip install sentencepiece') | ||
| 36 | + print('before you continue') | ||
| 37 | + raise | ||
| 38 | + | ||
| 39 | + | ||
| 40 | +def get_args(): | ||
| 41 | + parser = argparse.ArgumentParser() | ||
| 42 | + parser.add_argument( | ||
| 43 | + "--bpe-model", | ||
| 44 | + type=str, | ||
| 45 | + help="The path to the bpe model.", | ||
| 46 | + ) | ||
| 47 | + | ||
| 48 | + return parser.parse_args() | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def main(): | ||
| 52 | + args = get_args() | ||
| 53 | + model_file = args.bpe_model | ||
| 54 | + | ||
| 55 | + vocab_file = model_file.replace(".model", ".vocab") | ||
| 56 | + | ||
| 57 | + sp = spm.SentencePieceProcessor() | ||
| 58 | + sp.Load(model_file) | ||
| 59 | + vocabs = [sp.IdToPiece(id) for id in range(sp.GetPieceSize())] | ||
| 60 | + with open(vocab_file, "w") as vfile: | ||
| 61 | + for v in vocabs: | ||
| 62 | + id = sp.PieceToId(v) | ||
| 63 | + vfile.write(f"{v}\t{sp.GetScore(id)}\n") | ||
| 64 | + print(f"Vocabulary file is written to {vocab_file}") | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +if __name__ == "__main__": | ||
| 68 | + main() |
| @@ -165,6 +165,7 @@ endif() | @@ -165,6 +165,7 @@ endif() | ||
| 165 | target_link_libraries(sherpa-onnx-core | 165 | target_link_libraries(sherpa-onnx-core |
| 166 | kaldi-native-fbank-core | 166 | kaldi-native-fbank-core |
| 167 | kaldi-decoder-core | 167 | kaldi-decoder-core |
| 168 | + ssentencepiece_core | ||
| 168 | ) | 169 | ) |
| 169 | 170 | ||
| 170 | if(SHERPA_ONNX_ENABLE_GPU) | 171 | if(SHERPA_ONNX_ENABLE_GPU) |
| @@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | @@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) | ||
| 491 | pad-sequence-test.cc | 492 | pad-sequence-test.cc |
| 492 | slice-test.cc | 493 | slice-test.cc |
| 493 | stack-test.cc | 494 | stack-test.cc |
| 495 | + text2token-test.cc | ||
| 494 | transpose-test.cc | 496 | transpose-test.cc |
| 495 | unbind-test.cc | 497 | unbind-test.cc |
| 496 | utfcpp-test.cc | 498 | utfcpp-test.cc |
| @@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 35 | "Valid values are: transducer, paraformer, nemo_ctc, whisper, " | 35 | "Valid values are: transducer, paraformer, nemo_ctc, whisper, " |
| 36 | "tdnn, zipformer2_ctc" | 36 | "tdnn, zipformer2_ctc" |
| 37 | "All other values lead to loading the model twice."); | 37 | "All other values lead to loading the model twice."); |
| 38 | + po->Register("modeling-unit", &modeling_unit, | ||
| 39 | + "The modeling unit of the model, commonly used units are bpe, " | ||
| 40 | + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " | ||
| 41 | + "hotwords are provided, we need it to encode the hotwords into " | ||
| 42 | + "token sequence."); | ||
| 43 | + po->Register("bpe-vocab", &bpe_vocab, | ||
| 44 | + "The vocabulary generated by google's sentencepiece program. " | ||
| 45 | + "It is a file has two columns, one is the token, the other is " | ||
| 46 | + "the log probability, you can get it from the directory where " | ||
| 47 | + "your bpe model is generated. Only used when hotwords provided " | ||
| 48 | + "and the modeling unit is bpe or cjkchar+bpe"); | ||
| 38 | } | 49 | } |
| 39 | 50 | ||
| 40 | bool OfflineModelConfig::Validate() const { | 51 | bool OfflineModelConfig::Validate() const { |
| @@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const { | @@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const { | ||
| 48 | return false; | 59 | return false; |
| 49 | } | 60 | } |
| 50 | 61 | ||
| 62 | + if (!modeling_unit.empty() && | ||
| 63 | + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { | ||
| 64 | + if (!FileExists(bpe_vocab)) { | ||
| 65 | + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str()); | ||
| 66 | + return false; | ||
| 67 | + } | ||
| 68 | + } | ||
| 69 | + | ||
| 51 | if (!paraformer.model.empty()) { | 70 | if (!paraformer.model.empty()) { |
| 52 | return paraformer.Validate(); | 71 | return paraformer.Validate(); |
| 53 | } | 72 | } |
| @@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const { | @@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const { | ||
| 90 | os << "num_threads=" << num_threads << ", "; | 109 | os << "num_threads=" << num_threads << ", "; |
| 91 | os << "debug=" << (debug ? "True" : "False") << ", "; | 110 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 92 | os << "provider=\"" << provider << "\", "; | 111 | os << "provider=\"" << provider << "\", "; |
| 93 | - os << "model_type=\"" << model_type << "\")"; | 112 | + os << "model_type=\"" << model_type << "\", "; |
| 113 | + os << "modeling_unit=\"" << modeling_unit << "\", "; | ||
| 114 | + os << "bpe_vocab=\"" << bpe_vocab << "\")"; | ||
| 94 | 115 | ||
| 95 | return os.str(); | 116 | return os.str(); |
| 96 | } | 117 | } |
| @@ -41,6 +41,9 @@ struct OfflineModelConfig { | @@ -41,6 +41,9 @@ struct OfflineModelConfig { | ||
| 41 | // All other values are invalid and lead to loading the model twice. | 41 | // All other values are invalid and lead to loading the model twice. |
| 42 | std::string model_type; | 42 | std::string model_type; |
| 43 | 43 | ||
| 44 | + std::string modeling_unit = "cjkchar"; | ||
| 45 | + std::string bpe_vocab; | ||
| 46 | + | ||
| 44 | OfflineModelConfig() = default; | 47 | OfflineModelConfig() = default; |
| 45 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | 48 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, |
| 46 | const OfflineParaformerModelConfig ¶former, | 49 | const OfflineParaformerModelConfig ¶former, |
| @@ -50,7 +53,9 @@ struct OfflineModelConfig { | @@ -50,7 +53,9 @@ struct OfflineModelConfig { | ||
| 50 | const OfflineZipformerCtcModelConfig &zipformer_ctc, | 53 | const OfflineZipformerCtcModelConfig &zipformer_ctc, |
| 51 | const OfflineWenetCtcModelConfig &wenet_ctc, | 54 | const OfflineWenetCtcModelConfig &wenet_ctc, |
| 52 | const std::string &tokens, int32_t num_threads, bool debug, | 55 | const std::string &tokens, int32_t num_threads, bool debug, |
| 53 | - const std::string &provider, const std::string &model_type) | 56 | + const std::string &provider, const std::string &model_type, |
| 57 | + const std::string &modeling_unit, | ||
| 58 | + const std::string &bpe_vocab) | ||
| 54 | : transducer(transducer), | 59 | : transducer(transducer), |
| 55 | paraformer(paraformer), | 60 | paraformer(paraformer), |
| 56 | nemo_ctc(nemo_ctc), | 61 | nemo_ctc(nemo_ctc), |
| @@ -62,7 +67,9 @@ struct OfflineModelConfig { | @@ -62,7 +67,9 @@ struct OfflineModelConfig { | ||
| 62 | num_threads(num_threads), | 67 | num_threads(num_threads), |
| 63 | debug(debug), | 68 | debug(debug), |
| 64 | provider(provider), | 69 | provider(provider), |
| 65 | - model_type(model_type) {} | 70 | + model_type(model_type), |
| 71 | + modeling_unit(modeling_unit), | ||
| 72 | + bpe_vocab(bpe_vocab) {} | ||
| 66 | 73 | ||
| 67 | void Register(ParseOptions *po); | 74 | void Register(ParseOptions *po); |
| 68 | bool Validate() const; | 75 | bool Validate() const; |
| @@ -31,6 +31,7 @@ | @@ -31,6 +31,7 @@ | ||
| 31 | #include "sherpa-onnx/csrc/pad-sequence.h" | 31 | #include "sherpa-onnx/csrc/pad-sequence.h" |
| 32 | #include "sherpa-onnx/csrc/symbol-table.h" | 32 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 33 | #include "sherpa-onnx/csrc/utils.h" | 33 | #include "sherpa-onnx/csrc/utils.h" |
| 34 | +#include "ssentencepiece/csrc/ssentencepiece.h" | ||
| 34 | 35 | ||
| 35 | namespace sherpa_onnx { | 36 | namespace sherpa_onnx { |
| 36 | 37 | ||
| @@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 76 | : config_(config), | 77 | : config_(config), |
| 77 | symbol_table_(config_.model_config.tokens), | 78 | symbol_table_(config_.model_config.tokens), |
| 78 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | 79 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { |
| 79 | - if (!config_.hotwords_file.empty()) { | ||
| 80 | - InitHotwords(); | ||
| 81 | - } | ||
| 82 | if (config_.decoding_method == "greedy_search") { | 80 | if (config_.decoding_method == "greedy_search") { |
| 83 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( | 81 | decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 84 | model_.get(), config_.blank_penalty); | 82 | model_.get(), config_.blank_penalty); |
| @@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 87 | lm_ = OfflineLM::Create(config.lm_config); | 85 | lm_ = OfflineLM::Create(config.lm_config); |
| 88 | } | 86 | } |
| 89 | 87 | ||
| 88 | + if (!config_.model_config.bpe_vocab.empty()) { | ||
| 89 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>( | ||
| 90 | + config_.model_config.bpe_vocab); | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + if (!config_.hotwords_file.empty()) { | ||
| 94 | + InitHotwords(); | ||
| 95 | + } | ||
| 96 | + | ||
| 90 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 97 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 91 | model_.get(), lm_.get(), config_.max_active_paths, | 98 | model_.get(), lm_.get(), config_.max_active_paths, |
| 92 | config_.lm_config.scale, config_.blank_penalty); | 99 | config_.lm_config.scale, config_.blank_penalty); |
| @@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 112 | lm_ = OfflineLM::Create(mgr, config.lm_config); | 119 | lm_ = OfflineLM::Create(mgr, config.lm_config); |
| 113 | } | 120 | } |
| 114 | 121 | ||
| 122 | + if (!config_.model_config.bpe_vocab.empty()) { | ||
| 123 | + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); | ||
| 124 | + std::istringstream iss(std::string(buf.begin(), buf.end())); | ||
| 125 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss); | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + if (!config_.hotwords_file.empty()) { | ||
| 129 | + InitHotwords(mgr); | ||
| 130 | + } | ||
| 131 | + | ||
| 115 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | 132 | decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( |
| 116 | model_.get(), lm_.get(), config_.max_active_paths, | 133 | model_.get(), lm_.get(), config_.max_active_paths, |
| 117 | config_.lm_config.scale, config_.blank_penalty); | 134 | config_.lm_config.scale, config_.blank_penalty); |
| @@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 128 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | 145 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); |
| 129 | std::istringstream is(hws); | 146 | std::istringstream is(hws); |
| 130 | std::vector<std::vector<int32_t>> current; | 147 | std::vector<std::vector<int32_t>> current; |
| 131 | - if (!EncodeHotwords(is, symbol_table_, ¤t)) { | 148 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, |
| 149 | + bpe_encoder_.get(), ¤t)) { | ||
| 132 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | 150 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", |
| 133 | hotwords.c_str()); | 151 | hotwords.c_str()); |
| 134 | } | 152 | } |
| @@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 207 | exit(-1); | 225 | exit(-1); |
| 208 | } | 226 | } |
| 209 | 227 | ||
| 210 | - if (!EncodeHotwords(is, symbol_table_, &hotwords_)) { | ||
| 211 | - SHERPA_ONNX_LOGE("Encode hotwords failed."); | 228 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, |
| 229 | + bpe_encoder_.get(), &hotwords_)) { | ||
| 230 | + SHERPA_ONNX_LOGE( | ||
| 231 | + "Failed to encode some hotwords, skip them already, see logs above " | ||
| 232 | + "for details."); | ||
| 233 | + } | ||
| 234 | + hotwords_graph_ = | ||
| 235 | + std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | ||
| 236 | + } | ||
| 237 | + | ||
| 238 | +#if __ANDROID_API__ >= 9 | ||
| 239 | + void InitHotwords(AAssetManager *mgr) { | ||
| 240 | + // each line in hotwords_file contains space-separated words | ||
| 241 | + | ||
| 242 | + auto buf = ReadFile(mgr, config_.hotwords_file); | ||
| 243 | + | ||
| 244 | + std::istringstream is(std::string(buf.begin(), buf.end())); | ||
| 245 | + | ||
| 246 | + if (!is) { | ||
| 247 | + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | ||
| 248 | + config_.hotwords_file.c_str()); | ||
| 212 | exit(-1); | 249 | exit(-1); |
| 213 | } | 250 | } |
| 251 | + | ||
| 252 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, | ||
| 253 | + bpe_encoder_.get(), &hotwords_)) { | ||
| 254 | + SHERPA_ONNX_LOGE( | ||
| 255 | + "Failed to encode some hotwords, skip them already, see logs above " | ||
| 256 | + "for details."); | ||
| 257 | + } | ||
| 214 | hotwords_graph_ = | 258 | hotwords_graph_ = |
| 215 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 259 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); |
| 216 | } | 260 | } |
| 261 | +#endif | ||
| 217 | 262 | ||
| 218 | private: | 263 | private: |
| 219 | OfflineRecognizerConfig config_; | 264 | OfflineRecognizerConfig config_; |
| 220 | SymbolTable symbol_table_; | 265 | SymbolTable symbol_table_; |
| 221 | std::vector<std::vector<int32_t>> hotwords_; | 266 | std::vector<std::vector<int32_t>> hotwords_; |
| 222 | ContextGraphPtr hotwords_graph_; | 267 | ContextGraphPtr hotwords_graph_; |
| 268 | + std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; | ||
| 223 | std::unique_ptr<OfflineTransducerModel> model_; | 269 | std::unique_ptr<OfflineTransducerModel> model_; |
| 224 | std::unique_ptr<OfflineTransducerDecoder> decoder_; | 270 | std::unique_ptr<OfflineTransducerDecoder> decoder_; |
| 225 | std::unique_ptr<OfflineLM> lm_; | 271 | std::unique_ptr<OfflineLM> lm_; |
| @@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 37 | 37 | ||
| 38 | po->Register( | 38 | po->Register( |
| 39 | "hotwords-file", &hotwords_file, | 39 | "hotwords-file", &hotwords_file, |
| 40 | - "The file containing hotwords, one words/phrases per line, and for each" | ||
| 41 | - "phrase the bpe/cjkchar are separated by a space. For example: " | ||
| 42 | - "▁HE LL O ▁WORLD" | ||
| 43 | - "你 好 世 界"); | 40 | + "The file containing hotwords, one words/phrases per line, For example: " |
| 41 | + "HELLO WORLD" | ||
| 42 | + "你好世界"); | ||
| 44 | 43 | ||
| 45 | po->Register("hotwords-score", &hotwords_score, | 44 | po->Register("hotwords-score", &hotwords_score, |
| 46 | "The bonus score for each token in context word/phrase. " | 45 | "The bonus score for each token in context word/phrase. " |
| @@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 32 | po->Register("provider", &provider, | 32 | po->Register("provider", &provider, |
| 33 | "Specify a provider to use: cpu, cuda, coreml"); | 33 | "Specify a provider to use: cpu, cuda, coreml"); |
| 34 | 34 | ||
| 35 | + po->Register("modeling-unit", &modeling_unit, | ||
| 36 | + "The modeling unit of the model, commonly used units are bpe, " | ||
| 37 | + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " | ||
| 38 | + "hotwords are provided, we need it to encode the hotwords into " | ||
| 39 | + "token sequence."); | ||
| 40 | + | ||
| 41 | + po->Register("bpe-vocab", &bpe_vocab, | ||
| 42 | + "The vocabulary generated by google's sentencepiece program. " | ||
| 43 | + "It is a file has two columns, one is the token, the other is " | ||
| 44 | + "the log probability, you can get it from the directory where " | ||
| 45 | + "your bpe model is generated. Only used when hotwords provided " | ||
| 46 | + "and the modeling unit is bpe or cjkchar+bpe"); | ||
| 47 | + | ||
| 35 | po->Register("model-type", &model_type, | 48 | po->Register("model-type", &model_type, |
| 36 | "Specify it to reduce model initialization time. " | 49 | "Specify it to reduce model initialization time. " |
| 37 | "Valid values are: conformer, lstm, zipformer, zipformer2, " | 50 | "Valid values are: conformer, lstm, zipformer, zipformer2, " |
| @@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const { | @@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const { | ||
| 50 | return false; | 63 | return false; |
| 51 | } | 64 | } |
| 52 | 65 | ||
| 66 | + if (!modeling_unit.empty() && | ||
| 67 | + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { | ||
| 68 | + if (!FileExists(bpe_vocab)) { | ||
| 69 | + SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str()); | ||
| 70 | + return false; | ||
| 71 | + } | ||
| 72 | + } | ||
| 73 | + | ||
| 53 | if (!paraformer.encoder.empty()) { | 74 | if (!paraformer.encoder.empty()) { |
| 54 | return paraformer.Validate(); | 75 | return paraformer.Validate(); |
| 55 | } | 76 | } |
| @@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const { | @@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const { | ||
| 83 | os << "warm_up=" << warm_up << ", "; | 104 | os << "warm_up=" << warm_up << ", "; |
| 84 | os << "debug=" << (debug ? "True" : "False") << ", "; | 105 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 85 | os << "provider=\"" << provider << "\", "; | 106 | os << "provider=\"" << provider << "\", "; |
| 86 | - os << "model_type=\"" << model_type << "\")"; | 107 | + os << "model_type=\"" << model_type << "\", "; |
| 108 | + os << "modeling_unit=\"" << modeling_unit << "\", "; | ||
| 109 | + os << "bpe_vocab=\"" << bpe_vocab << "\")"; | ||
| 87 | 110 | ||
| 88 | return os.str(); | 111 | return os.str(); |
| 89 | } | 112 | } |
| @@ -37,6 +37,13 @@ struct OnlineModelConfig { | @@ -37,6 +37,13 @@ struct OnlineModelConfig { | ||
| 37 | // All other values are invalid and lead to loading the model twice. | 37 | // All other values are invalid and lead to loading the model twice. |
| 38 | std::string model_type; | 38 | std::string model_type; |
| 39 | 39 | ||
| 40 | + // Valid values: | ||
| 41 | + // - cjkchar | ||
| 42 | + // - bpe | ||
| 43 | + // - cjkchar+bpe | ||
| 44 | + std::string modeling_unit = "cjkchar"; | ||
| 45 | + std::string bpe_vocab; | ||
| 46 | + | ||
| 40 | OnlineModelConfig() = default; | 47 | OnlineModelConfig() = default; |
| 41 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | 48 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, |
| 42 | const OnlineParaformerModelConfig ¶former, | 49 | const OnlineParaformerModelConfig ¶former, |
| @@ -45,7 +52,9 @@ struct OnlineModelConfig { | @@ -45,7 +52,9 @@ struct OnlineModelConfig { | ||
| 45 | const OnlineNeMoCtcModelConfig &nemo_ctc, | 52 | const OnlineNeMoCtcModelConfig &nemo_ctc, |
| 46 | const std::string &tokens, int32_t num_threads, | 53 | const std::string &tokens, int32_t num_threads, |
| 47 | int32_t warm_up, bool debug, const std::string &provider, | 54 | int32_t warm_up, bool debug, const std::string &provider, |
| 48 | - const std::string &model_type) | 55 | + const std::string &model_type, |
| 56 | + const std::string &modeling_unit, | ||
| 57 | + const std::string &bpe_vocab) | ||
| 49 | : transducer(transducer), | 58 | : transducer(transducer), |
| 50 | paraformer(paraformer), | 59 | paraformer(paraformer), |
| 51 | wenet_ctc(wenet_ctc), | 60 | wenet_ctc(wenet_ctc), |
| @@ -56,7 +65,9 @@ struct OnlineModelConfig { | @@ -56,7 +65,9 @@ struct OnlineModelConfig { | ||
| 56 | warm_up(warm_up), | 65 | warm_up(warm_up), |
| 57 | debug(debug), | 66 | debug(debug), |
| 58 | provider(provider), | 67 | provider(provider), |
| 59 | - model_type(model_type) {} | 68 | + model_type(model_type), |
| 69 | + modeling_unit(modeling_unit), | ||
| 70 | + bpe_vocab(bpe_vocab) {} | ||
| 60 | 71 | ||
| 61 | void Register(ParseOptions *po); | 72 | void Register(ParseOptions *po); |
| 62 | bool Validate() const; | 73 | bool Validate() const; |
| @@ -15,8 +15,6 @@ | @@ -15,8 +15,6 @@ | ||
| 15 | #include <vector> | 15 | #include <vector> |
| 16 | 16 | ||
| 17 | #if __ANDROID_API__ >= 9 | 17 | #if __ANDROID_API__ >= 9 |
| 18 | -#include <strstream> | ||
| 19 | - | ||
| 20 | #include "android/asset_manager.h" | 18 | #include "android/asset_manager.h" |
| 21 | #include "android/asset_manager_jni.h" | 19 | #include "android/asset_manager_jni.h" |
| 22 | #endif | 20 | #endif |
| @@ -33,6 +31,7 @@ | @@ -33,6 +31,7 @@ | ||
| 33 | #include "sherpa-onnx/csrc/onnx-utils.h" | 31 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 34 | #include "sherpa-onnx/csrc/symbol-table.h" | 32 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 35 | #include "sherpa-onnx/csrc/utils.h" | 33 | #include "sherpa-onnx/csrc/utils.h" |
| 34 | +#include "ssentencepiece/csrc/ssentencepiece.h" | ||
| 36 | 35 | ||
| 37 | namespace sherpa_onnx { | 36 | namespace sherpa_onnx { |
| 38 | 37 | ||
| @@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 94 | model_->SetFeatureDim(config.feat_config.feature_dim); | 93 | model_->SetFeatureDim(config.feat_config.feature_dim); |
| 95 | 94 | ||
| 96 | if (config.decoding_method == "modified_beam_search") { | 95 | if (config.decoding_method == "modified_beam_search") { |
| 96 | + if (!config_.model_config.bpe_vocab.empty()) { | ||
| 97 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>( | ||
| 98 | + config_.model_config.bpe_vocab); | ||
| 99 | + } | ||
| 100 | + | ||
| 97 | if (!config_.hotwords_file.empty()) { | 101 | if (!config_.hotwords_file.empty()) { |
| 98 | InitHotwords(); | 102 | InitHotwords(); |
| 99 | } | 103 | } |
| @@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 140 | } | 144 | } |
| 141 | #endif | 145 | #endif |
| 142 | 146 | ||
| 147 | + if (!config_.model_config.bpe_vocab.empty()) { | ||
| 148 | + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); | ||
| 149 | + std::istringstream iss(std::string(buf.begin(), buf.end())); | ||
| 150 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss); | ||
| 151 | + } | ||
| 152 | + | ||
| 143 | if (!config_.hotwords_file.empty()) { | 153 | if (!config_.hotwords_file.empty()) { |
| 144 | InitHotwords(mgr); | 154 | InitHotwords(mgr); |
| 145 | } | 155 | } |
| @@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 174 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); | 184 | auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); |
| 175 | std::istringstream is(hws); | 185 | std::istringstream is(hws); |
| 176 | std::vector<std::vector<int32_t>> current; | 186 | std::vector<std::vector<int32_t>> current; |
| 177 | - if (!EncodeHotwords(is, sym_, ¤t)) { | 187 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 188 | + bpe_encoder_.get(), ¤t)) { | ||
| 178 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", | 189 | SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", |
| 179 | hotwords.c_str()); | 190 | hotwords.c_str()); |
| 180 | } | 191 | } |
| @@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 363 | exit(-1); | 374 | exit(-1); |
| 364 | } | 375 | } |
| 365 | 376 | ||
| 366 | - if (!EncodeHotwords(is, sym_, &hotwords_)) { | ||
| 367 | - SHERPA_ONNX_LOGE("Encode hotwords failed."); | ||
| 368 | - exit(-1); | 377 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 378 | + bpe_encoder_.get(), &hotwords_)) { | ||
| 379 | + SHERPA_ONNX_LOGE( | ||
| 380 | + "Failed to encode some hotwords, skip them already, see logs above " | ||
| 381 | + "for details."); | ||
| 369 | } | 382 | } |
| 370 | hotwords_graph_ = | 383 | hotwords_graph_ = |
| 371 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 384 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); |
| @@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 377 | 390 | ||
| 378 | auto buf = ReadFile(mgr, config_.hotwords_file); | 391 | auto buf = ReadFile(mgr, config_.hotwords_file); |
| 379 | 392 | ||
| 380 | - std::istrstream is(buf.data(), buf.size()); | 393 | + std::istringstream is(std::string(buf.begin(), buf.end())); |
| 381 | 394 | ||
| 382 | if (!is) { | 395 | if (!is) { |
| 383 | SHERPA_ONNX_LOGE("Open hotwords file failed: %s", | 396 | SHERPA_ONNX_LOGE("Open hotwords file failed: %s", |
| @@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 385 | exit(-1); | 398 | exit(-1); |
| 386 | } | 399 | } |
| 387 | 400 | ||
| 388 | - if (!EncodeHotwords(is, sym_, &hotwords_)) { | ||
| 389 | - SHERPA_ONNX_LOGE("Encode hotwords failed."); | ||
| 390 | - exit(-1); | 401 | + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, |
| 402 | + bpe_encoder_.get(), &hotwords_)) { | ||
| 403 | + SHERPA_ONNX_LOGE( | ||
| 404 | + "Failed to encode some hotwords, skip them already, see logs above " | ||
| 405 | + "for details."); | ||
| 391 | } | 406 | } |
| 392 | hotwords_graph_ = | 407 | hotwords_graph_ = |
| 393 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); | 408 | std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score); |
| @@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 413 | OnlineRecognizerConfig config_; | 428 | OnlineRecognizerConfig config_; |
| 414 | std::vector<std::vector<int32_t>> hotwords_; | 429 | std::vector<std::vector<int32_t>> hotwords_; |
| 415 | ContextGraphPtr hotwords_graph_; | 430 | ContextGraphPtr hotwords_graph_; |
| 431 | + std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_; | ||
| 416 | std::unique_ptr<OnlineTransducerModel> model_; | 432 | std::unique_ptr<OnlineTransducerModel> model_; |
| 417 | std::unique_ptr<OnlineLM> lm_; | 433 | std::unique_ptr<OnlineLM> lm_; |
| 418 | std::unique_ptr<OnlineTransducerDecoder> decoder_; | 434 | std::unique_ptr<OnlineTransducerDecoder> decoder_; |
| @@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec, | @@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec, | ||
| 51 | std::string OnlineRecognizerResult::AsJsonString() const { | 51 | std::string OnlineRecognizerResult::AsJsonString() const { |
| 52 | std::ostringstream os; | 52 | std::ostringstream os; |
| 53 | os << "{ "; | 53 | os << "{ "; |
| 54 | - os << "\"text\": " | ||
| 55 | - << "\"" << text << "\"" | ||
| 56 | - << ", "; | 54 | + os << "\"text\": " << "\"" << text << "\"" << ", "; |
| 57 | os << "\"tokens\": " << VecToString(tokens) << ", "; | 55 | os << "\"tokens\": " << VecToString(tokens) << ", "; |
| 58 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; | 56 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; |
| 59 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; | 57 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; |
| @@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 89 | "Used only when decoding_method is modified_beam_search"); | 87 | "Used only when decoding_method is modified_beam_search"); |
| 90 | po->Register( | 88 | po->Register( |
| 91 | "hotwords-file", &hotwords_file, | 89 | "hotwords-file", &hotwords_file, |
| 92 | - "The file containing hotwords, one words/phrases per line, and for each" | ||
| 93 | - "phrase the bpe/cjkchar are separated by a space. For example: " | ||
| 94 | - "▁HE LL O ▁WORLD" | ||
| 95 | - "你 好 世 界"); | 90 | + "The file containing hotwords, one words/phrases per line, For example: " |
| 91 | + "HELLO WORLD" | ||
| 92 | + "你好世界"); | ||
| 96 | po->Register("decoding-method", &decoding_method, | 93 | po->Register("decoding-method", &decoding_method, |
| 97 | "decoding method," | 94 | "decoding method," |
| 98 | "now support greedy_search and modified_beam_search."); | 95 | "now support greedy_search and modified_beam_search."); |
| @@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) { | @@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) { | ||
| 38 | std::string sym; | 38 | std::string sym; |
| 39 | int32_t id; | 39 | int32_t id; |
| 40 | while (is >> sym >> id) { | 40 | while (is >> sym >> id) { |
| 41 | - if (sym.size() >= 3) { | ||
| 42 | - // For BPE-based models, we replace ▁ with a space | ||
| 43 | - // Unicode 9601, hex 0x2581, utf8 0xe29681 | ||
| 44 | - const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); | ||
| 45 | - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 46 | - sym = sym.replace(0, 3, " "); | ||
| 47 | - } | ||
| 48 | - } | ||
| 49 | - | ||
| 50 | - // for byte-level BPE | ||
| 51 | - // id 0 is blank, id 1 is sos/eos, id 2 is unk | ||
| 52 | - if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && | ||
| 53 | - sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { | ||
| 54 | - std::ostringstream os; | ||
| 55 | - os << std::hex << std::uppercase << (id - 3); | ||
| 56 | - | ||
| 57 | - if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { | ||
| 58 | - uint8_t i = id - 3; | ||
| 59 | - sym = std::string(&i, &i + 1); | ||
| 60 | - } | ||
| 61 | - } | ||
| 62 | - | ||
| 63 | - assert(!sym.empty()); | ||
| 64 | - | ||
| 65 | - // for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20, | ||
| 66 | - // there is a conflict between the real byte 0x20 and ▁, so we disable | ||
| 67 | - // the following check. | ||
| 68 | - // | ||
| 69 | - // Note: Only id2sym_ matters as we use it to convert ID to symbols. | ||
| 70 | #if 0 | 41 | #if 0 |
| 71 | // we disable the test here since for some multi-lingual BPE models | 42 | // we disable the test here since for some multi-lingual BPE models |
| 72 | // from NeMo, the same symbol can appear multiple times with different IDs. | 43 | // from NeMo, the same symbol can appear multiple times with different IDs. |
| @@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const { | @@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const { | ||
| 92 | return os.str(); | 63 | return os.str(); |
| 93 | } | 64 | } |
| 94 | 65 | ||
| 95 | -const std::string &SymbolTable::operator[](int32_t id) const { | ||
| 96 | - return id2sym_.at(id); | 66 | +const std::string SymbolTable::operator[](int32_t id) const { |
| 67 | + std::string sym = id2sym_.at(id); | ||
| 68 | + if (sym.size() >= 3) { | ||
| 69 | + // For BPE-based models, we replace ▁ with a space | ||
| 70 | + // Unicode 9601, hex 0x2581, utf8 0xe29681 | ||
| 71 | + const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str()); | ||
| 72 | + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 73 | + sym = sym.replace(0, 3, " "); | ||
| 74 | + } | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + // for byte-level BPE | ||
| 78 | + // id 0 is blank, id 1 is sos/eos, id 2 is unk | ||
| 79 | + if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && | ||
| 80 | + sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { | ||
| 81 | + std::ostringstream os; | ||
| 82 | + os << std::hex << std::uppercase << (id - 3); | ||
| 83 | + | ||
| 84 | + if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { | ||
| 85 | + uint8_t i = id - 3; | ||
| 86 | + sym = std::string(&i, &i + 1); | ||
| 87 | + } | ||
| 88 | + } | ||
| 89 | + return sym; | ||
| 97 | } | 90 | } |
| 98 | 91 | ||
| 99 | int32_t SymbolTable::operator[](const std::string &sym) const { | 92 | int32_t SymbolTable::operator[](const std::string &sym) const { |
| @@ -35,7 +35,7 @@ class SymbolTable { | @@ -35,7 +35,7 @@ class SymbolTable { | ||
| 35 | std::string ToString() const; | 35 | std::string ToString() const; |
| 36 | 36 | ||
| 37 | /// Return the symbol corresponding to the given ID. | 37 | /// Return the symbol corresponding to the given ID. |
| 38 | - const std::string &operator[](int32_t id) const; | 38 | + const std::string operator[](int32_t id) const; |
| 39 | /// Return the ID corresponding to the given symbol. | 39 | /// Return the ID corresponding to the given symbol. |
| 40 | int32_t operator[](const std::string &sym) const; | 40 | int32_t operator[](const std::string &sym) const; |
| 41 | 41 |
sherpa-onnx/csrc/text2token-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/text2token-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include <fstream> | ||
| 6 | +#include <sstream> | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "gtest/gtest.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/utils.h" | ||
| 12 | +#include "ssentencepiece/csrc/ssentencepiece.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +// Please refer to | ||
| 17 | +// https://github.com/pkufool/sherpa-test-data | ||
| 18 | +// to download test data for testing | ||
| 19 | +static const char dir[] = "/tmp/sherpa-test-data"; | ||
| 20 | + | ||
| 21 | +TEST(TEXT2TOKEN, TEST_cjkchar) { | ||
| 22 | + std::ostringstream oss; | ||
| 23 | + oss << dir << "/text2token/tokens_cn.txt"; | ||
| 24 | + | ||
| 25 | + std::string tokens = oss.str(); | ||
| 26 | + | ||
| 27 | + if (!std::ifstream(tokens).good()) { | ||
| 28 | + SHERPA_ONNX_LOGE( | ||
| 29 | + "No test data found, skipping TEST_cjkchar()." | ||
| 30 | + "You can download the test data by: " | ||
| 31 | + "git clone https://github.com/pkufool/sherpa-test-data.git " | ||
| 32 | + "/tmp/sherpa-test-data"); | ||
| 33 | + return; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + auto sym_table = SymbolTable(tokens); | ||
| 37 | + | ||
| 38 | + std::string text = "世界人民大团结\n中国 V S 美国"; | ||
| 39 | + | ||
| 40 | + std::istringstream iss(text); | ||
| 41 | + | ||
| 42 | + std::vector<std::vector<int32_t>> ids; | ||
| 43 | + | ||
| 44 | + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids); | ||
| 45 | + | ||
| 46 | + std::vector<std::vector<int32_t>> expected_ids( | ||
| 47 | + {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); | ||
| 48 | + EXPECT_EQ(ids, expected_ids); | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +TEST(TEXT2TOKEN, TEST_bpe) { | ||
| 52 | + std::ostringstream oss; | ||
| 53 | + oss << dir << "/text2token/tokens_en.txt"; | ||
| 54 | + std::string tokens = oss.str(); | ||
| 55 | + oss.clear(); | ||
| 56 | + oss.str(""); | ||
| 57 | + oss << dir << "/text2token/bpe_en.vocab"; | ||
| 58 | + std::string bpe = oss.str(); | ||
| 59 | + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { | ||
| 60 | + SHERPA_ONNX_LOGE( | ||
| 61 | + "No test data found, skipping TEST_bpe()." | ||
| 62 | + "You can download the test data by: " | ||
| 63 | + "git clone https://github.com/pkufool/sherpa-test-data.git " | ||
| 64 | + "/tmp/sherpa-test-data"); | ||
| 65 | + return; | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + auto sym_table = SymbolTable(tokens); | ||
| 69 | + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | ||
| 70 | + | ||
| 71 | + std::string text = "HELLO WORLD\nI LOVE YOU"; | ||
| 72 | + | ||
| 73 | + std::istringstream iss(text); | ||
| 74 | + | ||
| 75 | + std::vector<std::vector<int32_t>> ids; | ||
| 76 | + | ||
| 77 | + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); | ||
| 78 | + | ||
| 79 | + std::vector<std::vector<int32_t>> expected_ids( | ||
| 80 | + {{22, 58, 24, 425}, {19, 370, 47}}); | ||
| 81 | + EXPECT_EQ(ids, expected_ids); | ||
| 82 | +} | ||
| 83 | + | ||
| 84 | +TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { | ||
| 85 | + std::ostringstream oss; | ||
| 86 | + oss << dir << "/text2token/tokens_mix.txt"; | ||
| 87 | + std::string tokens = oss.str(); | ||
| 88 | + oss.clear(); | ||
| 89 | + oss.str(""); | ||
| 90 | + oss << dir << "/text2token/bpe_mix.vocab"; | ||
| 91 | + std::string bpe = oss.str(); | ||
| 92 | + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { | ||
| 93 | + SHERPA_ONNX_LOGE( | ||
| 94 | + "No test data found, skipping TEST_cjkchar_bpe()." | ||
| 95 | + "You can download the test data by: " | ||
| 96 | + "git clone https://github.com/pkufool/sherpa-test-data.git " | ||
| 97 | + "/tmp/sherpa-test-data"); | ||
| 98 | + return; | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + auto sym_table = SymbolTable(tokens); | ||
| 102 | + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | ||
| 103 | + | ||
| 104 | + std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国"; | ||
| 105 | + | ||
| 106 | + std::istringstream iss(text); | ||
| 107 | + | ||
| 108 | + std::vector<std::vector<int32_t>> ids; | ||
| 109 | + | ||
| 110 | + auto r = | ||
| 111 | + EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids); | ||
| 112 | + | ||
| 113 | + std::vector<std::vector<int32_t>> expected_ids( | ||
| 114 | + {{1368, 1392, 557, 680, 275, 178, 475}, | ||
| 115 | + {685, 736, 275, 178, 179, 921, 736}}); | ||
| 116 | + EXPECT_EQ(ids, expected_ids); | ||
| 117 | +} | ||
| 118 | + | ||
| 119 | +TEST(TEXT2TOKEN, TEST_bbpe) { | ||
| 120 | + std::ostringstream oss; | ||
| 121 | + oss << dir << "/text2token/tokens_bbpe.txt"; | ||
| 122 | + std::string tokens = oss.str(); | ||
| 123 | + oss.clear(); | ||
| 124 | + oss.str(""); | ||
| 125 | + oss << dir << "/text2token/bbpe.vocab"; | ||
| 126 | + std::string bpe = oss.str(); | ||
| 127 | + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { | ||
| 128 | + SHERPA_ONNX_LOGE( | ||
| 129 | + "No test data found, skipping TEST_bbpe()." | ||
| 130 | + "You can download the test data by: " | ||
| 131 | + "git clone https://github.com/pkufool/sherpa-test-data.git " | ||
| 132 | + "/tmp/sherpa-test-data"); | ||
| 133 | + return; | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + auto sym_table = SymbolTable(tokens); | ||
| 137 | + auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe); | ||
| 138 | + | ||
| 139 | + std::string text = "频繁\n李鞑靼"; | ||
| 140 | + | ||
| 141 | + std::istringstream iss(text); | ||
| 142 | + | ||
| 143 | + std::vector<std::vector<int32_t>> ids; | ||
| 144 | + | ||
| 145 | + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); | ||
| 146 | + | ||
| 147 | + std::vector<std::vector<int32_t>> expected_ids( | ||
| 148 | + {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); | ||
| 149 | + EXPECT_EQ(ids, expected_ids); | ||
| 150 | +} | ||
| 151 | + | ||
| 152 | +} // namespace sherpa_onnx |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/utils.h" | 5 | #include "sherpa-onnx/csrc/utils.h" |
| 6 | 6 | ||
| 7 | +#include <cassert> | ||
| 7 | #include <iostream> | 8 | #include <iostream> |
| 8 | #include <sstream> | 9 | #include <sstream> |
| 9 | #include <string> | 10 | #include <string> |
| @@ -12,15 +13,16 @@ | @@ -12,15 +13,16 @@ | ||
| 12 | 13 | ||
| 13 | #include "sherpa-onnx/csrc/log.h" | 14 | #include "sherpa-onnx/csrc/log.h" |
| 14 | #include "sherpa-onnx/csrc/macros.h" | 15 | #include "sherpa-onnx/csrc/macros.h" |
| 16 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 15 | 17 | ||
| 16 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 17 | 19 | ||
| 18 | -static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | 20 | +static bool EncodeBase(const std::vector<std::string> &lines, |
| 21 | + const SymbolTable &symbol_table, | ||
| 19 | std::vector<std::vector<int32_t>> *ids, | 22 | std::vector<std::vector<int32_t>> *ids, |
| 20 | std::vector<std::string> *phrases, | 23 | std::vector<std::string> *phrases, |
| 21 | std::vector<float> *scores, | 24 | std::vector<float> *scores, |
| 22 | std::vector<float> *thresholds) { | 25 | std::vector<float> *thresholds) { |
| 23 | - SHERPA_ONNX_CHECK(ids != nullptr); | ||
| 24 | ids->clear(); | 26 | ids->clear(); |
| 25 | 27 | ||
| 26 | std::vector<int32_t> tmp_ids; | 28 | std::vector<int32_t> tmp_ids; |
| @@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | @@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | ||
| 33 | bool has_scores = false; | 35 | bool has_scores = false; |
| 34 | bool has_thresholds = false; | 36 | bool has_thresholds = false; |
| 35 | bool has_phrases = false; | 37 | bool has_phrases = false; |
| 38 | + bool has_oov = false; | ||
| 36 | 39 | ||
| 37 | - while (std::getline(is, line)) { | 40 | + for (const auto &line : lines) { |
| 38 | float score = 0; | 41 | float score = 0; |
| 39 | float threshold = 0; | 42 | float threshold = 0; |
| 40 | std::string phrase = ""; | 43 | std::string phrase = ""; |
| 41 | 44 | ||
| 42 | std::istringstream iss(line); | 45 | std::istringstream iss(line); |
| 43 | while (iss >> word) { | 46 | while (iss >> word) { |
| 44 | - if (word.size() >= 3) { | ||
| 45 | - // For BPE-based models, we replace ▁ with a space | ||
| 46 | - // Unicode 9601, hex 0x2581, utf8 0xe29681 | ||
| 47 | - const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str()); | ||
| 48 | - if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { | ||
| 49 | - word = word.replace(0, 3, " "); | ||
| 50 | - } | ||
| 51 | - } | ||
| 52 | if (symbol_table.Contains(word)) { | 47 | if (symbol_table.Contains(word)) { |
| 53 | int32_t id = symbol_table[word]; | 48 | int32_t id = symbol_table[word]; |
| 54 | tmp_ids.push_back(id); | 49 | tmp_ids.push_back(id); |
| @@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | @@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | ||
| 71 | "Cannot find ID for token %s at line: %s. (Hint: words on " | 66 | "Cannot find ID for token %s at line: %s. (Hint: words on " |
| 72 | "the same line are separated by spaces)", | 67 | "the same line are separated by spaces)", |
| 73 | word.c_str(), line.c_str()); | 68 | word.c_str(), line.c_str()); |
| 74 | - return false; | 69 | + has_oov = true; |
| 70 | + break; | ||
| 75 | } | 71 | } |
| 76 | } | 72 | } |
| 77 | } | 73 | } |
| @@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | @@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table, | ||
| 101 | thresholds->clear(); | 97 | thresholds->clear(); |
| 102 | } | 98 | } |
| 103 | } | 99 | } |
| 104 | - return true; | 100 | + return !has_oov; |
| 105 | } | 101 | } |
| 106 | 102 | ||
| 107 | -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | 103 | +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, |
| 104 | + const SymbolTable &symbol_table, | ||
| 105 | + const ssentencepiece::Ssentencepiece *bpe_encoder, | ||
| 108 | std::vector<std::vector<int32_t>> *hotwords) { | 106 | std::vector<std::vector<int32_t>> *hotwords) { |
| 109 | - return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr); | 107 | + std::vector<std::string> lines; |
| 108 | + std::string line; | ||
| 109 | + std::string word; | ||
| 110 | + | ||
| 111 | + while (std::getline(is, line)) { | ||
| 112 | + std::string score; | ||
| 113 | + std::string phrase; | ||
| 114 | + | ||
| 115 | + std::ostringstream oss; | ||
| 116 | + std::istringstream iss(line); | ||
| 117 | + while (iss >> word) { | ||
| 118 | + switch (word[0]) { | ||
| 119 | + case ':': // boosting score for current keyword | ||
| 120 | + score = word; | ||
| 121 | + break; | ||
| 122 | + default: | ||
| 123 | + if (!score.empty()) { | ||
| 124 | + SHERPA_ONNX_LOGE( | ||
| 125 | + "Boosting score should be put after the words/phrase, given " | ||
| 126 | + "%s.", | ||
| 127 | + line.c_str()); | ||
| 128 | + return false; | ||
| 129 | + } | ||
| 130 | + oss << " " << word; | ||
| 131 | + break; | ||
| 132 | + } | ||
| 133 | + } | ||
| 134 | + phrase = oss.str().substr(1); | ||
| 135 | + std::istringstream piss(phrase); | ||
| 136 | + oss.clear(); | ||
| 137 | + oss.str(""); | ||
| 138 | + while (piss >> word) { | ||
| 139 | + if (modeling_unit == "cjkchar") { | ||
| 140 | + for (const auto &w : SplitUtf8(word)) { | ||
| 141 | + oss << " " << w; | ||
| 142 | + } | ||
| 143 | + } else if (modeling_unit == "bpe") { | ||
| 144 | + std::vector<std::string> bpes; | ||
| 145 | + bpe_encoder->Encode(word, &bpes); | ||
| 146 | + for (const auto &bpe : bpes) { | ||
| 147 | + oss << " " << bpe; | ||
| 148 | + } | ||
| 149 | + } else { | ||
| 150 | + if (modeling_unit != "cjkchar+bpe") { | ||
| 151 | + SHERPA_ONNX_LOGE( | ||
| 152 | + "modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, " | ||
| 153 | + "given " | ||
| 154 | + "%s", | ||
| 155 | + modeling_unit.c_str()); | ||
| 156 | + exit(-1); | ||
| 157 | + } | ||
| 158 | + for (const auto &w : SplitUtf8(word)) { | ||
| 159 | + if (isalpha(w[0])) { | ||
| 160 | + std::vector<std::string> bpes; | ||
| 161 | + bpe_encoder->Encode(w, &bpes); | ||
| 162 | + for (const auto &bpe : bpes) { | ||
| 163 | + oss << " " << bpe; | ||
| 164 | + } | ||
| 165 | + } else { | ||
| 166 | + oss << " " << w; | ||
| 167 | + } | ||
| 168 | + } | ||
| 169 | + } | ||
| 170 | + } | ||
| 171 | + std::string encoded_phrase = oss.str().substr(1); | ||
| 172 | + oss.clear(); | ||
| 173 | + oss.str(""); | ||
| 174 | + oss << encoded_phrase; | ||
| 175 | + if (!score.empty()) { | ||
| 176 | + oss << " " << score; | ||
| 177 | + } | ||
| 178 | + lines.push_back(oss.str()); | ||
| 179 | + } | ||
| 180 | + return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr); | ||
| 110 | } | 181 | } |
| 111 | 182 | ||
| 112 | bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | 183 | bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, |
| @@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | @@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, | ||
| 114 | std::vector<std::string> *keywords, | 185 | std::vector<std::string> *keywords, |
| 115 | std::vector<float> *boost_scores, | 186 | std::vector<float> *boost_scores, |
| 116 | std::vector<float> *threshold) { | 187 | std::vector<float> *threshold) { |
| 117 | - return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores, | 188 | + std::vector<std::string> lines; |
| 189 | + std::string line; | ||
| 190 | + while (std::getline(is, line)) { | ||
| 191 | + lines.push_back(line); | ||
| 192 | + } | ||
| 193 | + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores, | ||
| 118 | threshold); | 194 | threshold); |
| 119 | } | 195 | } |
| 120 | 196 |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/symbol-table.h" | 10 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 11 | +#include "ssentencepiece/csrc/ssentencepiece.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| @@ -25,7 +26,9 @@ namespace sherpa_onnx { | @@ -25,7 +26,9 @@ namespace sherpa_onnx { | ||
| 25 | * @return If all the symbols from ``is`` are in the symbol_table, returns true | 26 | * @return If all the symbols from ``is`` are in the symbol_table, returns true |
| 26 | * otherwise returns false. | 27 | * otherwise returns false. |
| 27 | */ | 28 | */ |
| 28 | -bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table, | 29 | +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, |
| 30 | + const SymbolTable &symbol_table, | ||
| 31 | + const ssentencepiece::Ssentencepiece *bpe_encoder_, | ||
| 29 | std::vector<std::vector<int32_t>> *hotwords_id); | 32 | std::vector<std::vector<int32_t>> *hotwords_id); |
| 30 | 33 | ||
| 31 | /* Encode the keywords in an input stream to be tokens ids. | 34 | /* Encode the keywords in an input stream to be tokens ids. |
| @@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | @@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | ||
| 76 | ans.model_config.model_type = p; | 76 | ans.model_config.model_type = p; |
| 77 | env->ReleaseStringUTFChars(s, p); | 77 | env->ReleaseStringUTFChars(s, p); |
| 78 | 78 | ||
| 79 | + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); | ||
| 80 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 81 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 82 | + ans.model_config.modeling_unit = p; | ||
| 83 | + env->ReleaseStringUTFChars(s, p); | ||
| 84 | + | ||
| 85 | + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); | ||
| 86 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 87 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 88 | + ans.model_config.bpe_vocab = p; | ||
| 89 | + env->ReleaseStringUTFChars(s, p); | ||
| 90 | + | ||
| 79 | // transducer | 91 | // transducer |
| 80 | fid = env->GetFieldID(model_config_cls, "transducer", | 92 | fid = env->GetFieldID(model_config_cls, "transducer", |
| 81 | "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); | 93 | "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); |
| @@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 195 | ans.model_config.model_type = p; | 195 | ans.model_config.model_type = p; |
| 196 | env->ReleaseStringUTFChars(s, p); | 196 | env->ReleaseStringUTFChars(s, p); |
| 197 | 197 | ||
| 198 | + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); | ||
| 199 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 200 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 201 | + ans.model_config.modeling_unit = p; | ||
| 202 | + env->ReleaseStringUTFChars(s, p); | ||
| 203 | + | ||
| 204 | + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); | ||
| 205 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 206 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 207 | + ans.model_config.bpe_vocab = p; | ||
| 208 | + env->ReleaseStringUTFChars(s, p); | ||
| 209 | + | ||
| 198 | //---------- rnn lm model config ---------- | 210 | //---------- rnn lm model config ---------- |
| 199 | fid = env->GetFieldID(cls, "lmConfig", | 211 | fid = env->GetFieldID(cls, "lmConfig", |
| 200 | "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); | 212 | "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); |
| @@ -40,6 +40,8 @@ data class OfflineModelConfig( | @@ -40,6 +40,8 @@ data class OfflineModelConfig( | ||
| 40 | var provider: String = "cpu", | 40 | var provider: String = "cpu", |
| 41 | var modelType: String = "", | 41 | var modelType: String = "", |
| 42 | var tokens: String, | 42 | var tokens: String, |
| 43 | + var modelingUnit: String = "", | ||
| 44 | + var bpeVocab: String = "", | ||
| 43 | ) | 45 | ) |
| 44 | 46 | ||
| 45 | data class OfflineRecognizerConfig( | 47 | data class OfflineRecognizerConfig( |
| @@ -43,6 +43,8 @@ data class OnlineModelConfig( | @@ -43,6 +43,8 @@ data class OnlineModelConfig( | ||
| 43 | var debug: Boolean = false, | 43 | var debug: Boolean = false, |
| 44 | var provider: String = "cpu", | 44 | var provider: String = "cpu", |
| 45 | var modelType: String = "", | 45 | var modelType: String = "", |
| 46 | + var modelingUnit: String = "", | ||
| 47 | + var bpeVocab: String = "", | ||
| 46 | ) | 48 | ) |
| 47 | 49 | ||
| 48 | data class OnlineLMConfig( | 50 | data class OnlineLMConfig( |
| @@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 36 | const OfflineTdnnModelConfig &, | 36 | const OfflineTdnnModelConfig &, |
| 37 | const OfflineZipformerCtcModelConfig &, | 37 | const OfflineZipformerCtcModelConfig &, |
| 38 | const OfflineWenetCtcModelConfig &, const std::string &, | 38 | const OfflineWenetCtcModelConfig &, const std::string &, |
| 39 | - int32_t, bool, const std::string &, const std::string &>(), | 39 | + int32_t, bool, const std::string &, const std::string &, |
| 40 | + const std::string &, const std::string &>(), | ||
| 40 | py::arg("transducer") = OfflineTransducerModelConfig(), | 41 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 41 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 42 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| 42 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | 43 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), |
| @@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 45 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), | 46 | py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), |
| 46 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), | 47 | py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), |
| 47 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 48 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 48 | - py::arg("provider") = "cpu", py::arg("model_type") = "") | 49 | + py::arg("provider") = "cpu", py::arg("model_type") = "", |
| 50 | + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") | ||
| 49 | .def_readwrite("transducer", &PyClass::transducer) | 51 | .def_readwrite("transducer", &PyClass::transducer) |
| 50 | .def_readwrite("paraformer", &PyClass::paraformer) | 52 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 51 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 53 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| @@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 58 | .def_readwrite("debug", &PyClass::debug) | 60 | .def_readwrite("debug", &PyClass::debug) |
| 59 | .def_readwrite("provider", &PyClass::provider) | 61 | .def_readwrite("provider", &PyClass::provider) |
| 60 | .def_readwrite("model_type", &PyClass::model_type) | 62 | .def_readwrite("model_type", &PyClass::model_type) |
| 63 | + .def_readwrite("modeling_unit", &PyClass::modeling_unit) | ||
| 64 | + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) | ||
| 61 | .def("validate", &PyClass::Validate) | 65 | .def("validate", &PyClass::Validate) |
| 62 | .def("__str__", &PyClass::ToString); | 66 | .def("__str__", &PyClass::ToString); |
| 63 | } | 67 | } |
| @@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 32 | const OnlineZipformer2CtcModelConfig &, | 32 | const OnlineZipformer2CtcModelConfig &, |
| 33 | const OnlineNeMoCtcModelConfig &, const std::string &, | 33 | const OnlineNeMoCtcModelConfig &, const std::string &, |
| 34 | int32_t, int32_t, bool, const std::string &, | 34 | int32_t, int32_t, bool, const std::string &, |
| 35 | + const std::string &, const std::string &, | ||
| 35 | const std::string &>(), | 36 | const std::string &>(), |
| 36 | py::arg("transducer") = OnlineTransducerModelConfig(), | 37 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 37 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 38 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| @@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 40 | py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), | 41 | py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), |
| 41 | py::arg("num_threads"), py::arg("warm_up") = 0, | 42 | py::arg("num_threads"), py::arg("warm_up") = 0, |
| 42 | py::arg("debug") = false, py::arg("provider") = "cpu", | 43 | py::arg("debug") = false, py::arg("provider") = "cpu", |
| 43 | - py::arg("model_type") = "") | 44 | + py::arg("model_type") = "", py::arg("modeling_unit") = "", |
| 45 | + py::arg("bpe_vocab") = "") | ||
| 44 | .def_readwrite("transducer", &PyClass::transducer) | 46 | .def_readwrite("transducer", &PyClass::transducer) |
| 45 | .def_readwrite("paraformer", &PyClass::paraformer) | 47 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 46 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 48 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| @@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 51 | .def_readwrite("debug", &PyClass::debug) | 53 | .def_readwrite("debug", &PyClass::debug) |
| 52 | .def_readwrite("provider", &PyClass::provider) | 54 | .def_readwrite("provider", &PyClass::provider) |
| 53 | .def_readwrite("model_type", &PyClass::model_type) | 55 | .def_readwrite("model_type", &PyClass::model_type) |
| 56 | + .def_readwrite("modeling_unit", &PyClass::modeling_unit) | ||
| 57 | + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) | ||
| 54 | .def("validate", &PyClass::Validate) | 58 | .def("validate", &PyClass::Validate) |
| 55 | .def("__str__", &PyClass::ToString); | 59 | .def("__str__", &PyClass::ToString); |
| 56 | } | 60 | } |
| @@ -49,6 +49,8 @@ class OfflineRecognizer(object): | @@ -49,6 +49,8 @@ class OfflineRecognizer(object): | ||
| 49 | hotwords_file: str = "", | 49 | hotwords_file: str = "", |
| 50 | hotwords_score: float = 1.5, | 50 | hotwords_score: float = 1.5, |
| 51 | blank_penalty: float = 0.0, | 51 | blank_penalty: float = 0.0, |
| 52 | + modeling_unit: str = "cjkchar", | ||
| 53 | + bpe_vocab: str = "", | ||
| 52 | debug: bool = False, | 54 | debug: bool = False, |
| 53 | provider: str = "cpu", | 55 | provider: str = "cpu", |
| 54 | model_type: str = "transducer", | 56 | model_type: str = "transducer", |
| @@ -91,6 +93,16 @@ class OfflineRecognizer(object): | @@ -91,6 +93,16 @@ class OfflineRecognizer(object): | ||
| 91 | hotwords_file is given with modified_beam_search as decoding method. | 93 | hotwords_file is given with modified_beam_search as decoding method. |
| 92 | blank_penalty: | 94 | blank_penalty: |
| 93 | The penalty applied on blank symbol during decoding. | 95 | The penalty applied on blank symbol during decoding. |
| 96 | + modeling_unit: | ||
| 97 | + The modeling unit of the model, commonly used units are bpe, cjkchar, | ||
| 98 | + cjkchar+bpe, etc. Currently, it is needed only when hotwords are | ||
| 99 | + provided, we need it to encode the hotwords into token sequence. | ||
| 100 | + and the modeling unit is bpe or cjkchar+bpe. | ||
| 101 | + bpe_vocab: | ||
| 102 | + The vocabulary generated by google's sentencepiece program. | ||
| 103 | + It is a file has two columns, one is the token, the other is | ||
| 104 | + the log probability, you can get it from the directory where | ||
| 105 | + your bpe model is generated. Only used when hotwords provided | ||
| 94 | debug: | 106 | debug: |
| 95 | True to show debug messages. | 107 | True to show debug messages. |
| 96 | provider: | 108 | provider: |
| @@ -107,6 +119,8 @@ class OfflineRecognizer(object): | @@ -107,6 +119,8 @@ class OfflineRecognizer(object): | ||
| 107 | num_threads=num_threads, | 119 | num_threads=num_threads, |
| 108 | debug=debug, | 120 | debug=debug, |
| 109 | provider=provider, | 121 | provider=provider, |
| 122 | + modeling_unit=modeling_unit, | ||
| 123 | + bpe_vocab=bpe_vocab, | ||
| 110 | model_type=model_type, | 124 | model_type=model_type, |
| 111 | ) | 125 | ) |
| 112 | 126 |
| @@ -58,6 +58,8 @@ class OnlineRecognizer(object): | @@ -58,6 +58,8 @@ class OnlineRecognizer(object): | ||
| 58 | hotwords_file: str = "", | 58 | hotwords_file: str = "", |
| 59 | provider: str = "cpu", | 59 | provider: str = "cpu", |
| 60 | model_type: str = "", | 60 | model_type: str = "", |
| 61 | + modeling_unit: str = "cjkchar", | ||
| 62 | + bpe_vocab: str = "", | ||
| 61 | lm: str = "", | 63 | lm: str = "", |
| 62 | lm_scale: float = 0.1, | 64 | lm_scale: float = 0.1, |
| 63 | temperature_scale: float = 2.0, | 65 | temperature_scale: float = 2.0, |
| @@ -136,6 +138,16 @@ class OnlineRecognizer(object): | @@ -136,6 +138,16 @@ class OnlineRecognizer(object): | ||
| 136 | model_type: | 138 | model_type: |
| 137 | Online transducer model type. Valid values are: conformer, lstm, | 139 | Online transducer model type. Valid values are: conformer, lstm, |
| 138 | zipformer, zipformer2. All other values lead to loading the model twice. | 140 | zipformer, zipformer2. All other values lead to loading the model twice. |
| 141 | + modeling_unit: | ||
| 142 | + The modeling unit of the model, commonly used units are bpe, cjkchar, | ||
| 143 | + cjkchar+bpe, etc. Currently, it is needed only when hotwords are | ||
| 144 | + provided, we need it to encode the hotwords into token sequence. | ||
| 145 | + bpe_vocab: | ||
| 146 | + The vocabulary generated by google's sentencepiece program. | ||
| 147 | + It is a file has two columns, one is the token, the other is | ||
| 148 | + the log probability, you can get it from the directory where | ||
| 149 | + your bpe model is generated. Only used when hotwords provided | ||
| 150 | + and the modeling unit is bpe or cjkchar+bpe. | ||
| 139 | """ | 151 | """ |
| 140 | self = cls.__new__(cls) | 152 | self = cls.__new__(cls) |
| 141 | _assert_file_exists(tokens) | 153 | _assert_file_exists(tokens) |
| @@ -157,6 +169,8 @@ class OnlineRecognizer(object): | @@ -157,6 +169,8 @@ class OnlineRecognizer(object): | ||
| 157 | num_threads=num_threads, | 169 | num_threads=num_threads, |
| 158 | provider=provider, | 170 | provider=provider, |
| 159 | model_type=model_type, | 171 | model_type=model_type, |
| 172 | + modeling_unit=modeling_unit, | ||
| 173 | + bpe_vocab=bpe_vocab, | ||
| 160 | debug=debug, | 174 | debug=debug, |
| 161 | ) | 175 | ) |
| 162 | 176 |
-
请 注册 或 登录 后发表评论