Wei Kang
Committed by GitHub

Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
@@ -54,7 +54,7 @@ jobs: @@ -54,7 +54,7 @@ jobs:
54 - name: Install Python dependencies 54 - name: Install Python dependencies
55 shell: bash 55 shell: bash
56 run: | 56 run: |
57 - python3 -m pip install --upgrade pip numpy 57 + python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
58 58
59 - name: Install sherpa-onnx 59 - name: Install sherpa-onnx
60 shell: bash 60 shell: bash
@@ -43,9 +43,10 @@ import argparse @@ -43,9 +43,10 @@ import argparse
43 import time 43 import time
44 import wave 44 import wave
45 from pathlib import Path 45 from pathlib import Path
46 -from typing import Tuple 46 +from typing import List, Tuple
47 47
48 import numpy as np 48 import numpy as np
  49 +import sentencepiece as spm
49 import sherpa_onnx 50 import sherpa_onnx
50 51
51 52
@@ -61,6 +62,47 @@ def get_args(): @@ -61,6 +62,47 @@ def get_args():
61 ) 62 )
62 63
63 parser.add_argument( 64 parser.add_argument(
  65 + "--bpe-model",
  66 + type=str,
  67 + default="",
  68 + help="""
  69 + Path to bpe.model,
  70 + Used only when --decoding-method=modified_beam_search
  71 + """,
  72 + )
  73 +
  74 + parser.add_argument(
  75 + "--modeling-unit",
  76 + type=str,
  77 + default="char",
  78 + help="""
  79 + The type of modeling unit.
  80 + Valid values are bpe, bpe+char, char.
  81 + Note: the char here means characters in CJK languages.
  82 + """,
  83 + )
  84 +
  85 + parser.add_argument(
  86 + "--contexts",
  87 + type=str,
  88 + default="",
  89 + help="""
  90 + The context list, it is a string containing some words/phrases separated
  91 + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
  92 + """,
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--context-score",
  97 + type=float,
  98 + default=1.5,
  99 + help="""
  100 + The context score of each token for biasing word/phrase. Used only if
  101 + --contexts is given.
  102 + """,
  103 + )
  104 +
  105 + parser.add_argument(
64 "--encoder", 106 "--encoder",
65 default="", 107 default="",
66 type=str, 108 type=str,
@@ -153,6 +195,24 @@ def assert_file_exists(filename: str): @@ -153,6 +195,24 @@ def assert_file_exists(filename: str):
153 ) 195 )
154 196
155 197
  198 +def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
  199 + sp = None
  200 + if "bpe" in args.modeling_unit:
  201 + assert_file_exists(args.bpe_model)
  202 + sp = spm.SentencePieceProcessor()
  203 + sp.load(args.bpe_model)
  204 + tokens = {}
  205 + with open(args.tokens, "r", encoding="utf-8") as f:
  206 + for line in f:
  207 + toks = line.strip().split()
  208 + assert len(toks) == 2, len(toks)
  209 + assert toks[0] not in tokens, f"Duplicate token: {toks} "
  210 + tokens[toks[0]] = int(toks[1])
  211 + return sherpa_onnx.encode_contexts(
  212 + modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
  213 + )
  214 +
  215 +
156 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: 216 def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
157 """ 217 """
158 Args: 218 Args:
@@ -182,10 +242,17 @@ def main(): @@ -182,10 +242,17 @@ def main():
182 args = get_args() 242 args = get_args()
183 assert_file_exists(args.tokens) 243 assert_file_exists(args.tokens)
184 assert args.num_threads > 0, args.num_threads 244 assert args.num_threads > 0, args.num_threads
  245 +
  246 + contexts_list = []
185 if args.encoder: 247 if args.encoder:
186 assert len(args.paraformer) == 0, args.paraformer 248 assert len(args.paraformer) == 0, args.paraformer
187 assert len(args.nemo_ctc) == 0, args.nemo_ctc 249 assert len(args.nemo_ctc) == 0, args.nemo_ctc
188 250
  251 + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
  252 + if contexts:
  253 + print(f"Contexts list: {contexts}")
  254 + contexts_list = encode_contexts(args, contexts)
  255 +
189 assert_file_exists(args.encoder) 256 assert_file_exists(args.encoder)
190 assert_file_exists(args.decoder) 257 assert_file_exists(args.decoder)
191 assert_file_exists(args.joiner) 258 assert_file_exists(args.joiner)
@@ -199,6 +266,7 @@ def main(): @@ -199,6 +266,7 @@ def main():
199 sample_rate=args.sample_rate, 266 sample_rate=args.sample_rate,
200 feature_dim=args.feature_dim, 267 feature_dim=args.feature_dim,
201 decoding_method=args.decoding_method, 268 decoding_method=args.decoding_method,
  269 + context_score=args.context_score,
202 debug=args.debug, 270 debug=args.debug,
203 ) 271 )
204 elif args.paraformer: 272 elif args.paraformer:
@@ -238,8 +306,12 @@ def main(): @@ -238,8 +306,12 @@ def main():
238 samples, sample_rate = read_wave(wave_filename) 306 samples, sample_rate = read_wave(wave_filename)
239 duration = len(samples) / sample_rate 307 duration = len(samples) / sample_rate
240 total_duration += duration 308 total_duration += duration
241 -  
242 - s = recognizer.create_stream() 309 + if contexts_list:
  310 + assert len(args.paraformer) == 0, args.paraformer
  311 + assert len(args.nemo_ctc) == 0, args.nemo_ctc
  312 + s = recognizer.create_stream(contexts_list=contexts_list)
  313 + else:
  314 + s = recognizer.create_stream()
243 s.accept_waveform(sample_rate, samples) 315 s.accept_waveform(sample_rate, samples)
244 316
245 streams.append(s) 317 streams.append(s)
@@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f: @@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
37 37
38 install_requires = [ 38 install_requires = [
39 "numpy", 39 "numpy",
  40 + "sentencepiece==0.1.96",
40 ] 41 ]
41 42
42 43
@@ -12,6 +12,7 @@ endif() @@ -12,6 +12,7 @@ endif()
12 12
13 set(sources 13 set(sources
14 cat.cc 14 cat.cc
  15 + context-graph.cc
15 endpoint.cc 16 endpoint.cc
16 features.cc 17 features.cc
17 file-utils.cc 18 file-utils.cc
@@ -248,6 +249,7 @@ endif() @@ -248,6 +249,7 @@ endif()
248 if(SHERPA_ONNX_ENABLE_TESTS) 249 if(SHERPA_ONNX_ENABLE_TESTS)
249 set(sherpa_onnx_test_srcs 250 set(sherpa_onnx_test_srcs
250 cat-test.cc 251 cat-test.cc
  252 + context-graph-test.cc
251 packed-sequence-test.cc 253 packed-sequence-test.cc
252 pad-sequence-test.cc 254 pad-sequence-test.cc
253 slice-test.cc 255 slice-test.cc
  1 +// sherpa-onnx/csrc/context-graph-test.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/context-graph.h"
  6 +
  7 +#include <map>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "gtest/gtest.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +TEST(ContextGraph, TestBasic) {
  16 + std::vector<std::string> contexts_str(
  17 + {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
  18 + std::vector<std::vector<int32_t>> contexts;
  19 + for (int32_t i = 0; i < contexts_str.size(); ++i) {
  20 + contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
  21 + }
  22 + auto context_graph = ContextGraph(contexts, 1);
  23 +
  24 + auto queries = std::map<std::string, float>{
  25 + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
  26 + {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
  27 +
  28 + for (const auto &iter : queries) {
  29 + float total_scores = 0;
  30 + auto state = context_graph.Root();
  31 + for (auto q : iter.first) {
  32 + auto res = context_graph.ForwardOneStep(state, q);
  33 + total_scores += res.first;
  34 + state = res.second;
  35 + }
  36 + auto res = context_graph.Finalize(state);
  37 + EXPECT_EQ(res.second->token, -1);
  38 + total_scores += res.first;
  39 + EXPECT_EQ(total_scores, iter.second);
  40 + }
  41 +}
  42 +
  43 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/context-graph.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/context-graph.h"
  6 +
  7 +#include <cassert>
  8 +#include <queue>
  9 +#include <utility>
  10 +
  11 +namespace sherpa_onnx {
  12 +void ContextGraph::Build(
  13 + const std::vector<std::vector<int32_t>> &token_ids) const {
  14 + for (int32_t i = 0; i < token_ids.size(); ++i) {
  15 + auto node = root_.get();
  16 + for (int32_t j = 0; j < token_ids[i].size(); ++j) {
  17 + int32_t token = token_ids[i][j];
  18 + if (0 == node->next.count(token)) {
  19 + bool is_end = j == token_ids[i].size() - 1;
  20 + node->next[token] = std::make_unique<ContextState>(
  21 + token, context_score_, node->node_score + context_score_,
  22 + is_end ? 0 : node->local_node_score + context_score_, is_end);
  23 + }
  24 + node = node->next[token].get();
  25 + }
  26 + }
  27 + FillFailOutput();
  28 +}
  29 +
  30 +std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
  31 + const ContextState *state, int32_t token) const {
  32 + const ContextState *node;
  33 + float score;
  34 + if (1 == state->next.count(token)) {
  35 + node = state->next.at(token).get();
  36 + score = node->token_score;
  37 + if (state->is_end) score += state->node_score;
  38 + } else {
  39 + node = state->fail;
  40 + while (0 == node->next.count(token)) {
  41 + node = node->fail;
  42 + if (-1 == node->token) break; // root
  43 + }
  44 + if (1 == node->next.count(token)) {
  45 + node = node->next.at(token).get();
  46 + }
  47 + score = node->node_score - state->local_node_score;
  48 + }
  49 + SHERPA_ONNX_CHECK(nullptr != node);
  50 + float matched_score = 0;
  51 + auto output = node->output;
  52 + while (nullptr != output) {
  53 + matched_score += output->node_score;
  54 + output = output->output;
  55 + }
  56 + return std::make_pair(score + matched_score, node);
  57 +}
  58 +
  59 +std::pair<float, const ContextState *> ContextGraph::Finalize(
  60 + const ContextState *state) const {
  61 + float score = -state->node_score;
  62 + if (state->is_end) {
  63 + score = 0;
  64 + }
  65 + return std::make_pair(score, root_.get());
  66 +}
  67 +
  68 +void ContextGraph::FillFailOutput() const {
  69 + std::queue<const ContextState *> node_queue;
  70 + for (auto &kv : root_->next) {
  71 + kv.second->fail = root_.get();
  72 + node_queue.push(kv.second.get());
  73 + }
  74 + while (!node_queue.empty()) {
  75 + auto current_node = node_queue.front();
  76 + node_queue.pop();
  77 + for (auto &kv : current_node->next) {
  78 + auto fail = current_node->fail;
  79 + if (1 == fail->next.count(kv.first)) {
  80 + fail = fail->next.at(kv.first).get();
  81 + } else {
  82 + fail = fail->fail;
  83 + while (0 == fail->next.count(kv.first)) {
  84 + fail = fail->fail;
  85 + if (-1 == fail->token) break;
  86 + }
  87 + if (1 == fail->next.count(kv.first))
  88 + fail = fail->next.at(kv.first).get();
  89 + }
  90 + kv.second->fail = fail;
  91 + // fill the output arc
  92 + auto output = fail;
  93 + while (!output->is_end) {
  94 + output = output->fail;
  95 + if (-1 == output->token) {
  96 + output = nullptr;
  97 + break;
  98 + }
  99 + }
  100 + kv.second->output = output;
  101 + node_queue.push(kv.second.get());
  102 + }
  103 + }
  104 +}
  105 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/context-graph.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
  6 +#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
  7 +
  8 +#include <memory>
  9 +#include <unordered_map>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/log.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class ContextGraph;
  18 +using ContextGraphPtr = std::shared_ptr<ContextGraph>;
  19 +
  20 +struct ContextState {
  21 + int32_t token;
  22 + float token_score;
  23 + float node_score;
  24 + float local_node_score;
  25 + bool is_end;
  26 + std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
  27 + const ContextState *fail = nullptr;
  28 + const ContextState *output = nullptr;
  29 +
  30 + ContextState() = default;
  31 + ContextState(int32_t token, float token_score, float node_score,
  32 + float local_node_score, bool is_end)
  33 + : token(token),
  34 + token_score(token_score),
  35 + node_score(node_score),
  36 + local_node_score(local_node_score),
  37 + is_end(is_end) {}
  38 +};
  39 +
  40 +class ContextGraph {
  41 + public:
  42 + ContextGraph() = default;
  43 + ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
  44 + float context_score)
  45 + : context_score_(context_score) {
  46 + root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
  47 + root_->fail = root_.get();
  48 + Build(token_ids);
  49 + }
  50 +
  51 + std::pair<float, const ContextState *> ForwardOneStep(
  52 + const ContextState *state, int32_t token_id) const;
  53 + std::pair<float, const ContextState *> Finalize(
  54 + const ContextState *state) const;
  55 +
  56 + const ContextState *Root() const { return root_.get(); }
  57 +
  58 + private:
  59 + float context_score_;
  60 + std::unique_ptr<ContextState> root_;
  61 + void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
  62 + void FillFailOutput() const;
  63 +};
  64 +
  65 +} // namespace sherpa_onnx
  66 +#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include <vector> 14 #include <vector>
15 15
16 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/context-graph.h"
17 #include "sherpa-onnx/csrc/math.h" 18 #include "sherpa-onnx/csrc/math.h"
18 #include "sherpa-onnx/csrc/onnx-utils.h" 19 #include "sherpa-onnx/csrc/onnx-utils.h"
19 20
@@ -39,11 +40,18 @@ struct Hypothesis { @@ -39,11 +40,18 @@ struct Hypothesis {
39 // the nn lm states 40 // the nn lm states
40 std::vector<CopyableOrtValue> nn_lm_states; 41 std::vector<CopyableOrtValue> nn_lm_states;
41 42
  43 + const ContextState *context_state;
  44 +
  45 + // TODO(fangjun): Make it configurable
  46 + // the minimum of tokens in a chunk for streaming RNN LM
  47 + int32_t lm_rescore_min_chunk = 2; // a const
  48 +
42 int32_t num_trailing_blanks = 0; 49 int32_t num_trailing_blanks = 0;
43 50
44 Hypothesis() = default; 51 Hypothesis() = default;
45 - Hypothesis(const std::vector<int64_t> &ys, double log_prob)  
46 - : ys(ys), log_prob(log_prob) {} 52 + Hypothesis(const std::vector<int64_t> &ys, double log_prob,
  53 + const ContextState *context_state = nullptr)
  54 + : ys(ys), log_prob(log_prob), context_state(context_state) {}
47 55
48 double TotalLogProb() const { return log_prob + lm_log_prob; } 56 double TotalLogProb() const { return log_prob + lm_log_prob; }
49 57
@@ -6,7 +6,9 @@ @@ -6,7 +6,9 @@
6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <vector>
9 10
  11 +#include "sherpa-onnx/csrc/macros.h"
10 #include "sherpa-onnx/csrc/offline-recognizer.h" 12 #include "sherpa-onnx/csrc/offline-recognizer.h"
11 #include "sherpa-onnx/csrc/offline-stream.h" 13 #include "sherpa-onnx/csrc/offline-stream.h"
12 14
@@ -19,6 +21,12 @@ class OfflineRecognizerImpl { @@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
19 21
20 virtual ~OfflineRecognizerImpl() = default; 22 virtual ~OfflineRecognizerImpl() = default;
21 23
  24 + virtual std::unique_ptr<OfflineStream> CreateStream(
  25 + const std::vector<std::vector<int32_t>> &context_list) const {
  26 + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
  27 + exit(-1);
  28 + }
  29 +
22 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; 30 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
23 31
24 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; 32 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <utility> 10 #include <utility>
11 #include <vector> 11 #include <vector>
12 12
  13 +#include "sherpa-onnx/csrc/context-graph.h"
13 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 15 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
15 #include "sherpa-onnx/csrc/offline-recognizer.h" 16 #include "sherpa-onnx/csrc/offline-recognizer.h"
@@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
72 } 73 }
73 } 74 }
74 75
  76 + std::unique_ptr<OfflineStream> CreateStream(
  77 + const std::vector<std::vector<int32_t>> &context_list) const override {
  78 + // We create context_graph at this level, because we might have default
  79 + // context_graph(will be added later if needed) that belongs to the whole
  80 + // model rather than each stream.
  81 + auto context_graph =
  82 + std::make_shared<ContextGraph>(context_list, config_.context_score);
  83 + return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
  84 + }
  85 +
75 std::unique_ptr<OfflineStream> CreateStream() const override { 86 std::unique_ptr<OfflineStream> CreateStream() const override {
76 return std::make_unique<OfflineStream>(config_.feat_config); 87 return std::make_unique<OfflineStream>(config_.feat_config);
77 } 88 }
@@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
117 -23.025850929940457f); 128 -23.025850929940457f);
118 129
119 auto t = model_->RunEncoder(std::move(x), std::move(x_length)); 130 auto t = model_->RunEncoder(std::move(x), std::move(x_length));
120 - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); 131 + auto results =
  132 + decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
121 133
122 int32_t frame_shift_ms = 10; 134 int32_t frame_shift_ms = 10;
123 for (int32_t i = 0; i != n; ++i) { 135 for (int32_t i = 0; i != n; ++i) {
@@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
26 26
27 po->Register("max-active-paths", &max_active_paths, 27 po->Register("max-active-paths", &max_active_paths,
28 "Used only when decoding_method is modified_beam_search"); 28 "Used only when decoding_method is modified_beam_search");
  29 + po->Register("context-score", &context_score,
  30 + "The bonus score for each token in context word/phrase. "
  31 + "Used only when decoding_method is modified_beam_search");
29 } 32 }
30 33
31 bool OfflineRecognizerConfig::Validate() const { 34 bool OfflineRecognizerConfig::Validate() const {
@@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
49 os << "model_config=" << model_config.ToString() << ", "; 52 os << "model_config=" << model_config.ToString() << ", ";
50 os << "lm_config=" << lm_config.ToString() << ", "; 53 os << "lm_config=" << lm_config.ToString() << ", ";
51 os << "decoding_method=\"" << decoding_method << "\", "; 54 os << "decoding_method=\"" << decoding_method << "\", ";
52 - os << "max_active_paths=" << max_active_paths << ")"; 55 + os << "max_active_paths=" << max_active_paths << ", ";
  56 + os << "context_score=" << context_score << ")";
53 57
54 return os.str(); 58 return os.str();
55 } 59 }
@@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) @@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
59 63
60 OfflineRecognizer::~OfflineRecognizer() = default; 64 OfflineRecognizer::~OfflineRecognizer() = default;
61 65
  66 +std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
  67 + const std::vector<std::vector<int32_t>> &context_list) const {
  68 + return impl_->CreateStream(context_list);
  69 +}
  70 +
62 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { 71 std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
63 return impl_->CreateStream(); 72 return impl_->CreateStream();
64 } 73 }
@@ -26,6 +26,7 @@ struct OfflineRecognizerConfig { @@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
26 26
27 std::string decoding_method = "greedy_search"; 27 std::string decoding_method = "greedy_search";
28 int32_t max_active_paths = 4; 28 int32_t max_active_paths = 4;
  29 + float context_score = 1.5;
29 // only greedy_search is implemented 30 // only greedy_search is implemented
30 // TODO(fangjun): Implement modified_beam_search 31 // TODO(fangjun): Implement modified_beam_search
31 32
@@ -34,12 +35,13 @@ struct OfflineRecognizerConfig { @@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
34 const OfflineModelConfig &model_config, 35 const OfflineModelConfig &model_config,
35 const OfflineLMConfig &lm_config, 36 const OfflineLMConfig &lm_config,
36 const std::string &decoding_method, 37 const std::string &decoding_method,
37 - int32_t max_active_paths) 38 + int32_t max_active_paths, float context_score)
38 : feat_config(feat_config), 39 : feat_config(feat_config),
39 model_config(model_config), 40 model_config(model_config),
40 lm_config(lm_config), 41 lm_config(lm_config),
41 decoding_method(decoding_method), 42 decoding_method(decoding_method),
42 - max_active_paths(max_active_paths) {} 43 + max_active_paths(max_active_paths),
  44 + context_score(context_score) {}
43 45
44 void Register(ParseOptions *po); 46 void Register(ParseOptions *po);
45 bool Validate() const; 47 bool Validate() const;
@@ -58,6 +60,10 @@ class OfflineRecognizer { @@ -58,6 +60,10 @@ class OfflineRecognizer {
58 /// Create a stream for decoding. 60 /// Create a stream for decoding.
59 std::unique_ptr<OfflineStream> CreateStream() const; 61 std::unique_ptr<OfflineStream> CreateStream() const;
60 62
  63 + /// Create a stream for decoding.
  64 + std::unique_ptr<OfflineStream> CreateStream(
  65 + const std::vector<std::vector<int32_t>> &context_list) const;
  66 +
61 /** Decode a single stream 67 /** Decode a single stream
62 * 68 *
63 * @param s The stream to decode. 69 * @param s The stream to decode.
@@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const { @@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
75 75
76 class OfflineStream::Impl { 76 class OfflineStream::Impl {
77 public: 77 public:
78 - explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { 78 + explicit Impl(const OfflineFeatureExtractorConfig &config,
  79 + ContextGraphPtr context_graph)
  80 + : config_(config), context_graph_(context_graph) {
79 opts_.frame_opts.dither = 0; 81 opts_.frame_opts.dither = 0;
80 opts_.frame_opts.snip_edges = false; 82 opts_.frame_opts.snip_edges = false;
81 opts_.frame_opts.samp_freq = config.sampling_rate; 83 opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -152,6 +154,8 @@ class OfflineStream::Impl { @@ -152,6 +154,8 @@ class OfflineStream::Impl {
152 154
153 const OfflineRecognitionResult &GetResult() const { return r_; } 155 const OfflineRecognitionResult &GetResult() const { return r_; }
154 156
  157 + const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
  158 +
155 private: 159 private:
156 void NemoNormalizeFeatures(float *p, int32_t num_frames, 160 void NemoNormalizeFeatures(float *p, int32_t num_frames,
157 int32_t feature_dim) const { 161 int32_t feature_dim) const {
@@ -189,11 +193,13 @@ class OfflineStream::Impl { @@ -189,11 +193,13 @@ class OfflineStream::Impl {
189 std::unique_ptr<knf::OnlineFbank> fbank_; 193 std::unique_ptr<knf::OnlineFbank> fbank_;
190 knf::FbankOptions opts_; 194 knf::FbankOptions opts_;
191 OfflineRecognitionResult r_; 195 OfflineRecognitionResult r_;
  196 + ContextGraphPtr context_graph_;
192 }; 197 };
193 198
194 OfflineStream::OfflineStream( 199 OfflineStream::OfflineStream(
195 - const OfflineFeatureExtractorConfig &config /*= {}*/)  
196 - : impl_(std::make_unique<Impl>(config)) {} 200 + const OfflineFeatureExtractorConfig &config /*= {}*/,
  201 + ContextGraphPtr context_graph /*= nullptr*/)
  202 + : impl_(std::make_unique<Impl>(config, context_graph)) {}
197 203
198 OfflineStream::~OfflineStream() = default; 204 OfflineStream::~OfflineStream() = default;
199 205
@@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) { @@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
212 impl_->SetResult(r); 218 impl_->SetResult(r);
213 } 219 }
214 220
  221 +const ContextGraphPtr &OfflineStream::GetContextGraph() const {
  222 + return impl_->GetContextGraph();
  223 +}
  224 +
215 const OfflineRecognitionResult &OfflineStream::GetResult() const { 225 const OfflineRecognitionResult &OfflineStream::GetResult() const {
216 return impl_->GetResult(); 226 return impl_->GetResult();
217 } 227 }
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <string> 10 #include <string>
11 #include <vector> 11 #include <vector>
12 12
  13 +#include "sherpa-onnx/csrc/context-graph.h"
13 #include "sherpa-onnx/csrc/parse-options.h" 14 #include "sherpa-onnx/csrc/parse-options.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
@@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig { @@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
66 67
67 class OfflineStream { 68 class OfflineStream {
68 public: 69 public:
69 - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); 70 + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
  71 + ContextGraphPtr context_graph = nullptr);
70 ~OfflineStream(); 72 ~OfflineStream();
71 73
72 /** 74 /**
@@ -96,6 +98,9 @@ class OfflineStream { @@ -96,6 +98,9 @@ class OfflineStream {
96 /** Get the recognition result of this stream */ 98 /** Get the recognition result of this stream */
97 const OfflineRecognitionResult &GetResult() const; 99 const OfflineRecognitionResult &GetResult() const;
98 100
  101 + /** Get the ContextGraph of this stream */
  102 + const ContextGraphPtr &GetContextGraph() const;
  103 +
99 private: 104 private:
100 class Impl; 105 class Impl;
101 std::unique_ptr<Impl> impl_; 106 std::unique_ptr<Impl> impl_;
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "onnxruntime_cxx_api.h" // NOLINT 10 #include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/offline-stream.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
@@ -33,7 +34,8 @@ class OfflineTransducerDecoder { @@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
33 * @return Return a vector of size `N` containing the decoded results. 34 * @return Return a vector of size `N` containing the decoded results.
34 */ 35 */
35 virtual std::vector<OfflineTransducerDecoderResult> Decode( 36 virtual std::vector<OfflineTransducerDecoderResult> Decode(
36 - Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; 37 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  38 + OfflineStream **ss = nullptr, int32_t n = 0) = 0;
37 }; 39 };
38 40
39 } // namespace sherpa_onnx 41 } // namespace sherpa_onnx
@@ -16,7 +16,9 @@ namespace sherpa_onnx { @@ -16,7 +16,9 @@ namespace sherpa_onnx {
16 16
17 std::vector<OfflineTransducerDecoderResult> 17 std::vector<OfflineTransducerDecoderResult>
18 OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, 18 OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
19 - Ort::Value encoder_out_length) { 19 + Ort::Value encoder_out_length,
  20 + OfflineStream **ss /*= nullptr*/,
  21 + int32_t n /*= 0*/) {
20 PackedSequence packed_encoder_out = PackPaddedSequence( 22 PackedSequence packed_encoder_out = PackPaddedSequence(
21 model_->Allocator(), &encoder_out, &encoder_out_length); 23 model_->Allocator(), &encoder_out, &encoder_out_length);
22 24
@@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { @@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
18 : model_(model) {} 18 : model_(model) {}
19 19
20 std::vector<OfflineTransducerDecoderResult> Decode( 20 std::vector<OfflineTransducerDecoderResult> Decode(
21 - Ort::Value encoder_out, Ort::Value encoder_out_length) override; 21 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  22 + OfflineStream **ss = nullptr, int32_t n = 0) override;
22 23
23 private: 24 private:
24 OfflineTransducerModel *model_; // Not owned 25 OfflineTransducerModel *model_; // Not owned
@@ -8,7 +8,9 @@ @@ -8,7 +8,9 @@
8 #include <utility> 8 #include <utility>
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "sherpa-onnx/csrc/context-graph.h"
11 #include "sherpa-onnx/csrc/hypothesis.h" 12 #include "sherpa-onnx/csrc/hypothesis.h"
  13 +#include "sherpa-onnx/csrc/log.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
13 #include "sherpa-onnx/csrc/packed-sequence.h" 15 #include "sherpa-onnx/csrc/packed-sequence.h"
14 #include "sherpa-onnx/csrc/slice.h" 16 #include "sherpa-onnx/csrc/slice.h"
@@ -17,23 +19,39 @@ namespace sherpa_onnx { @@ -17,23 +19,39 @@ namespace sherpa_onnx {
17 19
18 std::vector<OfflineTransducerDecoderResult> 20 std::vector<OfflineTransducerDecoderResult>
19 OfflineTransducerModifiedBeamSearchDecoder::Decode( 21 OfflineTransducerModifiedBeamSearchDecoder::Decode(
20 - Ort::Value encoder_out, Ort::Value encoder_out_length) { 22 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  23 + OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
21 PackedSequence packed_encoder_out = PackPaddedSequence( 24 PackedSequence packed_encoder_out = PackPaddedSequence(
22 model_->Allocator(), &encoder_out, &encoder_out_length); 25 model_->Allocator(), &encoder_out, &encoder_out_length);
23 26
24 int32_t batch_size = 27 int32_t batch_size =
25 static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); 28 static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
26 29
  30 + if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
  31 +
27 int32_t vocab_size = model_->VocabSize(); 32 int32_t vocab_size = model_->VocabSize();
28 int32_t context_size = model_->ContextSize(); 33 int32_t context_size = model_->ContextSize();
29 34
30 std::vector<int64_t> blanks(context_size, 0); 35 std::vector<int64_t> blanks(context_size, 0);
31 - Hypotheses blank_hyp({{blanks, 0}});  
32 36
33 std::deque<Hypotheses> finalized; 37 std::deque<Hypotheses> finalized;
34 - std::vector<Hypotheses> cur(batch_size, blank_hyp); 38 + std::vector<Hypotheses> cur;
35 std::vector<Hypothesis> prev; 39 std::vector<Hypothesis> prev;
36 40
  41 + std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
  42 +
  43 + for (int32_t i = 0; i < batch_size; ++i) {
  44 + const ContextState *context_state;
  45 + if (ss != nullptr) {
  46 + context_graphs[i] =
  47 + ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
  48 + if (context_graphs[i] != nullptr)
  49 + context_state = context_graphs[i]->Root();
  50 + }
  51 + Hypotheses blank_hyp({{blanks, 0, context_state}});
  52 + cur.emplace_back(std::move(blank_hyp));
  53 + }
  54 +
37 int32_t start = 0; 55 int32_t start = 0;
38 int32_t t = 0; 56 int32_t t = 0;
39 for (auto n : packed_encoder_out.batch_sizes) { 57 for (auto n : packed_encoder_out.batch_sizes) {
@@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
106 int32_t new_token = k % vocab_size; 124 int32_t new_token = k % vocab_size;
107 Hypothesis new_hyp = prev[hyp_index]; 125 Hypothesis new_hyp = prev[hyp_index];
108 126
  127 + float context_score = 0;
  128 + auto context_state = new_hyp.context_state;
109 if (new_token != 0) { 129 if (new_token != 0) {
110 // blank id is fixed to 0 130 // blank id is fixed to 0
111 new_hyp.ys.push_back(new_token); 131 new_hyp.ys.push_back(new_token);
112 new_hyp.timestamps.push_back(t); 132 new_hyp.timestamps.push_back(t);
  133 + if (context_graphs[i] != nullptr) {
  134 + auto context_res =
  135 + context_graphs[i]->ForwardOneStep(context_state, new_token);
  136 + context_score = context_res.first;
  137 + new_hyp.context_state = context_res.second;
  138 + }
113 } 139 }
114 140
115 - new_hyp.log_prob = p_logprob[k]; 141 + new_hyp.log_prob = p_logprob[k] + context_score;
116 hyps.Add(std::move(new_hyp)); 142 hyps.Add(std::move(new_hyp));
117 } // for (auto k : topk) 143 } // for (auto k : topk)
118 p_logprob += (end - start) * vocab_size; 144 p_logprob += (end - start) * vocab_size;
@@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
126 cur.push_back(std::move(h)); 152 cur.push_back(std::move(h));
127 } 153 }
128 154
  155 + // Finalize context biasing matching..
  156 + for (int32_t i = 0; i < cur.size(); ++i) {
  157 + for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
  158 + if (context_graphs[i] != nullptr) {
  159 + auto context_res =
  160 + context_graphs[i]->Finalize(iter->second.context_state);
  161 + iter->second.log_prob += context_res.first;
  162 + iter->second.context_state = context_res.second;
  163 + }
  164 + }
  165 + }
  166 +
129 if (lm_) { 167 if (lm_) {
130 // use LM for rescoring 168 // use LM for rescoring
131 lm_->ComputeLMScore(lm_scale_, context_size, &cur); 169 lm_->ComputeLMScore(lm_scale_, context_size, &cur);
@@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder @@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
26 lm_scale_(lm_scale) {} 26 lm_scale_(lm_scale) {}
27 27
28 std::vector<OfflineTransducerDecoderResult> Decode( 28 std::vector<OfflineTransducerDecoderResult> Decode(
29 - Ort::Value encoder_out, Ort::Value encoder_out_length) override; 29 + Ort::Value encoder_out, Ort::Value encoder_out_length,
  30 + OfflineStream **ss = nullptr, int32_t n = 0) override;
30 31
31 private: 32 private:
32 OfflineTransducerModel *model_; // Not owned 33 OfflineTransducerModel *model_; // Not owned
@@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
16 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
17 .def(py::init<const OfflineFeatureExtractorConfig &, 17 .def(py::init<const OfflineFeatureExtractorConfig &,
18 const OfflineModelConfig &, const OfflineLMConfig &, 18 const OfflineModelConfig &, const OfflineLMConfig &,
19 - const std::string &, int32_t>(), 19 + const std::string &, int32_t, float>(),
20 py::arg("feat_config"), py::arg("model_config"), 20 py::arg("feat_config"), py::arg("model_config"),
21 py::arg("lm_config") = OfflineLMConfig(), 21 py::arg("lm_config") = OfflineLMConfig(),
22 py::arg("decoding_method") = "greedy_search", 22 py::arg("decoding_method") = "greedy_search",
23 - py::arg("max_active_paths") = 4) 23 + py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
24 .def_readwrite("feat_config", &PyClass::feat_config) 24 .def_readwrite("feat_config", &PyClass::feat_config)
25 .def_readwrite("model_config", &PyClass::model_config) 25 .def_readwrite("model_config", &PyClass::model_config)
26 .def_readwrite("lm_config", &PyClass::lm_config) 26 .def_readwrite("lm_config", &PyClass::lm_config)
27 .def_readwrite("decoding_method", &PyClass::decoding_method) 27 .def_readwrite("decoding_method", &PyClass::decoding_method)
28 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 28 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
  29 + .def_readwrite("context_score", &PyClass::context_score)
29 .def("__str__", &PyClass::ToString); 30 .def("__str__", &PyClass::ToString);
30 } 31 }
31 32
@@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) { @@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
35 using PyClass = OfflineRecognizer; 36 using PyClass = OfflineRecognizer;
36 py::class_<PyClass>(*m, "OfflineRecognizer") 37 py::class_<PyClass>(*m, "OfflineRecognizer")
37 .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) 38 .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
38 - .def("create_stream", &PyClass::CreateStream) 39 + .def("create_stream",
  40 + [](const PyClass &self) { return self.CreateStream(); })
  41 + .def(
  42 + "create_stream",
  43 + [](PyClass &self,
  44 + const std::vector<std::vector<int32_t>> &contexts_list) {
  45 + return self.CreateStream(contexts_list);
  46 + },
  47 + py::arg("contexts_list"))
39 .def("decode_stream", &PyClass::DecodeStream) 48 .def("decode_stream", &PyClass::DecodeStream)
40 .def("decode_streams", 49 .def("decode_streams",
41 - [](PyClass &self, std::vector<OfflineStream *> ss) { 50 + [](const PyClass &self, std::vector<OfflineStream *> ss) {
42 self.DecodeStreams(ss.data(), ss.size()); 51 self.DecodeStreams(ss.data(), ss.size());
43 }); 52 });
44 } 53 }
  1 +from typing import Dict, List, Optional
  2 +
1 from _sherpa_onnx import Display 3 from _sherpa_onnx import Display
2 4
3 from .online_recognizer import OnlineRecognizer 5 from .online_recognizer import OnlineRecognizer
4 from .online_recognizer import OnlineStream 6 from .online_recognizer import OnlineStream
5 from .offline_recognizer import OfflineRecognizer 7 from .offline_recognizer import OfflineRecognizer
  8 +
  9 +from .utils import encode_contexts
  10 +
  11 +
  12 +
1 # Copyright (c) 2023 by manyeyes 1 # Copyright (c) 2023 by manyeyes
2 from pathlib import Path 2 from pathlib import Path
3 -from typing import List 3 +from typing import List, Optional
4 4
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
6 OfflineFeatureExtractorConfig, 6 OfflineFeatureExtractorConfig,
@@ -39,6 +39,7 @@ class OfflineRecognizer(object): @@ -39,6 +39,7 @@ class OfflineRecognizer(object):
39 sample_rate: int = 16000, 39 sample_rate: int = 16000,
40 feature_dim: int = 80, 40 feature_dim: int = 80,
41 decoding_method: str = "greedy_search", 41 decoding_method: str = "greedy_search",
  42 + context_score: float = 1.5,
42 debug: bool = False, 43 debug: bool = False,
43 provider: str = "cpu", 44 provider: str = "cpu",
44 ): 45 ):
@@ -96,6 +97,7 @@ class OfflineRecognizer(object): @@ -96,6 +97,7 @@ class OfflineRecognizer(object):
96 feat_config=feat_config, 97 feat_config=feat_config,
97 model_config=model_config, 98 model_config=model_config,
98 decoding_method=decoding_method, 99 decoding_method=decoding_method,
  100 + context_score=context_score,
99 ) 101 )
100 self.recognizer = _Recognizer(recognizer_config) 102 self.recognizer = _Recognizer(recognizer_config)
101 return self 103 return self
@@ -216,8 +218,11 @@ class OfflineRecognizer(object): @@ -216,8 +218,11 @@ class OfflineRecognizer(object):
216 self.recognizer = _Recognizer(recognizer_config) 218 self.recognizer = _Recognizer(recognizer_config)
217 return self 219 return self
218 220
219 - def create_stream(self):  
220 - return self.recognizer.create_stream() 221 + def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
  222 + if contexts_list is None:
  223 + return self.recognizer.create_stream()
  224 + else:
  225 + return self.recognizer.create_stream(contexts_list)
221 226
222 def decode_stream(self, s: OfflineStream): 227 def decode_stream(self, s: OfflineStream):
223 self.recognizer.decode_stream(s) 228 self.recognizer.decode_stream(s)
  1 +from typing import Dict, List, Optional
  2 +
  3 +
  4 +def encode_contexts(
  5 + modeling_unit: str,
  6 + contexts: List[str],
  7 + sp: Optional["SentencePieceProcessor"] = None,
  8 + tokens_table: Optional[Dict[str, int]] = None,
  9 +) -> List[List[int]]:
  10 + """
  11 + Encode the given contexts (a list of string) to a list of a list of token ids.
  12 +
  13 + Args:
  14 + modeling_unit:
  15 + The valid values are bpe, char, bpe+char.
  16 + Note: char here means characters in CJK languages, not English like languages.
  17 + contexts:
  18 + The given contexts list (a list of string).
  19 + sp:
  20 + An instance of SentencePieceProcessor.
  21 + tokens_table:
  22 + The tokens_table containing the tokens and the corresponding ids.
  23 + Returns:
  24 + Return the contexts_list, it is a list of a list of token ids.
  25 + """
  26 + contexts_list = []
  27 + if "bpe" in modeling_unit:
  28 + assert sp is not None
  29 + if "char" in modeling_unit:
  30 + assert tokens_table is not None
  31 + assert len(tokens_table) > 0, len(tokens_table)
  32 +
  33 + if "char" == modeling_unit:
  34 + for context in contexts:
  35 + assert ' ' not in context
  36 + ids = [
  37 + tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
  38 + for txt in context
  39 + ]
  40 + contexts_list.append(ids)
  41 + elif "bpe" == modeling_unit:
  42 + contexts_list = sp.encode(contexts, out_type=int)
  43 + else:
  44 + assert modeling_unit == "bpe+char", modeling_unit
  45 +
  46 + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
  47 + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  48 + pattern = re.compile(r"([\u4e00-\u9fff])")
  49 + for context in contexts:
  50 + # Example:
  51 + # txt = "你好 ITS'S OKAY 的"
  52 + # chars = ["你", "好", " ITS'S OKAY ", "的"]
  53 + chars = pattern.split(context.upper())
  54 + mix_chars = [w for w in chars if len(w.strip()) > 0]
  55 + ids = []
  56 + for ch_or_w in mix_chars:
  57 + # ch_or_w is a single CJK charater(i.e., "你"), do nothing.
  58 + if pattern.fullmatch(ch_or_w) is not None:
  59 + ids.append(
  60 + tokens_table[ch_or_w]
  61 + if ch_or_w in tokens_table
  62 + else tokens_table["<unk>"]
  63 + )
  64 + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
  65 + # encode ch_or_w using bpe_model.
  66 + else:
  67 + for p in sp.encode_as_pieces(ch_or_w):
  68 + ids.append(
  69 + tokens_table[p]
  70 + if p in tokens_table
  71 + else tokens_table["<unk>"]
  72 + )
  73 + contexts_list.append(ids)
  74 + return contexts_list