Wei Kang
Committed by GitHub

Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
@@ -20,9 +20,10 @@ import argparse @@ -20,9 +20,10 @@ import argparse
20 import time 20 import time
21 import wave 21 import wave
22 from pathlib import Path 22 from pathlib import Path
23 -from typing import Tuple 23 +from typing import List, Tuple
24 24
25 import numpy as np 25 import numpy as np
  26 +import sentencepiece as spm
26 import sherpa_onnx 27 import sherpa_onnx
27 28
28 29
@@ -70,6 +71,59 @@ def get_args(): @@ -70,6 +71,59 @@ def get_args():
70 ) 71 )
71 72
72 parser.add_argument( 73 parser.add_argument(
  74 + "--max-active-paths",
  75 + type=int,
  76 + default=4,
  77 + help="""Used only when --decoding-method is modified_beam_search.
  78 + It specifies number of active paths to keep during decoding.
  79 + """,
  80 + )
  81 +
  82 + parser.add_argument(
  83 + "--bpe-model",
  84 + type=str,
  85 + default="",
  86 + help="""
  87 + Path to bpe.model, it will be used to tokenize contexts biasing phrases.
  88 + Used only when --decoding-method=modified_beam_search
  89 + """,
  90 + )
  91 +
  92 + parser.add_argument(
  93 + "--modeling-unit",
  94 + type=str,
  95 + default="char",
  96 + help="""
  97 + The type of modeling unit, it will be used to tokenize contexts biasing phrases.
  98 + Valid values are bpe, bpe+char, char.
  99 + Note: the char here means characters in CJK languages.
  100 + Used only when --decoding-method=modified_beam_search
  101 + """,
  102 + )
  103 +
  104 + parser.add_argument(
  105 + "--contexts",
  106 + type=str,
  107 + default="",
  108 + help="""
  109 + The context list, it is a string containing some words/phrases separated
  110 + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
  111 + Used only when --decoding-method=modified_beam_search
  112 + """,
  113 + )
  114 +
  115 + parser.add_argument(
  116 + "--context-score",
  117 + type=float,
  118 + default=1.5,
  119 + help="""
  120 + The context score of each token for biasing word/phrase. Used only if
  121 + --contexts is given.
  122 + Used only when --decoding-method=modified_beam_search
  123 + """,
  124 + )
  125 +
  126 + parser.add_argument(
73 "sound_files", 127 "sound_files",
74 type=str, 128 type=str,
75 nargs="+", 129 nargs="+",
@@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: @@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
116 return samples_float32, f.getframerate() 170 return samples_float32, f.getframerate()
117 171
118 172
  173 +def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
  174 + sp = None
  175 + if "bpe" in args.modeling_unit:
  176 + assert_file_exists(args.bpe_model)
  177 + sp = spm.SentencePieceProcessor()
  178 + sp.load(args.bpe_model)
  179 + tokens = {}
  180 + with open(args.tokens, "r", encoding="utf-8") as f:
  181 + for line in f:
  182 + toks = line.strip().split()
  183 + assert len(toks) == 2, len(toks)
  184 + assert toks[0] not in tokens, f"Duplicate token: {toks} "
  185 + tokens[toks[0]] = int(toks[1])
  186 + return sherpa_onnx.encode_contexts(
  187 + modeling_unit=args.modeling_unit,
  188 + contexts=contexts,
  189 + sp=sp,
  190 + tokens_table=tokens,
  191 + )
  192 +
  193 +
119 def main(): 194 def main():
120 args = get_args() 195 args = get_args()
121 assert_file_exists(args.encoder) 196 assert_file_exists(args.encoder)
@@ -132,11 +207,20 @@ def main(): @@ -132,11 +207,20 @@ def main():
132 sample_rate=16000, 207 sample_rate=16000,
133 feature_dim=80, 208 feature_dim=80,
134 decoding_method=args.decoding_method, 209 decoding_method=args.decoding_method,
  210 + max_active_paths=args.max_active_paths,
  211 + context_score=args.context_score,
135 ) 212 )
136 213
137 print("Started!") 214 print("Started!")
138 start_time = time.time() 215 start_time = time.time()
139 216
  217 + contexts_list = []
  218 + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
  219 + if contexts:
  220 + print(f"Contexts list: {contexts}")
  221 + contexts_list = encode_contexts(args, contexts)
  222 +
  223 +
140 streams = [] 224 streams = []
141 total_duration = 0 225 total_duration = 0
142 for wave_filename in args.sound_files: 226 for wave_filename in args.sound_files:
@@ -145,7 +229,11 @@ def main(): @@ -145,7 +229,11 @@ def main():
145 duration = len(samples) / sample_rate 229 duration = len(samples) / sample_rate
146 total_duration += duration 230 total_duration += duration
147 231
148 - s = recognizer.create_stream() 232 + if contexts_list:
  233 + s = recognizer.create_stream(contexts_list=contexts_list)
  234 + else:
  235 + s = recognizer.create_stream()
  236 +
149 s.accept_waveform(sample_rate, samples) 237 s.accept_waveform(sample_rate, samples)
150 238
151 tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) 239 tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
@@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
88 "True to enable endpoint detection. False to disable it."); 88 "True to enable endpoint detection. False to disable it.");
89 po->Register("max-active-paths", &max_active_paths, 89 po->Register("max-active-paths", &max_active_paths,
90 "beam size used in modified beam search."); 90 "beam size used in modified beam search.");
  91 + po->Register("context-score", &context_score,
  92 + "The bonus score for each token in context word/phrase. "
  93 + "Used only when decoding_method is modified_beam_search");
91 po->Register("decoding-method", &decoding_method, 94 po->Register("decoding-method", &decoding_method,
92 "decoding method," 95 "decoding method,"
93 "now support greedy_search and modified_beam_search."); 96 "now support greedy_search and modified_beam_search.");
@@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
115 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 118 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
116 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; 119 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
117 os << "max_active_paths=" << max_active_paths << ", "; 120 os << "max_active_paths=" << max_active_paths << ", ";
  121 + os << "context_score=" << context_score << ", ";
118 os << "decoding_method=\"" << decoding_method << "\")"; 122 os << "decoding_method=\"" << decoding_method << "\")";
119 123
120 return os.str(); 124 return os.str();
@@ -166,10 +170,37 @@ class OnlineRecognizer::Impl { @@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
166 } 170 }
167 #endif 171 #endif
168 172
  173 + void InitOnlineStream(OnlineStream *stream) const {
  174 + auto r = decoder_->GetEmptyResult();
  175 +
  176 + if (config_.decoding_method == "modified_beam_search" &&
  177 + nullptr != stream->GetContextGraph()) {
  178 + // r.hyps has only one element.
  179 + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
  180 + it->second.context_state = stream->GetContextGraph()->Root();
  181 + }
  182 + }
  183 +
  184 + stream->SetResult(r);
  185 + stream->SetStates(model_->GetEncoderInitStates());
  186 + }
  187 +
169 std::unique_ptr<OnlineStream> CreateStream() const { 188 std::unique_ptr<OnlineStream> CreateStream() const {
170 auto stream = std::make_unique<OnlineStream>(config_.feat_config); 189 auto stream = std::make_unique<OnlineStream>(config_.feat_config);
171 - stream->SetResult(decoder_->GetEmptyResult());  
172 - stream->SetStates(model_->GetEncoderInitStates()); 190 + InitOnlineStream(stream.get());
  191 + return stream;
  192 + }
  193 +
  194 + std::unique_ptr<OnlineStream> CreateStream(
  195 + const std::vector<std::vector<int32_t>> &contexts) const {
  196 + // We create context_graph at this level, because we might have default
  197 + // context_graph(will be added later if needed) that belongs to the whole
  198 + // model rather than each stream.
  199 + auto context_graph =
  200 + std::make_shared<ContextGraph>(contexts, config_.context_score);
  201 + auto stream =
  202 + std::make_unique<OnlineStream>(config_.feat_config, context_graph);
  203 + InitOnlineStream(stream.get());
173 return stream; 204 return stream;
174 } 205 }
175 206
@@ -188,8 +219,12 @@ class OnlineRecognizer::Impl { @@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
188 std::vector<float> features_vec(n * chunk_size * feature_dim); 219 std::vector<float> features_vec(n * chunk_size * feature_dim);
189 std::vector<std::vector<Ort::Value>> states_vec(n); 220 std::vector<std::vector<Ort::Value>> states_vec(n);
190 std::vector<int64_t> all_processed_frames(n); 221 std::vector<int64_t> all_processed_frames(n);
  222 + bool has_context_graph = false;
191 223
192 for (int32_t i = 0; i != n; ++i) { 224 for (int32_t i = 0; i != n; ++i) {
  225 + if (!has_context_graph && ss[i]->GetContextGraph())
  226 + has_context_graph = true;
  227 +
193 const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); 228 const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
194 std::vector<float> features = 229 std::vector<float> features =
195 ss[i]->GetFrames(num_processed_frames, chunk_size); 230 ss[i]->GetFrames(num_processed_frames, chunk_size);
@@ -226,7 +261,11 @@ class OnlineRecognizer::Impl { @@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
226 auto pair = model_->RunEncoder(std::move(x), std::move(states), 261 auto pair = model_->RunEncoder(std::move(x), std::move(states),
227 std::move(processed_frames)); 262 std::move(processed_frames));
228 263
229 - decoder_->Decode(std::move(pair.first), &results); 264 + if (has_context_graph) {
  265 + decoder_->Decode(std::move(pair.first), ss, &results);
  266 + } else {
  267 + decoder_->Decode(std::move(pair.first), &results);
  268 + }
230 269
231 std::vector<std::vector<Ort::Value>> next_states = 270 std::vector<std::vector<Ort::Value>> next_states =
232 model_->UnStackStates(pair.second); 271 model_->UnStackStates(pair.second);
@@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { @@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
297 return impl_->CreateStream(); 336 return impl_->CreateStream();
298 } 337 }
299 338
  339 +std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
  340 + const std::vector<std::vector<int32_t>> &context_list) const {
  341 + return impl_->CreateStream(context_list);
  342 +}
  343 +
300 bool OnlineRecognizer::IsReady(OnlineStream *s) const { 344 bool OnlineRecognizer::IsReady(OnlineStream *s) const {
301 return impl_->IsReady(s); 345 return impl_->IsReady(s);
302 } 346 }
@@ -75,7 +75,10 @@ struct OnlineRecognizerConfig { @@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
75 std::string decoding_method = "greedy_search"; 75 std::string decoding_method = "greedy_search";
76 // now support modified_beam_search and greedy_search 76 // now support modified_beam_search and greedy_search
77 77
78 - int32_t max_active_paths = 4; // used only for modified_beam_search 78 + // used only for modified_beam_search
  79 + int32_t max_active_paths = 4;
  80 + /// used only for modified_beam_search
  81 + float context_score = 1.5;
79 82
80 OnlineRecognizerConfig() = default; 83 OnlineRecognizerConfig() = default;
81 84
@@ -85,13 +88,14 @@ struct OnlineRecognizerConfig { @@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
85 const EndpointConfig &endpoint_config, 88 const EndpointConfig &endpoint_config,
86 bool enable_endpoint, 89 bool enable_endpoint,
87 const std::string &decoding_method, 90 const std::string &decoding_method,
88 - int32_t max_active_paths) 91 + int32_t max_active_paths, float context_score)
89 : feat_config(feat_config), 92 : feat_config(feat_config),
90 model_config(model_config), 93 model_config(model_config),
91 endpoint_config(endpoint_config), 94 endpoint_config(endpoint_config),
92 enable_endpoint(enable_endpoint), 95 enable_endpoint(enable_endpoint),
93 decoding_method(decoding_method), 96 decoding_method(decoding_method),
94 - max_active_paths(max_active_paths) {} 97 + max_active_paths(max_active_paths),
  98 + context_score(context_score) {}
95 99
96 void Register(ParseOptions *po); 100 void Register(ParseOptions *po);
97 bool Validate() const; 101 bool Validate() const;
@@ -112,6 +116,10 @@ class OnlineRecognizer { @@ -112,6 +116,10 @@ class OnlineRecognizer {
112 /// Create a stream for decoding. 116 /// Create a stream for decoding.
113 std::unique_ptr<OnlineStream> CreateStream() const; 117 std::unique_ptr<OnlineStream> CreateStream() const;
114 118
  119 + // Create a stream with context phrases
  120 + std::unique_ptr<OnlineStream> CreateStream(
  121 + const std::vector<std::vector<int32_t>> &context_list) const;
  122 +
115 /** 123 /**
116 * Return true if the given stream has enough frames for decoding. 124 * Return true if the given stream has enough frames for decoding.
117 * Return false otherwise 125 * Return false otherwise
@@ -13,8 +13,9 @@ namespace sherpa_onnx { @@ -13,8 +13,9 @@ namespace sherpa_onnx {
13 13
14 class OnlineStream::Impl { 14 class OnlineStream::Impl {
15 public: 15 public:
16 - explicit Impl(const FeatureExtractorConfig &config)  
17 - : feat_extractor_(config) {} 16 + explicit Impl(const FeatureExtractorConfig &config,
  17 + ContextGraphPtr context_graph)
  18 + : feat_extractor_(config), context_graph_(context_graph) {}
18 19
19 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 20 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
20 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); 21 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
@@ -54,16 +55,21 @@ class OnlineStream::Impl { @@ -54,16 +55,21 @@ class OnlineStream::Impl {
54 55
55 std::vector<Ort::Value> &GetStates() { return states_; } 56 std::vector<Ort::Value> &GetStates() { return states_; }
56 57
  58 + const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
  59 +
57 private: 60 private:
58 FeatureExtractor feat_extractor_; 61 FeatureExtractor feat_extractor_;
  62 + /// For contextual-biasing
  63 + ContextGraphPtr context_graph_;
59 int32_t num_processed_frames_ = 0; // before subsampling 64 int32_t num_processed_frames_ = 0; // before subsampling
60 int32_t start_frame_index_ = 0; // never reset 65 int32_t start_frame_index_ = 0; // never reset
61 OnlineTransducerDecoderResult result_; 66 OnlineTransducerDecoderResult result_;
62 std::vector<Ort::Value> states_; 67 std::vector<Ort::Value> states_;
63 }; 68 };
64 69
65 -OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)  
66 - : impl_(std::make_unique<Impl>(config)) {} 70 +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
  71 + ContextGraphPtr context_graph /*= nullptr */)
  72 + : impl_(std::make_unique<Impl>(config, context_graph)) {}
67 73
68 OnlineStream::~OnlineStream() = default; 74 OnlineStream::~OnlineStream() = default;
69 75
@@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { @@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
109 return impl_->GetStates(); 115 return impl_->GetStates();
110 } 116 }
111 117
  118 +const ContextGraphPtr &OnlineStream::GetContextGraph() const {
  119 + return impl_->GetContextGraph();
  120 +}
  121 +
112 } // namespace sherpa_onnx 122 } // namespace sherpa_onnx
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/context-graph.h"
12 #include "sherpa-onnx/csrc/features.h" 13 #include "sherpa-onnx/csrc/features.h"
13 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 14 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
14 15
@@ -16,7 +17,8 @@ namespace sherpa_onnx { @@ -16,7 +17,8 @@ namespace sherpa_onnx {
16 17
17 class OnlineStream { 18 class OnlineStream {
18 public: 19 public:
19 - explicit OnlineStream(const FeatureExtractorConfig &config = {}); 20 + explicit OnlineStream(const FeatureExtractorConfig &config = {},
  21 + ContextGraphPtr context_graph = nullptr);
20 ~OnlineStream(); 22 ~OnlineStream();
21 23
22 /** 24 /**
@@ -71,6 +73,13 @@ class OnlineStream { @@ -71,6 +73,13 @@ class OnlineStream {
71 void SetStates(std::vector<Ort::Value> states); 73 void SetStates(std::vector<Ort::Value> states);
72 std::vector<Ort::Value> &GetStates(); 74 std::vector<Ort::Value> &GetStates();
73 75
  76 + /**
  77 + * Get the context graph corresponding to this stream.
  78 + *
  79 + * @return Return the context graph for this stream.
  80 + */
  81 + const ContextGraphPtr &GetContextGraph() const;
  82 +
74 private: 83 private:
75 class Impl; 84 class Impl;
76 std::unique_ptr<Impl> impl_; 85 std::unique_ptr<Impl> impl_;
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 9
10 #include "onnxruntime_cxx_api.h" // NOLINT 10 #include "onnxruntime_cxx_api.h" // NOLINT
11 #include "sherpa-onnx/csrc/hypothesis.h" 11 #include "sherpa-onnx/csrc/hypothesis.h"
  12 +#include "sherpa-onnx/csrc/macros.h"
12 13
13 namespace sherpa_onnx { 14 namespace sherpa_onnx {
14 15
@@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult { @@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
45 OnlineTransducerDecoderResult &&other); 46 OnlineTransducerDecoderResult &&other);
46 }; 47 };
47 48
  49 +class OnlineStream;
48 class OnlineTransducerDecoder { 50 class OnlineTransducerDecoder {
49 public: 51 public:
50 virtual ~OnlineTransducerDecoder() = default; 52 virtual ~OnlineTransducerDecoder() = default;
@@ -76,6 +78,26 @@ class OnlineTransducerDecoder { @@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
76 virtual void Decode(Ort::Value encoder_out, 78 virtual void Decode(Ort::Value encoder_out,
77 std::vector<OnlineTransducerDecoderResult> *result) = 0; 79 std::vector<OnlineTransducerDecoderResult> *result) = 0;
78 80
  81 + /** Run transducer beam search given the output from the encoder model.
  82 + *
  83 + * Note: Currently this interface is for contextual-biasing feature which
  84 + * needs a ContextGraph owned by the OnlineStream.
  85 + *
  86 + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
  87 + * @param ss A list of OnlineStreams.
  88 + * @param result It is modified in-place.
  89 + *
  90 + * @note There is no need to pass encoder_out_length here since for the
  91 + * online decoding case, each utterance has the same number of frames
  92 + * and there are no paddings.
  93 + */
  94 + virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
  95 + std::vector<OnlineTransducerDecoderResult> *result) {
  96 + SHERPA_ONNX_LOGE(
  97 + "This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
  98 + exit(-1);
  99 + }
  100 +
79 // used for endpointing. We need to keep decoder_out after reset 101 // used for endpointing. We need to keep decoder_out after reset
80 virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} 102 virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
81 }; 103 };
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <utility> 9 #include <utility>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/log.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 13 #include "sherpa-onnx/csrc/onnx-utils.h"
13 14
14 namespace sherpa_onnx { 15 namespace sherpa_onnx {
@@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( @@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
62 void OnlineTransducerModifiedBeamSearchDecoder::Decode( 63 void OnlineTransducerModifiedBeamSearchDecoder::Decode(
63 Ort::Value encoder_out, 64 Ort::Value encoder_out,
64 std::vector<OnlineTransducerDecoderResult> *result) { 65 std::vector<OnlineTransducerDecoderResult> *result) {
  66 + Decode(std::move(encoder_out), nullptr, result);
  67 +}
  68 +
  69 +void OnlineTransducerModifiedBeamSearchDecoder::Decode(
  70 + Ort::Value encoder_out, OnlineStream **ss,
  71 + std::vector<OnlineTransducerDecoderResult> *result) {
65 std::vector<int64_t> encoder_out_shape = 72 std::vector<int64_t> encoder_out_shape =
66 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 73 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
67 74
@@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
74 } 81 }
75 82
76 int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); 83 int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
  84 +
77 int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); 85 int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
78 int32_t vocab_size = model_->VocabSize(); 86 int32_t vocab_size = model_->VocabSize();
79 87
@@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
142 150
143 Hypothesis new_hyp = prev[hyp_index]; 151 Hypothesis new_hyp = prev[hyp_index];
144 const float prev_lm_log_prob = new_hyp.lm_log_prob; 152 const float prev_lm_log_prob = new_hyp.lm_log_prob;
  153 + float context_score = 0;
  154 + auto context_state = new_hyp.context_state;
  155 +
145 if (new_token != 0) { 156 if (new_token != 0) {
146 new_hyp.ys.push_back(new_token); 157 new_hyp.ys.push_back(new_token);
147 new_hyp.timestamps.push_back(t + frame_offset); 158 new_hyp.timestamps.push_back(t + frame_offset);
148 new_hyp.num_trailing_blanks = 0; 159 new_hyp.num_trailing_blanks = 0;
  160 + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
  161 + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
  162 + context_state, new_token);
  163 + context_score = context_res.first;
  164 + new_hyp.context_state = context_res.second;
  165 + }
149 if (lm_) { 166 if (lm_) {
150 lm_->ComputeLMScore(lm_scale_, &new_hyp); 167 lm_->ComputeLMScore(lm_scale_, &new_hyp);
151 } 168 }
152 } else { 169 } else {
153 ++new_hyp.num_trailing_blanks; 170 ++new_hyp.num_trailing_blanks;
154 } 171 }
155 - new_hyp.log_prob =  
156 - p_logprob[k] - prev_lm_log_prob; // log_prob only includes the 172 + new_hyp.log_prob = p_logprob[k] + context_score -
  173 + prev_lm_log_prob; // log_prob only includes the
157 // score of the transducer 174 // score of the transducer
158 hyps.Add(std::move(new_hyp)); 175 hyps.Add(std::move(new_hyp));
159 } // for (auto k : topk) 176 } // for (auto k : topk)
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "sherpa-onnx/csrc/online-lm.h" 11 #include "sherpa-onnx/csrc/online-lm.h"
  12 +#include "sherpa-onnx/csrc/online-stream.h"
12 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 13 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
13 #include "sherpa-onnx/csrc/online-transducer-model.h" 14 #include "sherpa-onnx/csrc/online-transducer-model.h"
14 15
@@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
33 void Decode(Ort::Value encoder_out, 34 void Decode(Ort::Value encoder_out,
34 std::vector<OnlineTransducerDecoderResult> *result) override; 35 std::vector<OnlineTransducerDecoderResult> *result) override;
35 36
  37 + void Decode(Ort::Value encoder_out, OnlineStream **ss,
  38 + std::vector<OnlineTransducerDecoderResult> *result) override;
  39 +
36 void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; 40 void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
37 41
38 private: 42 private:
@@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
22 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 22 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
23 .def(py::init<const FeatureExtractorConfig &, 23 .def(py::init<const FeatureExtractorConfig &,
24 const OnlineTransducerModelConfig &, const OnlineLMConfig &, 24 const OnlineTransducerModelConfig &, const OnlineLMConfig &,
25 - const EndpointConfig &, bool, const std::string &,  
26 - int32_t>(), 25 + const EndpointConfig &, bool, const std::string &, int32_t,
  26 + float>(),
27 py::arg("feat_config"), py::arg("model_config"), 27 py::arg("feat_config"), py::arg("model_config"),
28 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), 28 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
29 py::arg("enable_endpoint"), py::arg("decoding_method"), 29 py::arg("enable_endpoint"), py::arg("decoding_method"),
30 - py::arg("max_active_paths")) 30 + py::arg("max_active_paths"), py::arg("context_score"))
31 .def_readwrite("feat_config", &PyClass::feat_config) 31 .def_readwrite("feat_config", &PyClass::feat_config)
32 .def_readwrite("model_config", &PyClass::model_config) 32 .def_readwrite("model_config", &PyClass::model_config)
33 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 33 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
34 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 34 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
35 .def_readwrite("decoding_method", &PyClass::decoding_method) 35 .def_readwrite("decoding_method", &PyClass::decoding_method)
36 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 36 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
  37 + .def_readwrite("context_score", &PyClass::context_score)
37 .def("__str__", &PyClass::ToString); 38 .def("__str__", &PyClass::ToString);
38 } 39 }
39 40
@@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) { @@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) {
44 using PyClass = OnlineRecognizer; 45 using PyClass = OnlineRecognizer;
45 py::class_<PyClass>(*m, "OnlineRecognizer") 46 py::class_<PyClass>(*m, "OnlineRecognizer")
46 .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) 47 .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
47 - .def("create_stream", &PyClass::CreateStream) 48 + .def("create_stream",
  49 + [](const PyClass &self) { return self.CreateStream(); })
  50 + .def(
  51 + "create_stream",
  52 + [](PyClass &self,
  53 + const std::vector<std::vector<int32_t>> &contexts_list) {
  54 + return self.CreateStream(contexts_list);
  55 + },
  56 + py::arg("contexts_list"))
48 .def("is_ready", &PyClass::IsReady) 57 .def("is_ready", &PyClass::IsReady)
49 .def("decode_stream", &PyClass::DecodeStream) 58 .def("decode_stream", &PyClass::DecodeStream)
50 .def("decode_streams", 59 .def("decode_streams",
1 # Copyright (c) 2023 Xiaomi Corporation 1 # Copyright (c) 2023 Xiaomi Corporation
2 from pathlib import Path 2 from pathlib import Path
3 -from typing import List 3 +from typing import List, Optional
4 4
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
6 EndpointConfig, 6 EndpointConfig,
@@ -39,6 +39,7 @@ class OnlineRecognizer(object): @@ -39,6 +39,7 @@ class OnlineRecognizer(object):
39 rule3_min_utterance_length: float = 20.0, 39 rule3_min_utterance_length: float = 20.0,
40 decoding_method: str = "greedy_search", 40 decoding_method: str = "greedy_search",
41 max_active_paths: int = 4, 41 max_active_paths: int = 4,
  42 + context_score: float = 1.5,
42 provider: str = "cpu", 43 provider: str = "cpu",
43 ): 44 ):
44 """ 45 """
@@ -124,13 +125,17 @@ class OnlineRecognizer(object): @@ -124,13 +125,17 @@ class OnlineRecognizer(object):
124 enable_endpoint=enable_endpoint_detection, 125 enable_endpoint=enable_endpoint_detection,
125 decoding_method=decoding_method, 126 decoding_method=decoding_method,
126 max_active_paths=max_active_paths, 127 max_active_paths=max_active_paths,
  128 + context_score=context_score,
127 ) 129 )
128 130
129 self.recognizer = _Recognizer(recognizer_config) 131 self.recognizer = _Recognizer(recognizer_config)
130 self.config = recognizer_config 132 self.config = recognizer_config
131 133
132 - def create_stream(self):  
133 - return self.recognizer.create_stream() 134 + def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
  135 + if contexts_list is None:
  136 + return self.recognizer.create_stream()
  137 + else:
  138 + return self.recognizer.create_stream(contexts_list)
134 139
135 def decode_stream(self, s: OnlineStream): 140 def decode_stream(self, s: OnlineStream):
136 self.recognizer.decode_stream(s) 141 self.recognizer.decode_stream(s)