Wei Kang
Committed by GitHub

Encode hotwords in C++ side (#828)

* Encode hotwords in C++ side
正在显示 43 个修改的文件 包含 713 行增加101 行删除
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
log "test online NeMo CTC"
url=https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2
... ...
... ... @@ -8,6 +8,8 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export GIT_CLONE_PROTECTION_ACTIVE=false
echo "EXE is $EXE"
echo "PATH: $PATH"
... ...
... ... @@ -234,6 +234,7 @@ endif()
include(kaldi-native-fbank)
include(kaldi-decoder)
include(onnxruntime)
include(simple-sentencepiece)
set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR})
message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}")
... ...
... ... @@ -126,7 +126,7 @@ echo "Generate xcframework"
mkdir -p "build/simulator/lib"
for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a; do
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a libssentencepiece_core.a; do
lipo -create build/simulator_arm64/lib/${f} \
build/simulator_x86_64/lib/${f} \
-output build/simulator/lib/${f}
... ... @@ -140,7 +140,8 @@ libtool -static -o build/simulator/sherpa-onnx.a \
build/simulator/lib/libsherpa-onnx-core.a \
build/simulator/lib/libsherpa-onnx-fst.a \
build/simulator/lib/libsherpa-onnx-kaldifst-core.a \
build/simulator/lib/libkaldi-decoder-core.a
build/simulator/lib/libkaldi-decoder-core.a \
build/simulator/lib/libssentencepiece_core.a
libtool -static -o build/os64/sherpa-onnx.a \
build/os64/lib/libkaldi-native-fbank-core.a \
... ... @@ -148,7 +149,8 @@ libtool -static -o build/os64/sherpa-onnx.a \
build/os64/lib/libsherpa-onnx-core.a \
build/os64/lib/libsherpa-onnx-fst.a \
build/os64/lib/libsherpa-onnx-kaldifst-core.a \
build/os64/lib/libkaldi-decoder-core.a
build/os64/lib/libkaldi-decoder-core.a \
build/os64/lib/libssentencepiece_core.a
rm -rf sherpa-onnx.xcframework
... ...
... ... @@ -129,7 +129,7 @@ echo "Generate xcframework"
mkdir -p "build/simulator/lib"
for f in libkaldi-native-fbank-core.a libsherpa-onnx-c-api.a libsherpa-onnx-core.a \
libsherpa-onnx-fstfar.a \
libsherpa-onnx-fstfar.a libssentencepiece_core.a \
libsherpa-onnx-fst.a libsherpa-onnx-kaldifst-core.a libkaldi-decoder-core.a \
libucd.a libpiper_phonemize.a libespeak-ng.a; do
lipo -create build/simulator_arm64/lib/${f} \
... ... @@ -150,6 +150,7 @@ libtool -static -o build/simulator/sherpa-onnx.a \
build/simulator/lib/libucd.a \
build/simulator/lib/libpiper_phonemize.a \
build/simulator/lib/libespeak-ng.a \
build/simulator/lib/libssentencepiece_core.a
libtool -static -o build/os64/sherpa-onnx.a \
build/os64/lib/libkaldi-native-fbank-core.a \
... ... @@ -162,6 +163,7 @@ libtool -static -o build/os64/sherpa-onnx.a \
build/os64/lib/libucd.a \
build/os64/lib/libpiper_phonemize.a \
build/os64/lib/libespeak-ng.a \
build/os64/lib/libssentencepiece_core.a
rm -rf sherpa-onnx.xcframework
... ...
... ... @@ -33,4 +33,5 @@ libtool -static -o ./install/lib/libsherpa-onnx.a \
./install/lib/libkaldi-decoder-core.a \
./install/lib/libucd.a \
./install/lib/libpiper_phonemize.a \
./install/lib/libespeak-ng.a
./install/lib/libespeak-ng.a \
./install/lib/libssentencepiece_core.a
... ...
function(download_simple_sentencepiece)
include(FetchContent)
set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
set(simple-sentencepiece_URL2 "https://hub.nauu.cf/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz")
set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af")
# If you don't have access to the Internet,
# please pre-download simple-sentencepiece
set(possible_file_locations
$ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz
${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz
${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz
/tmp/simple-sentencepiece-0.7.tar.gz
/star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(simple-sentencepiece_URL "${f}")
file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL)
message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}")
set(simple-sentencepiece_URL2)
break()
endif()
endforeach()
set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
FetchContent_Declare(simple-sentencepiece
URL
${simple-sentencepiece_URL}
${simple-sentencepiece_URL2}
URL_HASH
${simple-sentencepiece_HASH}
)
FetchContent_GetProperties(simple-sentencepiece)
if(NOT simple-sentencepiece_POPULATED)
message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}")
FetchContent_Populate(simple-sentencepiece)
endif()
message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}")
add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(ssentencepiece_core
PUBLIC
${simple-sentencepiece_SOURCE_DIR}/
)
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
install(TARGETS ssentencepiece_core DESTINATION ..)
else()
install(TARGETS ssentencepiece_core DESTINATION lib)
endif()
if(WIN32 AND BUILD_SHARED_LIBS)
install(TARGETS ssentencepiece_core DESTINATION bin)
endif()
endfunction()
download_simple_sentencepiece()
... ...
... ... @@ -60,7 +60,7 @@ function testSpeakerEmbeddingExtractor() {
function testOnlineAsr() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then
git lfs install
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
fi
if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then
... ...
... ... @@ -18,6 +18,7 @@
piper_phonemize.lib;
espeak-ng.lib;
ucd.lib;
ssentencepiece_core.lib;
</SherpaOnnxLibraries>
</PropertyGroup>
<ItemDefinitionGroup>
... ...
... ... @@ -18,6 +18,7 @@
piper_phonemize.lib;
espeak-ng.lib;
ucd.lib;
ssentencepiece_core.lib;
</SherpaOnnxLibraries>
</PropertyGroup>
<ItemDefinitionGroup>
... ...
... ... @@ -18,6 +18,7 @@
piper_phonemize.lib;
espeak-ng.lib;
ucd.lib;
ssentencepiece_core.lib;
</SherpaOnnxLibraries>
</PropertyGroup>
<ItemDefinitionGroup>
... ...
... ... @@ -110,11 +110,9 @@ def get_args():
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
... ... @@ -129,6 +127,28 @@ def get_args():
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
... ... @@ -347,6 +367,8 @@ def main():
decoding_method=args.decoding_method,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
blank_penalty=args.blank_penalty,
debug=args.debug,
)
... ...
... ... @@ -198,11 +198,9 @@ def get_args():
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
... ... @@ -217,6 +215,28 @@ def get_args():
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
... ... @@ -302,6 +322,8 @@ def main():
lm_scale=args.lm_scale,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
blank_penalty=args.blank_penalty,
)
elif args.zipformer2_ctc:
... ...
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
from typing import Dict
try:
import sentencepiece as spm
except ImportError:
print('Please run')
print(' pip install sentencepiece')
print('before you continue')
raise
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="The path to the bpe model.",
)
return parser.parse_args()
def main():
args = get_args()
model_file = args.bpe_model
vocab_file = model_file.replace(".model", ".vocab")
sp = spm.SentencePieceProcessor()
sp.Load(model_file)
vocabs = [sp.IdToPiece(id) for id in range(sp.GetPieceSize())]
with open(vocab_file, "w") as vfile:
for v in vocabs:
id = sp.PieceToId(v)
vfile.write(f"{v}\t{sp.GetScore(id)}\n")
print(f"Vocabulary file is written to {vocab_file}")
if __name__ == "__main__":
main()
... ...
... ... @@ -165,6 +165,7 @@ endif()
target_link_libraries(sherpa-onnx-core
kaldi-native-fbank-core
kaldi-decoder-core
ssentencepiece_core
)
if(SHERPA_ONNX_ENABLE_GPU)
... ... @@ -491,6 +492,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
pad-sequence-test.cc
slice-test.cc
stack-test.cc
text2token-test.cc
transpose-test.cc
unbind-test.cc
utfcpp-test.cc
... ...
... ... @@ -35,6 +35,17 @@ void OfflineModelConfig::Register(ParseOptions *po) {
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn, zipformer2_ctc"
"All other values lead to loading the model twice.");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
"It is a file has two columns, one is the token, the other is "
"the log probability, you can get it from the directory where "
"your bpe model is generated. Only used when hotwords provided "
"and the modeling unit is bpe or cjkchar+bpe");
}
bool OfflineModelConfig::Validate() const {
... ... @@ -48,6 +59,14 @@ bool OfflineModelConfig::Validate() const {
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
return false;
}
}
if (!paraformer.model.empty()) {
return paraformer.Validate();
}
... ... @@ -90,7 +109,9 @@ std::string OfflineModelConfig::ToString() const {
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";
return os.str();
}
... ...
... ... @@ -41,6 +41,9 @@ struct OfflineModelConfig {
// All other values are invalid and lead to loading the model twice.
std::string model_type;
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;
OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
... ... @@ -50,7 +53,9 @@ struct OfflineModelConfig {
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
const std::string &provider, const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
... ... @@ -62,7 +67,9 @@ struct OfflineModelConfig {
num_threads(num_threads),
debug(debug),
provider(provider),
model_type(model_type) {}
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -31,6 +31,7 @@
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
... ... @@ -76,9 +77,6 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
... ... @@ -87,6 +85,15 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
lm_ = OfflineLM::Create(config.lm_config);
}
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
... ... @@ -112,6 +119,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
lm_ = OfflineLM::Create(mgr, config.lm_config);
}
if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
... ... @@ -128,7 +145,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, symbol_table_, &current)) {
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
... ... @@ -207,19 +225,47 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
#if __ANDROID_API__ >= 9
void InitHotwords(AAssetManager *mgr) {
// each line in hotwords_file contains space-separated words
auto buf = ReadFile(mgr, config_.hotwords_file);
std::istringstream is(std::string(buf.begin(), buf.end()));
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
#endif
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;
... ...
... ... @@ -37,10 +37,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
... ...
... ... @@ -32,6 +32,19 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
"It is a file has two columns, one is the token, the other is "
"the log probability, you can get it from the directory where "
"your bpe model is generated. Only used when hotwords provided "
"and the modeling unit is bpe or cjkchar+bpe");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
... ... @@ -50,6 +63,14 @@ bool OnlineModelConfig::Validate() const {
return false;
}
if (!modeling_unit.empty() &&
(modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) {
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("bpe_vocab: %s does not exist", bpe_vocab.c_str());
return false;
}
}
if (!paraformer.encoder.empty()) {
return paraformer.Validate();
}
... ... @@ -83,7 +104,9 @@ std::string OnlineModelConfig::ToString() const {
os << "warm_up=" << warm_up << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";
return os.str();
}
... ...
... ... @@ -37,6 +37,13 @@ struct OnlineModelConfig {
// All other values are invalid and lead to loading the model twice.
std::string model_type;
// Valid values:
// - cjkchar
// - bpe
// - cjkchar+bpe
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
... ... @@ -45,7 +52,9 @@ struct OnlineModelConfig {
const OnlineNeMoCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
const std::string &model_type)
const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
... ... @@ -56,7 +65,9 @@ struct OnlineModelConfig {
warm_up(warm_up),
debug(debug),
provider(provider),
model_type(model_type) {}
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -15,8 +15,6 @@
#include <vector>
#if __ANDROID_API__ >= 9
#include <strstream>
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
... ... @@ -33,6 +31,7 @@
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
... ... @@ -94,6 +93,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_->SetFeatureDim(config.feat_config.feature_dim);
if (config.decoding_method == "modified_beam_search") {
if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
... ... @@ -140,6 +144,12 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
if (!config_.hotwords_file.empty()) {
InitHotwords(mgr);
}
... ... @@ -174,7 +184,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
... ... @@ -363,9 +374,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
... ... @@ -377,7 +390,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto buf = ReadFile(mgr, config_.hotwords_file);
std::istrstream is(buf.data(), buf.size());
std::istringstream is(std::string(buf.begin(), buf.end()));
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
... ... @@ -385,9 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
... ... @@ -413,6 +428,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
... ...
... ... @@ -51,9 +51,7 @@ std::string VecToString<std::string>(const std::vector<std::string> &vec,
std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": "
<< "\"" << text << "\""
<< ", ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
... ... @@ -89,10 +87,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"Used only when decoding_method is modified_beam_search");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
... ...
... ... @@ -38,35 +38,6 @@ void SymbolTable::Init(std::istream &is) {
std::string sym;
int32_t id;
while (is >> sym >> id) {
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}
// for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
os << std::hex << std::uppercase << (id - 3);
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
uint8_t i = id - 3;
sym = std::string(&i, &i + 1);
}
}
assert(!sym.empty());
// for byte bpe, after replacing ▁ with a space, whose ascii is also 0x20,
// there is a conflict between the real byte 0x20 and ▁, so we disable
// the following check.
//
// Note: Only id2sym_ matters as we use it to convert ID to symbols.
#if 0
// we disable the test here since for some multi-lingual BPE models
// from NeMo, the same symbol can appear multiple times with different IDs.
... ... @@ -92,8 +63,30 @@ std::string SymbolTable::ToString() const {
return os.str();
}
const std::string &SymbolTable::operator[](int32_t id) const {
return id2sym_.at(id);
const std::string SymbolTable::operator[](int32_t id) const {
std::string sym = id2sym_.at(id);
if (sym.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
sym = sym.replace(0, 3, " ");
}
}
// for byte-level BPE
// id 0 is blank, id 1 is sos/eos, id 2 is unk
if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' &&
sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') {
std::ostringstream os;
os << std::hex << std::uppercase << (id - 3);
if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) {
uint8_t i = id - 3;
sym = std::string(&i, &i + 1);
}
}
return sym;
}
int32_t SymbolTable::operator[](const std::string &sym) const {
... ...
... ... @@ -35,7 +35,7 @@ class SymbolTable {
std::string ToString() const;
/// Return the symbol corresponding to the given ID.
const std::string &operator[](int32_t id) const;
const std::string operator[](int32_t id) const;
/// Return the ID corresponding to the given symbol.
int32_t operator[](const std::string &sym) const;
... ...
// sherpa-onnx/csrc/text2token-test.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <fstream>
#include <sstream>
#include <string>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
// Please refer to
// https://github.com/pkufool/sherpa-test-data
// to download test data for testing
static const char dir[] = "/tmp/sherpa-test-data";
TEST(TEXT2TOKEN, TEST_cjkchar) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_cn.txt";
std::string tokens = oss.str();
if (!std::ifstream(tokens).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_cjkchar()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
std::string text = "世界人民大团结\n中国 V S 美国";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_bpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_en.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bpe_en.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_bpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "HELLO WORLD\nI LOVE YOU";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{22, 58, 24, 425}, {19, 370, 47}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_mix.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bpe_mix.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_cjkchar_bpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "世界人民 GOES TOGETHER\n中国 GOES WITH 美国";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r =
EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{1368, 1392, 557, 680, 275, 178, 475},
{685, 736, 275, 178, 179, 921, 736}});
EXPECT_EQ(ids, expected_ids);
}
TEST(TEXT2TOKEN, TEST_bbpe) {
std::ostringstream oss;
oss << dir << "/text2token/tokens_bbpe.txt";
std::string tokens = oss.str();
oss.clear();
oss.str("");
oss << dir << "/text2token/bbpe.vocab";
std::string bpe = oss.str();
if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) {
SHERPA_ONNX_LOGE(
"No test data found, skipping TEST_bbpe()."
"You can download the test data by: "
"git clone https://github.com/pkufool/sherpa-test-data.git "
"/tmp/sherpa-test-data");
return;
}
auto sym_table = SymbolTable(tokens);
auto bpe_processor = std::make_unique<ssentencepiece::Ssentencepiece>(bpe);
std::string text = "频繁\n李鞑靼";
std::istringstream iss(text);
std::vector<std::vector<int32_t>> ids;
auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids);
std::vector<std::vector<int32_t>> expected_ids(
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
EXPECT_EQ(ids, expected_ids);
}
} // namespace sherpa_onnx
... ...
... ... @@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/utils.h"
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
... ... @@ -12,15 +13,16 @@
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
static bool EncodeBase(const std::vector<std::string> &lines,
const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *ids,
std::vector<std::string> *phrases,
std::vector<float> *scores,
std::vector<float> *thresholds) {
SHERPA_ONNX_CHECK(ids != nullptr);
ids->clear();
std::vector<int32_t> tmp_ids;
... ... @@ -33,22 +35,15 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
bool has_scores = false;
bool has_thresholds = false;
bool has_phrases = false;
bool has_oov = false;
while (std::getline(is, line)) {
for (const auto &line : lines) {
float score = 0;
float threshold = 0;
std::string phrase = "";
std::istringstream iss(line);
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
word = word.replace(0, 3, " ");
}
}
if (symbol_table.Contains(word)) {
int32_t id = symbol_table[word];
tmp_ids.push_back(id);
... ... @@ -71,7 +66,8 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
"Cannot find ID for token %s at line: %s. (Hint: words on "
"the same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
has_oov = true;
break;
}
}
}
... ... @@ -101,12 +97,87 @@ static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
thresholds->clear();
}
}
return true;
return !has_oov;
}
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const ssentencepiece::Ssentencepiece *bpe_encoder,
std::vector<std::vector<int32_t>> *hotwords) {
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
std::vector<std::string> lines;
std::string line;
std::string word;
while (std::getline(is, line)) {
std::string score;
std::string phrase;
std::ostringstream oss;
std::istringstream iss(line);
while (iss >> word) {
switch (word[0]) {
case ':': // boosting score for current keyword
score = word;
break;
default:
if (!score.empty()) {
SHERPA_ONNX_LOGE(
"Boosting score should be put after the words/phrase, given "
"%s.",
line.c_str());
return false;
}
oss << " " << word;
break;
}
}
phrase = oss.str().substr(1);
std::istringstream piss(phrase);
oss.clear();
oss.str("");
while (piss >> word) {
if (modeling_unit == "cjkchar") {
for (const auto &w : SplitUtf8(word)) {
oss << " " << w;
}
} else if (modeling_unit == "bpe") {
std::vector<std::string> bpes;
bpe_encoder->Encode(word, &bpes);
for (const auto &bpe : bpes) {
oss << " " << bpe;
}
} else {
if (modeling_unit != "cjkchar+bpe") {
SHERPA_ONNX_LOGE(
"modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, "
"given "
"%s",
modeling_unit.c_str());
exit(-1);
}
for (const auto &w : SplitUtf8(word)) {
if (isalpha(w[0])) {
std::vector<std::string> bpes;
bpe_encoder->Encode(w, &bpes);
for (const auto &bpe : bpes) {
oss << " " << bpe;
}
} else {
oss << " " << w;
}
}
}
}
std::string encoded_phrase = oss.str().substr(1);
oss.clear();
oss.str("");
oss << encoded_phrase;
if (!score.empty()) {
oss << " " << score;
}
lines.push_back(oss.str());
}
return EncodeBase(lines, symbol_table, hotwords, nullptr, nullptr, nullptr);
}
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
... ... @@ -114,7 +185,12 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::string> *keywords,
std::vector<float> *boost_scores,
std::vector<float> *threshold) {
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
std::vector<std::string> lines;
std::string line;
while (std::getline(is, line)) {
lines.push_back(line);
}
return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores,
threshold);
}
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/symbol-table.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
namespace sherpa_onnx {
... ... @@ -25,7 +26,9 @@ namespace sherpa_onnx {
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const ssentencepiece::Ssentencepiece *bpe_encoder_,
std::vector<std::vector<int32_t>> *hotwords_id);
/* Encode the keywords in an input stream to be tokens ids.
... ...
... ... @@ -76,6 +76,18 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.modeling_unit = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
... ...
... ... @@ -195,6 +195,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.modeling_unit = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
... ...
... ... @@ -40,6 +40,8 @@ data class OfflineModelConfig(
var provider: String = "cpu",
var modelType: String = "",
var tokens: String,
var modelingUnit: String = "",
var bpeVocab: String = "",
)
data class OfflineRecognizerConfig(
... ...
... ... @@ -43,6 +43,8 @@ data class OnlineModelConfig(
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
var modelingUnit: String = "",
var bpeVocab: String = "",
)
data class OnlineLMConfig(
... ...
... ... @@ -36,7 +36,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
int32_t, bool, const std::string &, const std::string &,
const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
... ... @@ -45,7 +46,8 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
py::arg("provider") = "cpu", py::arg("model_type") = "",
py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
... ... @@ -58,6 +60,8 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -32,6 +32,7 @@ void PybindOnlineModelConfig(py::module *m) {
const OnlineZipformer2CtcModelConfig &,
const OnlineNeMoCtcModelConfig &, const std::string &,
int32_t, int32_t, bool, const std::string &,
const std::string &, const std::string &,
const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
... ... @@ -40,7 +41,8 @@ void PybindOnlineModelConfig(py::module *m) {
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("provider") = "cpu",
py::arg("model_type") = "")
py::arg("model_type") = "", py::arg("modeling_unit") = "",
py::arg("bpe_vocab") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
... ... @@ -51,6 +53,8 @@ void PybindOnlineModelConfig(py::module *m) {
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -49,6 +49,8 @@ class OfflineRecognizer(object):
hotwords_file: str = "",
hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
debug: bool = False,
provider: str = "cpu",
model_type: str = "transducer",
... ... @@ -91,6 +93,16 @@ class OfflineRecognizer(object):
hotwords_file is given with modified_beam_search as decoding method.
blank_penalty:
The penalty applied on blank symbol during decoding.
modeling_unit:
The modeling unit of the model, commonly used units are bpe, cjkchar,
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
provided, we need it to encode the hotwords into token sequence.
and the modeling unit is bpe or cjkchar+bpe.
bpe_vocab:
The vocabulary generated by google's sentencepiece program.
It is a file has two columns, one is the token, the other is
the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided
debug:
True to show debug messages.
provider:
... ... @@ -107,6 +119,8 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
model_type=model_type,
)
... ...
... ... @@ -58,6 +58,8 @@ class OnlineRecognizer(object):
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
... ... @@ -136,6 +138,16 @@ class OnlineRecognizer(object):
model_type:
Online transducer model type. Valid values are: conformer, lstm,
zipformer, zipformer2. All other values lead to loading the model twice.
modeling_unit:
The modeling unit of the model, commonly used units are bpe, cjkchar,
cjkchar+bpe, etc. Currently, it is needed only when hotwords are
provided, we need it to encode the hotwords into token sequence.
bpe_vocab:
The vocabulary generated by google's sentencepiece program.
It is a file has two columns, one is the token, the other is
the log probability, you can get it from the directory where
your bpe model is generated. Only used when hotwords provided
and the modeling unit is bpe or cjkchar+bpe.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
... ... @@ -157,6 +169,8 @@ class OnlineRecognizer(object):
num_threads=num_threads,
provider=provider,
model_type=model_type,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
debug=debug,
)
... ...