Wei Kang
Committed by GitHub

Refactor hotwords,support loading hotwords from file (#296)

正在显示 34 个修改的文件 包含 800 行增加297 行删除
... ... @@ -166,3 +166,8 @@ python3 ./python-api-examples/offline-decode-files.py \
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $repo
# test text2token
git clone https://github.com/pkufool/sherpa-test-data /tmp/sherpa-test-data
python3 sherpa-onnx/python/tests/test_text2token.py --verbose
... ...
... ... @@ -39,7 +39,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy
python3 -m pip install --upgrade pip numpy sentencepiece
- name: Install sherpa-onnx
shell: bash
... ...
... ... @@ -39,7 +39,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy
python3 -m pip install --upgrade pip numpy sentencepiece
- name: Install sherpa-onnx
shell: bash
... ...
... ... @@ -326,6 +326,31 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
)
def add_hotwords_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
def check_args(args):
if not Path(args.tokens).is_file():
raise ValueError(f"{args.tokens} does not exist")
... ... @@ -342,6 +367,10 @@ def check_args(args):
assert Path(args.decoder).is_file(), args.decoder
assert Path(args.joiner).is_file(), args.joiner
if args.hotwords_file != "":
assert args.decoding_method == "modified_beam_search", args.decoding_method
assert Path(args.hotwords_file).is_file(), args.hotwords_file
def get_args():
parser = argparse.ArgumentParser(
... ... @@ -351,6 +380,7 @@ def get_args():
add_model_args(parser)
add_feature_config_args(parser)
add_decoding_args(parser)
add_hotwords_args(parser)
parser.add_argument(
"--port",
... ... @@ -792,6 +822,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
... ...
... ... @@ -82,7 +82,6 @@ from pathlib import Path
from typing import List, Tuple
import numpy as np
import sentencepiece as spm
import sherpa_onnx
... ... @@ -98,43 +97,25 @@ def get_args():
)
parser.add_argument(
"--bpe-model",
"--hotwords-file",
type=str,
default="",
help="""
Path to bpe.model,
Used only when --decoding-method=modified_beam_search
""",
)
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
""",
)
parser.add_argument(
"--context-score",
"--hotwords-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
... ... @@ -273,25 +254,6 @@ def assert_file_exists(filename: str):
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
... ... @@ -322,7 +284,6 @@ def main():
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
contexts_list = []
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
... ... @@ -330,11 +291,6 @@ def main():
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
... ... @@ -348,7 +304,8 @@ def main():
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
context_score=args.context_score,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
debug=args.debug,
)
elif args.paraformer:
... ... @@ -425,12 +382,7 @@ def main():
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
if contexts_list:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
... ...
... ... @@ -48,7 +48,6 @@ from pathlib import Path
from typing import List, Tuple
import numpy as np
import sentencepiece as spm
import sherpa_onnx
... ... @@ -124,46 +123,25 @@ def get_args():
)
parser.add_argument(
"--bpe-model",
"--hotwords-file",
type=str,
default="",
help="""
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
Used only when --decoding-method=modified_beam_search
""",
)
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
Used only when --decoding-method=modified_beam_search
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--context-score",
"--hotwords-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
Used only when --decoding-method=modified_beam_search
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
... ... @@ -214,27 +192,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
return samples_float32, f.getframerate()
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit,
contexts=contexts,
sp=sp,
tokens_table=tokens,
)
def main():
args = get_args()
assert_file_exists(args.tokens)
... ... @@ -258,7 +215,8 @@ def main():
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
... ... @@ -277,12 +235,6 @@ def main():
print("Started!")
start_time = time.time()
contexts_list = []
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
streams = []
total_duration = 0
for wave_filename in args.sound_files:
... ... @@ -291,10 +243,7 @@ def main():
duration = len(samples) / sample_rate
total_duration += duration
if contexts_list:
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
... ...
... ... @@ -79,6 +79,30 @@ def get_args():
help="Valid values: cpu, cuda, coreml",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
return parser.parse_args()
... ... @@ -104,6 +128,8 @@ def create_recognizer(args):
rule3_min_utterance_length=300, # it essentially disables this rule
decoding_method=args.decoding_method,
provider=args.provider,
hotwords_file=agrs.hotwords_file,
hotwords_score=args.hotwords_score,
)
return recognizer
... ...
... ... @@ -11,7 +11,6 @@ import sys
from pathlib import Path
from typing import List
import sentencepiece as spm
try:
import sounddevice as sd
... ... @@ -90,49 +89,29 @@ def get_args():
)
parser.add_argument(
"--bpe-model",
"--hotwords-file",
type=str,
default="",
help="""
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
Used only when --decoding-method=modified_beam_search
""",
)
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
Used only when --decoding-method=modified_beam_search
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--context-score",
"--hotwords-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
Used only when --decoding-method=modified_beam_search
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
return parser.parse_args()
... ... @@ -155,32 +134,12 @@ def create_recognizer(args):
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
provider=args.provider,
context_score=args.context_score,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
return recognizer
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit,
contexts=contexts,
sp=sp,
tokens_table=tokens,
)
def main():
args = get_args()
... ... @@ -193,12 +152,6 @@ def main():
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
contexts_list = []
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
recognizer = create_recognizer(args)
print("Started! Please speak")
... ... @@ -207,10 +160,7 @@ def main():
sample_rate = 48000
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
last_result = ""
if contexts_list:
stream = recognizer.create_stream(contexts_list=contexts_list)
else:
stream = recognizer.create_stream()
stream = recognizer.create_stream()
with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
while True:
samples, _ = s.read(samples_per_read) # a blocking read
... ...
... ... @@ -87,6 +87,30 @@ def get_args():
""",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
return parser.parse_args()
... ... @@ -107,6 +131,8 @@ def create_recognizer(args):
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300, # it essentially disables this rule
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
return recognizer
... ...
... ... @@ -187,6 +187,32 @@ def add_decoding_args(parser: argparse.ArgumentParser):
add_modified_beam_search_args(parser)
def add_hotwords_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
▁HE LL O ▁WORLD
你 好 世 界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-active-paths",
... ... @@ -239,6 +265,7 @@ def get_args():
add_model_args(parser)
add_decoding_args(parser)
add_endpointing_args(parser)
add_hotwords_args(parser)
parser.add_argument(
"--port",
... ... @@ -343,6 +370,8 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.num_active_paths,
hotwords_score=args.hotwords_score,
hotwords_file=args.hotwords_file,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
... ...
#!/usr/bin/env python3
"""
This script encode the texts (given line by line through `text`) to tokens and
write the results to the file given by ``output``.
Usage:
If the tokens_type is bpe:
python3 ./text2token.py \
--text texts.txt \
--tokens tokens.txt \
--tokens-type bpe \
--bpe-model bpe.model \
--output hotwords.txt
If the tokens_type is cjkchar:
python3 ./text2token.py \
--text texts.txt \
--tokens tokens.txt \
--tokens-type cjkchar \
--output hotwords.txt
If the tokens_type is cjkchar+bpe:
python3 ./text2token.py \
--text texts.txt \
--tokens tokens.txt \
--tokens-type cjkchar+bpe \
--bpe-model bpe.model \
--output hotwords.txt
"""
import argparse
from sherpa_onnx import text2token
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
type=str,
required=True,
help="Path to the input texts",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="The path to tokens.txt.",
)
parser.add_argument(
"--tokens-type",
type=str,
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
)
parser.add_argument(
"--bpe-model",
type=str,
help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Path where the encoded tokens will be written to.",
)
return parser.parse_args()
def main():
args = get_args()
texts = []
with open(args.text, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
encoded_texts = text2token(
texts,
tokens=args.tokens,
tokens_type=args.tokens_type,
bpe_model=args.bpe_model,
)
with open(args.output, "w", encoding="utf8") as f:
for txt in encoded_texts:
f.write(" ".join(txt) + "\n")
if __name__ == "__main__":
main()
... ...
... ... @@ -39,6 +39,7 @@ install_requires = [
"numpy",
"sentencepiece==0.1.96; python_version < '3.11'",
"sentencepiece; python_version >= '3.11'",
"click>=7.1.1",
]
... ... @@ -93,6 +94,11 @@ setuptools.setup(
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
entry_points={
'console_scripts': [
'sherpa-onnx-cli=sherpa_onnx.cli:cli',
],
},
license="Apache licensed, as found in the LICENSE file",
)
... ...
... ... @@ -72,6 +72,7 @@ set(sources
text-utils.cc
transpose.cc
unbind.cc
utils.cc
wave-reader.cc
)
... ...
... ... @@ -4,11 +4,14 @@
#include "sherpa-onnx/csrc/context-graph.h"
#include <chrono> // NOLINT
#include <map>
#include <random>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
... ... @@ -41,4 +44,29 @@ TEST(ContextGraph, TestBasic) {
}
}
TEST(ContextGraph, Benchmark) {
std::random_device rd;
std::mt19937 mt(rd());
std::uniform_int_distribution<int32_t> char_dist(0, 25);
std::uniform_int_distribution<int32_t> len_dist(3, 8);
for (int32_t num = 10; num <= 10000; num *= 10) {
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < num; ++i) {
std::vector<int32_t> tmp;
int32_t word_len = len_dist(mt);
for (int32_t j = 0; j < word_len; ++j) {
tmp.push_back(char_dist(mt));
}
contexts.push_back(std::move(tmp));
}
auto start = std::chrono::high_resolution_clock::now();
auto context_graph = ContextGraph(contexts, 1);
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
duration.count());
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
... ... @@ -32,7 +33,7 @@ class OfflineRecognizerImpl {
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
const std::string &hotwords) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
... ...
... ... @@ -5,7 +5,9 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
#include <fstream>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
... ... @@ -16,6 +18,7 @@
#endif
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
... ... @@ -25,6 +28,7 @@
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
... ... @@ -60,6 +64,9 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
... ... @@ -105,17 +112,24 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
#endif
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, symbol_table_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(context_list, config_.context_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
return std::make_unique<OfflineStream>(config_.feat_config,
hotwords_graph_);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
... ... @@ -171,9 +185,29 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words
std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, symbol_table_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;
... ...
... ... @@ -26,7 +26,15 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("context-score", &context_score,
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
}
... ... @@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ")";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ")";
return os.str();
}
... ... @@ -70,8 +79,8 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
... ...
... ... @@ -31,7 +31,10 @@ struct OfflineRecognizerConfig {
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
float context_score = 1.5;
std::string hotwords_file;
float hotwords_score = 1.5;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
... ... @@ -40,13 +43,16 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths, float context_score)
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
context_score(context_score) {}
hotwords_file(hotwords_file),
hotwords_score(hotwords_score) {}
void Register(ParseOptions *po);
bool Validate() const;
... ... @@ -69,9 +75,17 @@ class OfflineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const;
/// Create a stream for decoding.
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
const std::string &hotwords) const;
/** Decode a single stream
*
... ...
... ... @@ -6,6 +6,7 @@
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -29,7 +30,7 @@ class OnlineRecognizerImpl {
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
const std::string &hotwords) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
... ...
... ... @@ -7,6 +7,8 @@
#include <algorithm>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <utility>
#include <vector>
... ... @@ -20,6 +22,7 @@
#include "sherpa-onnx/csrc/online-transducer-model.h"
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
namespace sherpa_onnx {
... ... @@ -57,6 +60,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (!config_.hotwords_file.empty()) {
InitHotwords();
}
if (sym_.contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
... ... @@ -106,18 +112,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
#endif
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, hotwords_graph_);
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, sym_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
... ... @@ -253,6 +265,24 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}
void InitHotwords() {
// each line in hotwords_file contains space-separated words
std::ifstream is(config_.hotwords_file);
if (!is) {
SHERPA_ONNX_LOGE("Open hotwords file failed: %s",
config_.hotwords_file.c_str());
exit(-1);
}
if (!EncodeHotwords(is, sym_, &hotwords_)) {
SHERPA_ONNX_LOGE("Encode hotwords failed.");
exit(-1);
}
hotwords_graph_ =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
}
private:
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
... ... @@ -271,6 +301,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
private:
OnlineRecognizerConfig config_;
std::vector<std::vector<int32_t>> hotwords_;
ContextGraphPtr hotwords_graph_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
... ...
... ... @@ -57,9 +57,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("context-score", &context_score,
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
... ... @@ -87,7 +93,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
... ... @@ -109,8 +116,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
}
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
... ...
... ... @@ -78,8 +78,10 @@ struct OnlineRecognizerConfig {
// used only for modified_beam_search
int32_t max_active_paths = 4;
/// used only for modified_beam_search
float context_score = 1.5;
float hotwords_score = 1.5;
std::string hotwords_file;
OnlineRecognizerConfig() = default;
... ... @@ -89,14 +91,16 @@ struct OnlineRecognizerConfig {
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths, float context_score)
int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
context_score(context_score) {}
hotwords_score(hotwords_score),
hotwords_file(hotwords_file) {}
void Register(ParseOptions *po);
bool Validate() const;
... ... @@ -119,9 +123,16 @@ class OnlineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const;
// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const;
/**
* Return true if the given stream has enough frames for decoding.
... ...
// sherpa-onnx/csrc/utils.cc
//
// Copyright 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/utils.h"
#include <iostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords) {
hotwords->clear();
std::vector<int32_t> tmp;
std::string line;
std::string word;
while (std::getline(is, line)) {
std::istringstream iss(line);
std::vector<std::string> syms;
while (iss >> word) {
if (word.size() >= 3) {
// For BPE-based models, we replace ▁ with a space
// Unicode 9601, hex 0x2581, utf8 0xe29681
const uint8_t *p = reinterpret_cast<const uint8_t *>(word.c_str());
if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
word = word.replace(0, 3, " ");
}
}
if (symbol_table.contains(word)) {
int32_t number = symbol_table[word];
tmp.push_back(number);
} else {
SHERPA_ONNX_LOGE(
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
"the "
"same line are separated by spaces)",
word.c_str(), line.c_str());
return false;
}
}
hotwords->push_back(std::move(tmp));
}
return true;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/utils.h
//
// Copyright 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_UTILS_H_
#define SHERPA_ONNX_CSRC_UTILS_H_
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
/* Encode the hotwords in an input stream to be tokens ids.
*
* @param is The input stream, it contains several lines, one hotword for each
* line. For each hotword, the tokens (cjkchar or bpe) are separated
* by spaces.
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
* in the stream should be in the symbol_table, if not this
* function returns fasle.
*
* @@param hotwords The encoded ids to be written to.
*
* @return If all the symbols from ``is`` are in the symbol_table, returns true
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
std::vector<std::vector<int32_t>> *hotwords);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_UTILS_H_
... ...
... ... @@ -16,17 +16,19 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t, float>(),
const std::string &, int32_t, const std::string &, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def("__str__", &PyClass::ToString);
}
... ... @@ -40,11 +42,10 @@ void PybindOfflineRecognizer(py::module *m) {
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
[](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords);
},
py::arg("contexts_list"))
py::arg("hotwords"))
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
[](const PyClass &self, std::vector<OfflineStream *> ss) {
... ...
... ... @@ -21,8 +21,8 @@ void PybindOnlineModelConfig(py::module *m) {
using PyClass = OnlineModelConfig;
py::class_<PyClass>(*m, "OnlineModelConfig")
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &, std::string &, int32_t,
bool, const std::string &, const std::string &>(),
const OnlineParaformerModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
... ...
... ... @@ -29,18 +29,20 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &, bool,
const std::string &, int32_t, float>(),
const std::string &, int32_t, const std::string &, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("context_score") = 0)
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def("__str__", &PyClass::ToString);
}
... ... @@ -55,11 +57,10 @@ void PybindOnlineRecognizer(py::module *m) {
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
[](PyClass &self, const std::string &hotwords) {
return self.CreateStream(hotwords);
},
py::arg("contexts_list"))
py::arg("hotwords"))
.def("is_ready", &PyClass::IsReady)
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
... ...
... ... @@ -4,4 +4,4 @@ from _sherpa_onnx import Display, OfflineStream, OnlineStream
from .offline_recognizer import OfflineRecognizer
from .online_recognizer import OnlineRecognizer
from .utils import encode_contexts
from .utils import text2token
... ...
# Copyright (c) 2023 Xiaomi Corporation
import logging
import click
from pathlib import Path
from sherpa_onnx import text2token
@click.group()
def cli():
"""
The shell entry point to sherpa-onnx.
"""
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
)
@cli.command(name="text2token")
@click.argument("input", type=click.Path(exists=True, dir_okay=False))
@click.argument("output", type=click.Path())
@click.option(
"--tokens",
type=str,
required=True,
help="The path to tokens.txt.",
)
@click.option(
"--tokens-type",
type=str,
required=True,
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
)
@click.option(
"--bpe-model",
type=str,
help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.",
)
def encode_text(
input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path
):
"""
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
"""
texts = []
with open(input, "r", encoding="utf8") as f:
for line in f:
texts.append(line.strip())
encoded_texts = text2token(
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
)
with open(output, "w", encoding="utf8") as f:
for txt in encoded_texts:
f.write(" ".join(txt) + "\n")
... ...
... ... @@ -43,7 +43,8 @@ class OfflineRecognizer(object):
feature_dim: int = 80,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
hotwords_file: str = "",
hotwords_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
):
... ... @@ -105,7 +106,8 @@ class OfflineRecognizer(object):
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
context_score=context_score,
hotwords_file=hotwords_file,
hotwords_score=hotwords_score,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -379,11 +381,11 @@ class OfflineRecognizer(object):
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
return self.recognizer.create_stream(hotwords)
def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s)
... ...
... ... @@ -42,7 +42,8 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
hotwords_score: float = 1.5,
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
):
... ... @@ -138,7 +139,8 @@ class OnlineRecognizer(object):
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
context_score=context_score,
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -248,11 +250,11 @@ class OnlineRecognizer(object):
self.config = recognizer_config
return self
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
def create_stream(self, hotwords: Optional[str] = None):
if hotwords is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
return self.recognizer.create_stream(hotwords)
def decode_stream(self, s: OnlineStream):
self.recognizer.decode_stream(s)
... ...
from typing import Dict, List, Optional
# Copyright (c) 2023 Xiaomi Corporation
import re
from pathlib import Path
from typing import List, Optional, Union
def encode_contexts(
modeling_unit: str,
contexts: List[str],
sp: Optional["SentencePieceProcessor"] = None,
tokens_table: Optional[Dict[str, int]] = None,
) -> List[List[int]]:
import sentencepiece as spm
def text2token(
texts: List[str],
tokens: str,
tokens_type: str = "cjkchar",
bpe_model: Optional[str] = None,
output_ids: bool = False,
) -> List[List[Union[str, int]]]:
"""
Encode the given contexts (a list of string) to a list of a list of token ids.
Encode the given texts (a list of string) to a list of a list of tokens.
Args:
modeling_unit:
The valid values are bpe, char, bpe+char.
Note: char here means characters in CJK languages, not English like languages.
contexts:
texts:
The given contexts list (a list of string).
sp:
An instance of SentencePieceProcessor.
tokens_table:
The tokens_table containing the tokens and the corresponding ids.
tokens:
The path of the tokens.txt.
tokens_type:
The valid values are cjkchar, bpe, cjkchar+bpe.
bpe_model:
The path of the bpe model. Only required when tokens_type is bpe or
cjkchar+bpe.
output_ids:
True to output token ids otherwise tokens.
Returns:
Return the contexts_list, it is a list of a list of token ids.
Return the encoded texts, it is a list of a list of token ids if output_ids
is True, or it is a list of list of tokens.
"""
contexts_list = []
if "bpe" in modeling_unit:
assert sp is not None
if "char" in modeling_unit:
assert tokens_table is not None
assert len(tokens_table) > 0, len(tokens_table)
assert Path(tokens).is_file(), f"File not exists, {tokens}"
tokens_table = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens_table, f"Duplicate token: {toks} "
tokens_table[toks[0]] = int(toks[1])
if "char" == modeling_unit:
for context in contexts:
assert ' ' not in context
ids = [
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
for txt in context
]
contexts_list.append(ids)
elif "bpe" == modeling_unit:
contexts_list = sp.encode(contexts, out_type=int)
else:
assert modeling_unit == "bpe+char", modeling_unit
if "bpe" in tokens_type:
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}"
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
texts_list: List[List[str]] = []
if tokens_type == "cjkchar":
texts_list = [list("".join(text.split())) for text in texts]
elif tokens_type == "bpe":
texts_list = sp.encode(texts, out_type=str)
else:
assert (
tokens_type == "cjkchar+bpe"
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}"
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for context in contexts:
for text in texts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(context.upper())
chars = pattern.split(text)
mix_chars = [w for w in chars if len(w.strip()) > 0]
ids = []
text_list = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
ids.append(
tokens_table[ch_or_w]
if ch_or_w in tokens_table
else tokens_table["<unk>"]
)
text_list.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
ids.append(
tokens_table[p]
if p in tokens_table
else tokens_table["<unk>"]
)
contexts_list.append(ids)
return contexts_list
text_list += sp.encode_as_pieces(ch_or_w)
texts_list.append(text_list)
result: List[List[Union[int, str]]] = []
for text in texts_list:
text_list = []
contain_oov = False
for txt in text:
if txt in tokens_table:
text_list.append(tokens_table[txt] if output_ids else txt)
else:
print(f"OOV token : {txt}, skipping text : {text}.")
contain_oov = True
break
if contain_oov:
continue
else:
result.append(text_list)
return result
... ...
... ... @@ -6,12 +6,14 @@ function(sherpa_onnx_add_py_test source)
COMMAND
"${PYTHON_EXECUTABLE}"
"${CMAKE_CURRENT_SOURCE_DIR}/${source}"
WORKING_DIRECTORY
${CMAKE_CURRENT_SOURCE_DIR}
)
get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY)
set_property(TEST ${name}
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$<TARGET_FILE_DIR:_sherpa_onnx>:$ENV{PYTHONPATH}"
)
endfunction()
... ... @@ -21,6 +23,7 @@ set(py_test_files
test_offline_recognizer.py
test_online_recognizer.py
test_online_transducer_model_config.py
test_text2token.py
)
foreach(source IN LISTS py_test_files)
... ...
# sherpa-onnx/python/tests/test_text2token.py
#
# Copyright (c) 2023 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_text2token_py
import unittest
from pathlib import Path
import sherpa_onnx
d = "/tmp/sherpa-test-data"
# Please refer to
# https://github.com/pkufool/sherpa-test-data
# to download test data for testing
class TestText2Token(unittest.TestCase):
def test_bpe(self):
tokens = f"{d}/text2token/tokens_en.txt"
bpe_model = f"{d}/text2token/bpe_en.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["HELLO WORLD", "I LOVE YOU"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["▁HE", "LL", "O", "▁WORLD"],
["▁I", "▁LOVE", "▁YOU"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids
def test_cjkchar(self):
tokens = f"{d}/text2token/tokens_cn.txt"
if not Path(tokens).is_file():
print(
f"No test data found, skipping test_cjkchar().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民大团结", "中国 VS 美国"]
encoded_texts = sherpa_onnx.text2token(
texts, tokens=tokens, tokens_type="cjkchar"
)
assert encoded_texts == [
["世", "界", "人", "民", "大", "团", "结"],
["中", "国", "V", "S", "美", "国"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar",
output_ids=True,
)
assert encoded_ids == [
[379, 380, 72, 874, 93, 1251, 489],
[262, 147, 3423, 2476, 21, 147],
], encoded_ids
def test_cjkchar_bpe(self):
tokens = f"{d}/text2token/tokens_mix.txt"
bpe_model = f"{d}/text2token/bpe_mix.model"
if not Path(tokens).is_file() or not Path(bpe_model).is_file():
print(
f"No test data found, skipping test_cjkchar_bpe().\n"
f"You can download the test data by: \n"
f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data"
)
return
texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"]
encoded_texts = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
)
assert encoded_texts == [
["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"],
["中", "国", "▁GO", "ES", "▁WITH", "美", "国"],
], encoded_texts
encoded_ids = sherpa_onnx.text2token(
texts,
tokens=tokens,
tokens_type="cjkchar+bpe",
bpe_model=bpe_model,
output_ids=True,
)
assert encoded_ids == [
[1368, 1392, 557, 680, 275, 178, 475],
[685, 736, 275, 178, 179, 921, 736],
], encoded_ids
if __name__ == "__main__":
unittest.main()
... ...