Committed by
GitHub
Support contextual-biasing for streaming model (#184)
* Support contextual-biasing for streaming model * The whole pipeline runs normally * Fix comments
正在显示
10 个修改的文件
包含
238 行增加
和
22 行删除
| @@ -20,9 +20,10 @@ import argparse | @@ -20,9 +20,10 @@ import argparse | ||
| 20 | import time | 20 | import time |
| 21 | import wave | 21 | import wave |
| 22 | from pathlib import Path | 22 | from pathlib import Path |
| 23 | -from typing import Tuple | 23 | +from typing import List, Tuple |
| 24 | 24 | ||
| 25 | import numpy as np | 25 | import numpy as np |
| 26 | +import sentencepiece as spm | ||
| 26 | import sherpa_onnx | 27 | import sherpa_onnx |
| 27 | 28 | ||
| 28 | 29 | ||
| @@ -70,6 +71,59 @@ def get_args(): | @@ -70,6 +71,59 @@ def get_args(): | ||
| 70 | ) | 71 | ) |
| 71 | 72 | ||
| 72 | parser.add_argument( | 73 | parser.add_argument( |
| 74 | + "--max-active-paths", | ||
| 75 | + type=int, | ||
| 76 | + default=4, | ||
| 77 | + help="""Used only when --decoding-method is modified_beam_search. | ||
| 78 | + It specifies number of active paths to keep during decoding. | ||
| 79 | + """, | ||
| 80 | + ) | ||
| 81 | + | ||
| 82 | + parser.add_argument( | ||
| 83 | + "--bpe-model", | ||
| 84 | + type=str, | ||
| 85 | + default="", | ||
| 86 | + help=""" | ||
| 87 | + Path to bpe.model, it will be used to tokenize contexts biasing phrases. | ||
| 88 | + Used only when --decoding-method=modified_beam_search | ||
| 89 | + """, | ||
| 90 | + ) | ||
| 91 | + | ||
| 92 | + parser.add_argument( | ||
| 93 | + "--modeling-unit", | ||
| 94 | + type=str, | ||
| 95 | + default="char", | ||
| 96 | + help=""" | ||
| 97 | + The type of modeling unit, it will be used to tokenize contexts biasing phrases. | ||
| 98 | + Valid values are bpe, bpe+char, char. | ||
| 99 | + Note: the char here means characters in CJK languages. | ||
| 100 | + Used only when --decoding-method=modified_beam_search | ||
| 101 | + """, | ||
| 102 | + ) | ||
| 103 | + | ||
| 104 | + parser.add_argument( | ||
| 105 | + "--contexts", | ||
| 106 | + type=str, | ||
| 107 | + default="", | ||
| 108 | + help=""" | ||
| 109 | + The context list, it is a string containing some words/phrases separated | ||
| 110 | + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". | ||
| 111 | + Used only when --decoding-method=modified_beam_search | ||
| 112 | + """, | ||
| 113 | + ) | ||
| 114 | + | ||
| 115 | + parser.add_argument( | ||
| 116 | + "--context-score", | ||
| 117 | + type=float, | ||
| 118 | + default=1.5, | ||
| 119 | + help=""" | ||
| 120 | + The context score of each token for biasing word/phrase. Used only if | ||
| 121 | + --contexts is given. | ||
| 122 | + Used only when --decoding-method=modified_beam_search | ||
| 123 | + """, | ||
| 124 | + ) | ||
| 125 | + | ||
| 126 | + parser.add_argument( | ||
| 73 | "sound_files", | 127 | "sound_files", |
| 74 | type=str, | 128 | type=str, |
| 75 | nargs="+", | 129 | nargs="+", |
| @@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | @@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 116 | return samples_float32, f.getframerate() | 170 | return samples_float32, f.getframerate() |
| 117 | 171 | ||
| 118 | 172 | ||
| 173 | +def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 174 | + sp = None | ||
| 175 | + if "bpe" in args.modeling_unit: | ||
| 176 | + assert_file_exists(args.bpe_model) | ||
| 177 | + sp = spm.SentencePieceProcessor() | ||
| 178 | + sp.load(args.bpe_model) | ||
| 179 | + tokens = {} | ||
| 180 | + with open(args.tokens, "r", encoding="utf-8") as f: | ||
| 181 | + for line in f: | ||
| 182 | + toks = line.strip().split() | ||
| 183 | + assert len(toks) == 2, len(toks) | ||
| 184 | + assert toks[0] not in tokens, f"Duplicate token: {toks} " | ||
| 185 | + tokens[toks[0]] = int(toks[1]) | ||
| 186 | + return sherpa_onnx.encode_contexts( | ||
| 187 | + modeling_unit=args.modeling_unit, | ||
| 188 | + contexts=contexts, | ||
| 189 | + sp=sp, | ||
| 190 | + tokens_table=tokens, | ||
| 191 | + ) | ||
| 192 | + | ||
| 193 | + | ||
| 119 | def main(): | 194 | def main(): |
| 120 | args = get_args() | 195 | args = get_args() |
| 121 | assert_file_exists(args.encoder) | 196 | assert_file_exists(args.encoder) |
| @@ -132,11 +207,20 @@ def main(): | @@ -132,11 +207,20 @@ def main(): | ||
| 132 | sample_rate=16000, | 207 | sample_rate=16000, |
| 133 | feature_dim=80, | 208 | feature_dim=80, |
| 134 | decoding_method=args.decoding_method, | 209 | decoding_method=args.decoding_method, |
| 210 | + max_active_paths=args.max_active_paths, | ||
| 211 | + context_score=args.context_score, | ||
| 135 | ) | 212 | ) |
| 136 | 213 | ||
| 137 | print("Started!") | 214 | print("Started!") |
| 138 | start_time = time.time() | 215 | start_time = time.time() |
| 139 | 216 | ||
| 217 | + contexts_list = [] | ||
| 218 | + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | ||
| 219 | + if contexts: | ||
| 220 | + print(f"Contexts list: {contexts}") | ||
| 221 | + contexts_list = encode_contexts(args, contexts) | ||
| 222 | + | ||
| 223 | + | ||
| 140 | streams = [] | 224 | streams = [] |
| 141 | total_duration = 0 | 225 | total_duration = 0 |
| 142 | for wave_filename in args.sound_files: | 226 | for wave_filename in args.sound_files: |
| @@ -145,7 +229,11 @@ def main(): | @@ -145,7 +229,11 @@ def main(): | ||
| 145 | duration = len(samples) / sample_rate | 229 | duration = len(samples) / sample_rate |
| 146 | total_duration += duration | 230 | total_duration += duration |
| 147 | 231 | ||
| 148 | - s = recognizer.create_stream() | 232 | + if contexts_list: |
| 233 | + s = recognizer.create_stream(contexts_list=contexts_list) | ||
| 234 | + else: | ||
| 235 | + s = recognizer.create_stream() | ||
| 236 | + | ||
| 149 | s.accept_waveform(sample_rate, samples) | 237 | s.accept_waveform(sample_rate, samples) |
| 150 | 238 | ||
| 151 | tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | 239 | tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) |
| @@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 88 | "True to enable endpoint detection. False to disable it."); | 88 | "True to enable endpoint detection. False to disable it."); |
| 89 | po->Register("max-active-paths", &max_active_paths, | 89 | po->Register("max-active-paths", &max_active_paths, |
| 90 | "beam size used in modified beam search."); | 90 | "beam size used in modified beam search."); |
| 91 | + po->Register("context-score", &context_score, | ||
| 92 | + "The bonus score for each token in context word/phrase. " | ||
| 93 | + "Used only when decoding_method is modified_beam_search"); | ||
| 91 | po->Register("decoding-method", &decoding_method, | 94 | po->Register("decoding-method", &decoding_method, |
| 92 | "decoding method," | 95 | "decoding method," |
| 93 | "now support greedy_search and modified_beam_search."); | 96 | "now support greedy_search and modified_beam_search."); |
| @@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 115 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 118 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 116 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; | 119 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; |
| 117 | os << "max_active_paths=" << max_active_paths << ", "; | 120 | os << "max_active_paths=" << max_active_paths << ", "; |
| 121 | + os << "context_score=" << context_score << ", "; | ||
| 118 | os << "decoding_method=\"" << decoding_method << "\")"; | 122 | os << "decoding_method=\"" << decoding_method << "\")"; |
| 119 | 123 | ||
| 120 | return os.str(); | 124 | return os.str(); |
| @@ -166,10 +170,37 @@ class OnlineRecognizer::Impl { | @@ -166,10 +170,37 @@ class OnlineRecognizer::Impl { | ||
| 166 | } | 170 | } |
| 167 | #endif | 171 | #endif |
| 168 | 172 | ||
| 173 | + void InitOnlineStream(OnlineStream *stream) const { | ||
| 174 | + auto r = decoder_->GetEmptyResult(); | ||
| 175 | + | ||
| 176 | + if (config_.decoding_method == "modified_beam_search" && | ||
| 177 | + nullptr != stream->GetContextGraph()) { | ||
| 178 | + // r.hyps has only one element. | ||
| 179 | + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { | ||
| 180 | + it->second.context_state = stream->GetContextGraph()->Root(); | ||
| 181 | + } | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + stream->SetResult(r); | ||
| 185 | + stream->SetStates(model_->GetEncoderInitStates()); | ||
| 186 | + } | ||
| 187 | + | ||
| 169 | std::unique_ptr<OnlineStream> CreateStream() const { | 188 | std::unique_ptr<OnlineStream> CreateStream() const { |
| 170 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 189 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); |
| 171 | - stream->SetResult(decoder_->GetEmptyResult()); | ||
| 172 | - stream->SetStates(model_->GetEncoderInitStates()); | 190 | + InitOnlineStream(stream.get()); |
| 191 | + return stream; | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + std::unique_ptr<OnlineStream> CreateStream( | ||
| 195 | + const std::vector<std::vector<int32_t>> &contexts) const { | ||
| 196 | + // We create context_graph at this level, because we might have default | ||
| 197 | + // context_graph(will be added later if needed) that belongs to the whole | ||
| 198 | + // model rather than each stream. | ||
| 199 | + auto context_graph = | ||
| 200 | + std::make_shared<ContextGraph>(contexts, config_.context_score); | ||
| 201 | + auto stream = | ||
| 202 | + std::make_unique<OnlineStream>(config_.feat_config, context_graph); | ||
| 203 | + InitOnlineStream(stream.get()); | ||
| 173 | return stream; | 204 | return stream; |
| 174 | } | 205 | } |
| 175 | 206 | ||
| @@ -188,8 +219,12 @@ class OnlineRecognizer::Impl { | @@ -188,8 +219,12 @@ class OnlineRecognizer::Impl { | ||
| 188 | std::vector<float> features_vec(n * chunk_size * feature_dim); | 219 | std::vector<float> features_vec(n * chunk_size * feature_dim); |
| 189 | std::vector<std::vector<Ort::Value>> states_vec(n); | 220 | std::vector<std::vector<Ort::Value>> states_vec(n); |
| 190 | std::vector<int64_t> all_processed_frames(n); | 221 | std::vector<int64_t> all_processed_frames(n); |
| 222 | + bool has_context_graph = false; | ||
| 191 | 223 | ||
| 192 | for (int32_t i = 0; i != n; ++i) { | 224 | for (int32_t i = 0; i != n; ++i) { |
| 225 | + if (!has_context_graph && ss[i]->GetContextGraph()) | ||
| 226 | + has_context_graph = true; | ||
| 227 | + | ||
| 193 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); | 228 | const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); |
| 194 | std::vector<float> features = | 229 | std::vector<float> features = |
| 195 | ss[i]->GetFrames(num_processed_frames, chunk_size); | 230 | ss[i]->GetFrames(num_processed_frames, chunk_size); |
| @@ -226,7 +261,11 @@ class OnlineRecognizer::Impl { | @@ -226,7 +261,11 @@ class OnlineRecognizer::Impl { | ||
| 226 | auto pair = model_->RunEncoder(std::move(x), std::move(states), | 261 | auto pair = model_->RunEncoder(std::move(x), std::move(states), |
| 227 | std::move(processed_frames)); | 262 | std::move(processed_frames)); |
| 228 | 263 | ||
| 229 | - decoder_->Decode(std::move(pair.first), &results); | 264 | + if (has_context_graph) { |
| 265 | + decoder_->Decode(std::move(pair.first), ss, &results); | ||
| 266 | + } else { | ||
| 267 | + decoder_->Decode(std::move(pair.first), &results); | ||
| 268 | + } | ||
| 230 | 269 | ||
| 231 | std::vector<std::vector<Ort::Value>> next_states = | 270 | std::vector<std::vector<Ort::Value>> next_states = |
| 232 | model_->UnStackStates(pair.second); | 271 | model_->UnStackStates(pair.second); |
| @@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | @@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | ||
| 297 | return impl_->CreateStream(); | 336 | return impl_->CreateStream(); |
| 298 | } | 337 | } |
| 299 | 338 | ||
| 339 | +std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream( | ||
| 340 | + const std::vector<std::vector<int32_t>> &context_list) const { | ||
| 341 | + return impl_->CreateStream(context_list); | ||
| 342 | +} | ||
| 343 | + | ||
| 300 | bool OnlineRecognizer::IsReady(OnlineStream *s) const { | 344 | bool OnlineRecognizer::IsReady(OnlineStream *s) const { |
| 301 | return impl_->IsReady(s); | 345 | return impl_->IsReady(s); |
| 302 | } | 346 | } |
| @@ -75,7 +75,10 @@ struct OnlineRecognizerConfig { | @@ -75,7 +75,10 @@ struct OnlineRecognizerConfig { | ||
| 75 | std::string decoding_method = "greedy_search"; | 75 | std::string decoding_method = "greedy_search"; |
| 76 | // now support modified_beam_search and greedy_search | 76 | // now support modified_beam_search and greedy_search |
| 77 | 77 | ||
| 78 | - int32_t max_active_paths = 4; // used only for modified_beam_search | 78 | + // used only for modified_beam_search |
| 79 | + int32_t max_active_paths = 4; | ||
| 80 | + /// used only for modified_beam_search | ||
| 81 | + float context_score = 1.5; | ||
| 79 | 82 | ||
| 80 | OnlineRecognizerConfig() = default; | 83 | OnlineRecognizerConfig() = default; |
| 81 | 84 | ||
| @@ -85,13 +88,14 @@ struct OnlineRecognizerConfig { | @@ -85,13 +88,14 @@ struct OnlineRecognizerConfig { | ||
| 85 | const EndpointConfig &endpoint_config, | 88 | const EndpointConfig &endpoint_config, |
| 86 | bool enable_endpoint, | 89 | bool enable_endpoint, |
| 87 | const std::string &decoding_method, | 90 | const std::string &decoding_method, |
| 88 | - int32_t max_active_paths) | 91 | + int32_t max_active_paths, float context_score) |
| 89 | : feat_config(feat_config), | 92 | : feat_config(feat_config), |
| 90 | model_config(model_config), | 93 | model_config(model_config), |
| 91 | endpoint_config(endpoint_config), | 94 | endpoint_config(endpoint_config), |
| 92 | enable_endpoint(enable_endpoint), | 95 | enable_endpoint(enable_endpoint), |
| 93 | decoding_method(decoding_method), | 96 | decoding_method(decoding_method), |
| 94 | - max_active_paths(max_active_paths) {} | 97 | + max_active_paths(max_active_paths), |
| 98 | + context_score(context_score) {} | ||
| 95 | 99 | ||
| 96 | void Register(ParseOptions *po); | 100 | void Register(ParseOptions *po); |
| 97 | bool Validate() const; | 101 | bool Validate() const; |
| @@ -112,6 +116,10 @@ class OnlineRecognizer { | @@ -112,6 +116,10 @@ class OnlineRecognizer { | ||
| 112 | /// Create a stream for decoding. | 116 | /// Create a stream for decoding. |
| 113 | std::unique_ptr<OnlineStream> CreateStream() const; | 117 | std::unique_ptr<OnlineStream> CreateStream() const; |
| 114 | 118 | ||
| 119 | + // Create a stream with context phrases | ||
| 120 | + std::unique_ptr<OnlineStream> CreateStream( | ||
| 121 | + const std::vector<std::vector<int32_t>> &context_list) const; | ||
| 122 | + | ||
| 115 | /** | 123 | /** |
| 116 | * Return true if the given stream has enough frames for decoding. | 124 | * Return true if the given stream has enough frames for decoding. |
| 117 | * Return false otherwise | 125 | * Return false otherwise |
| @@ -13,8 +13,9 @@ namespace sherpa_onnx { | @@ -13,8 +13,9 @@ namespace sherpa_onnx { | ||
| 13 | 13 | ||
| 14 | class OnlineStream::Impl { | 14 | class OnlineStream::Impl { |
| 15 | public: | 15 | public: |
| 16 | - explicit Impl(const FeatureExtractorConfig &config) | ||
| 17 | - : feat_extractor_(config) {} | 16 | + explicit Impl(const FeatureExtractorConfig &config, |
| 17 | + ContextGraphPtr context_graph) | ||
| 18 | + : feat_extractor_(config), context_graph_(context_graph) {} | ||
| 18 | 19 | ||
| 19 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 20 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 20 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); | 21 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); |
| @@ -54,16 +55,21 @@ class OnlineStream::Impl { | @@ -54,16 +55,21 @@ class OnlineStream::Impl { | ||
| 54 | 55 | ||
| 55 | std::vector<Ort::Value> &GetStates() { return states_; } | 56 | std::vector<Ort::Value> &GetStates() { return states_; } |
| 56 | 57 | ||
| 58 | + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } | ||
| 59 | + | ||
| 57 | private: | 60 | private: |
| 58 | FeatureExtractor feat_extractor_; | 61 | FeatureExtractor feat_extractor_; |
| 62 | + /// For contextual-biasing | ||
| 63 | + ContextGraphPtr context_graph_; | ||
| 59 | int32_t num_processed_frames_ = 0; // before subsampling | 64 | int32_t num_processed_frames_ = 0; // before subsampling |
| 60 | int32_t start_frame_index_ = 0; // never reset | 65 | int32_t start_frame_index_ = 0; // never reset |
| 61 | OnlineTransducerDecoderResult result_; | 66 | OnlineTransducerDecoderResult result_; |
| 62 | std::vector<Ort::Value> states_; | 67 | std::vector<Ort::Value> states_; |
| 63 | }; | 68 | }; |
| 64 | 69 | ||
| 65 | -OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | ||
| 66 | - : impl_(std::make_unique<Impl>(config)) {} | 70 | +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, |
| 71 | + ContextGraphPtr context_graph /*= nullptr */) | ||
| 72 | + : impl_(std::make_unique<Impl>(config, context_graph)) {} | ||
| 67 | 73 | ||
| 68 | OnlineStream::~OnlineStream() = default; | 74 | OnlineStream::~OnlineStream() = default; |
| 69 | 75 | ||
| @@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | @@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() { | ||
| 109 | return impl_->GetStates(); | 115 | return impl_->GetStates(); |
| 110 | } | 116 | } |
| 111 | 117 | ||
| 118 | +const ContextGraphPtr &OnlineStream::GetContextGraph() const { | ||
| 119 | + return impl_->GetContextGraph(); | ||
| 120 | +} | ||
| 121 | + | ||
| 112 | } // namespace sherpa_onnx | 122 | } // namespace sherpa_onnx |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 12 | #include "sherpa-onnx/csrc/features.h" | 13 | #include "sherpa-onnx/csrc/features.h" |
| 13 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 14 | 15 | ||
| @@ -16,7 +17,8 @@ namespace sherpa_onnx { | @@ -16,7 +17,8 @@ namespace sherpa_onnx { | ||
| 16 | 17 | ||
| 17 | class OnlineStream { | 18 | class OnlineStream { |
| 18 | public: | 19 | public: |
| 19 | - explicit OnlineStream(const FeatureExtractorConfig &config = {}); | 20 | + explicit OnlineStream(const FeatureExtractorConfig &config = {}, |
| 21 | + ContextGraphPtr context_graph = nullptr); | ||
| 20 | ~OnlineStream(); | 22 | ~OnlineStream(); |
| 21 | 23 | ||
| 22 | /** | 24 | /** |
| @@ -71,6 +73,13 @@ class OnlineStream { | @@ -71,6 +73,13 @@ class OnlineStream { | ||
| 71 | void SetStates(std::vector<Ort::Value> states); | 73 | void SetStates(std::vector<Ort::Value> states); |
| 72 | std::vector<Ort::Value> &GetStates(); | 74 | std::vector<Ort::Value> &GetStates(); |
| 73 | 75 | ||
| 76 | + /** | ||
| 77 | + * Get the context graph corresponding to this stream. | ||
| 78 | + * | ||
| 79 | + * @return Return the context graph for this stream. | ||
| 80 | + */ | ||
| 81 | + const ContextGraphPtr &GetContextGraph() const; | ||
| 82 | + | ||
| 74 | private: | 83 | private: |
| 75 | class Impl; | 84 | class Impl; |
| 76 | std::unique_ptr<Impl> impl_; | 85 | std::unique_ptr<Impl> impl_; |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | 9 | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | #include "sherpa-onnx/csrc/hypothesis.h" | 11 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 14 | 15 | ||
| @@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult { | @@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult { | ||
| 45 | OnlineTransducerDecoderResult &&other); | 46 | OnlineTransducerDecoderResult &&other); |
| 46 | }; | 47 | }; |
| 47 | 48 | ||
| 49 | +class OnlineStream; | ||
| 48 | class OnlineTransducerDecoder { | 50 | class OnlineTransducerDecoder { |
| 49 | public: | 51 | public: |
| 50 | virtual ~OnlineTransducerDecoder() = default; | 52 | virtual ~OnlineTransducerDecoder() = default; |
| @@ -76,6 +78,26 @@ class OnlineTransducerDecoder { | @@ -76,6 +78,26 @@ class OnlineTransducerDecoder { | ||
| 76 | virtual void Decode(Ort::Value encoder_out, | 78 | virtual void Decode(Ort::Value encoder_out, |
| 77 | std::vector<OnlineTransducerDecoderResult> *result) = 0; | 79 | std::vector<OnlineTransducerDecoderResult> *result) = 0; |
| 78 | 80 | ||
| 81 | + /** Run transducer beam search given the output from the encoder model. | ||
| 82 | + * | ||
| 83 | + * Note: Currently this interface is for contextual-biasing feature which | ||
| 84 | + * needs a ContextGraph owned by the OnlineStream. | ||
| 85 | + * | ||
| 86 | + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) | ||
| 87 | + * @param ss A list of OnlineStreams. | ||
| 88 | + * @param result It is modified in-place. | ||
| 89 | + * | ||
| 90 | + * @note There is no need to pass encoder_out_length here since for the | ||
| 91 | + * online decoding case, each utterance has the same number of frames | ||
| 92 | + * and there are no paddings. | ||
| 93 | + */ | ||
| 94 | + virtual void Decode(Ort::Value encoder_out, OnlineStream **ss, | ||
| 95 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 96 | + SHERPA_ONNX_LOGE( | ||
| 97 | + "This interface is for OnlineTransducerModifiedBeamSearchDecoder."); | ||
| 98 | + exit(-1); | ||
| 99 | + } | ||
| 100 | + | ||
| 79 | // used for endpointing. We need to keep decoder_out after reset | 101 | // used for endpointing. We need to keep decoder_out after reset |
| 80 | virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} | 102 | virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {} |
| 81 | }; | 103 | }; |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/log.h" | ||
| 12 | #include "sherpa-onnx/csrc/onnx-utils.h" | 13 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 13 | 14 | ||
| 14 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| @@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | @@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( | ||
| 62 | void OnlineTransducerModifiedBeamSearchDecoder::Decode( | 63 | void OnlineTransducerModifiedBeamSearchDecoder::Decode( |
| 63 | Ort::Value encoder_out, | 64 | Ort::Value encoder_out, |
| 64 | std::vector<OnlineTransducerDecoderResult> *result) { | 65 | std::vector<OnlineTransducerDecoderResult> *result) { |
| 66 | + Decode(std::move(encoder_out), nullptr, result); | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 70 | + Ort::Value encoder_out, OnlineStream **ss, | ||
| 71 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 65 | std::vector<int64_t> encoder_out_shape = | 72 | std::vector<int64_t> encoder_out_shape = |
| 66 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 73 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 67 | 74 | ||
| @@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 74 | } | 81 | } |
| 75 | 82 | ||
| 76 | int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); | 83 | int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]); |
| 84 | + | ||
| 77 | int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); | 85 | int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]); |
| 78 | int32_t vocab_size = model_->VocabSize(); | 86 | int32_t vocab_size = model_->VocabSize(); |
| 79 | 87 | ||
| @@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 142 | 150 | ||
| 143 | Hypothesis new_hyp = prev[hyp_index]; | 151 | Hypothesis new_hyp = prev[hyp_index]; |
| 144 | const float prev_lm_log_prob = new_hyp.lm_log_prob; | 152 | const float prev_lm_log_prob = new_hyp.lm_log_prob; |
| 153 | + float context_score = 0; | ||
| 154 | + auto context_state = new_hyp.context_state; | ||
| 155 | + | ||
| 145 | if (new_token != 0) { | 156 | if (new_token != 0) { |
| 146 | new_hyp.ys.push_back(new_token); | 157 | new_hyp.ys.push_back(new_token); |
| 147 | new_hyp.timestamps.push_back(t + frame_offset); | 158 | new_hyp.timestamps.push_back(t + frame_offset); |
| 148 | new_hyp.num_trailing_blanks = 0; | 159 | new_hyp.num_trailing_blanks = 0; |
| 160 | + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { | ||
| 161 | + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( | ||
| 162 | + context_state, new_token); | ||
| 163 | + context_score = context_res.first; | ||
| 164 | + new_hyp.context_state = context_res.second; | ||
| 165 | + } | ||
| 149 | if (lm_) { | 166 | if (lm_) { |
| 150 | lm_->ComputeLMScore(lm_scale_, &new_hyp); | 167 | lm_->ComputeLMScore(lm_scale_, &new_hyp); |
| 151 | } | 168 | } |
| 152 | } else { | 169 | } else { |
| 153 | ++new_hyp.num_trailing_blanks; | 170 | ++new_hyp.num_trailing_blanks; |
| 154 | } | 171 | } |
| 155 | - new_hyp.log_prob = | ||
| 156 | - p_logprob[k] - prev_lm_log_prob; // log_prob only includes the | 172 | + new_hyp.log_prob = p_logprob[k] + context_score - |
| 173 | + prev_lm_log_prob; // log_prob only includes the | ||
| 157 | // score of the transducer | 174 | // score of the transducer |
| 158 | hyps.Add(std::move(new_hyp)); | 175 | hyps.Add(std::move(new_hyp)); |
| 159 | } // for (auto k : topk) | 176 | } // for (auto k : topk) |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "sherpa-onnx/csrc/online-lm.h" | 11 | #include "sherpa-onnx/csrc/online-lm.h" |
| 12 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 12 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 13 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 13 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 14 | 15 | ||
| @@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder | @@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder | ||
| 33 | void Decode(Ort::Value encoder_out, | 34 | void Decode(Ort::Value encoder_out, |
| 34 | std::vector<OnlineTransducerDecoderResult> *result) override; | 35 | std::vector<OnlineTransducerDecoderResult> *result) override; |
| 35 | 36 | ||
| 37 | + void Decode(Ort::Value encoder_out, OnlineStream **ss, | ||
| 38 | + std::vector<OnlineTransducerDecoderResult> *result) override; | ||
| 39 | + | ||
| 36 | void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; | 40 | void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; |
| 37 | 41 | ||
| 38 | private: | 42 | private: |
| @@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 22 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 23 | .def(py::init<const FeatureExtractorConfig &, | 23 | .def(py::init<const FeatureExtractorConfig &, |
| 24 | const OnlineTransducerModelConfig &, const OnlineLMConfig &, | 24 | const OnlineTransducerModelConfig &, const OnlineLMConfig &, |
| 25 | - const EndpointConfig &, bool, const std::string &, | ||
| 26 | - int32_t>(), | 25 | + const EndpointConfig &, bool, const std::string &, int32_t, |
| 26 | + float>(), | ||
| 27 | py::arg("feat_config"), py::arg("model_config"), | 27 | py::arg("feat_config"), py::arg("model_config"), |
| 28 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), | 28 | py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), |
| 29 | py::arg("enable_endpoint"), py::arg("decoding_method"), | 29 | py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 30 | - py::arg("max_active_paths")) | 30 | + py::arg("max_active_paths"), py::arg("context_score")) |
| 31 | .def_readwrite("feat_config", &PyClass::feat_config) | 31 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 32 | .def_readwrite("model_config", &PyClass::model_config) | 32 | .def_readwrite("model_config", &PyClass::model_config) |
| 33 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 33 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| 34 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | 34 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) |
| 35 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 35 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 36 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 36 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 37 | + .def_readwrite("context_score", &PyClass::context_score) | ||
| 37 | .def("__str__", &PyClass::ToString); | 38 | .def("__str__", &PyClass::ToString); |
| 38 | } | 39 | } |
| 39 | 40 | ||
| @@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) { | @@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) { | ||
| 44 | using PyClass = OnlineRecognizer; | 45 | using PyClass = OnlineRecognizer; |
| 45 | py::class_<PyClass>(*m, "OnlineRecognizer") | 46 | py::class_<PyClass>(*m, "OnlineRecognizer") |
| 46 | .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) | 47 | .def(py::init<const OnlineRecognizerConfig &>(), py::arg("config")) |
| 47 | - .def("create_stream", &PyClass::CreateStream) | 48 | + .def("create_stream", |
| 49 | + [](const PyClass &self) { return self.CreateStream(); }) | ||
| 50 | + .def( | ||
| 51 | + "create_stream", | ||
| 52 | + [](PyClass &self, | ||
| 53 | + const std::vector<std::vector<int32_t>> &contexts_list) { | ||
| 54 | + return self.CreateStream(contexts_list); | ||
| 55 | + }, | ||
| 56 | + py::arg("contexts_list")) | ||
| 48 | .def("is_ready", &PyClass::IsReady) | 57 | .def("is_ready", &PyClass::IsReady) |
| 49 | .def("decode_stream", &PyClass::DecodeStream) | 58 | .def("decode_stream", &PyClass::DecodeStream) |
| 50 | .def("decode_streams", | 59 | .def("decode_streams", |
| 1 | # Copyright (c) 2023 Xiaomi Corporation | 1 | # Copyright (c) 2023 Xiaomi Corporation |
| 2 | from pathlib import Path | 2 | from pathlib import Path |
| 3 | -from typing import List | 3 | +from typing import List, Optional |
| 4 | 4 | ||
| 5 | from _sherpa_onnx import ( | 5 | from _sherpa_onnx import ( |
| 6 | EndpointConfig, | 6 | EndpointConfig, |
| @@ -39,6 +39,7 @@ class OnlineRecognizer(object): | @@ -39,6 +39,7 @@ class OnlineRecognizer(object): | ||
| 39 | rule3_min_utterance_length: float = 20.0, | 39 | rule3_min_utterance_length: float = 20.0, |
| 40 | decoding_method: str = "greedy_search", | 40 | decoding_method: str = "greedy_search", |
| 41 | max_active_paths: int = 4, | 41 | max_active_paths: int = 4, |
| 42 | + context_score: float = 1.5, | ||
| 42 | provider: str = "cpu", | 43 | provider: str = "cpu", |
| 43 | ): | 44 | ): |
| 44 | """ | 45 | """ |
| @@ -124,13 +125,17 @@ class OnlineRecognizer(object): | @@ -124,13 +125,17 @@ class OnlineRecognizer(object): | ||
| 124 | enable_endpoint=enable_endpoint_detection, | 125 | enable_endpoint=enable_endpoint_detection, |
| 125 | decoding_method=decoding_method, | 126 | decoding_method=decoding_method, |
| 126 | max_active_paths=max_active_paths, | 127 | max_active_paths=max_active_paths, |
| 128 | + context_score=context_score, | ||
| 127 | ) | 129 | ) |
| 128 | 130 | ||
| 129 | self.recognizer = _Recognizer(recognizer_config) | 131 | self.recognizer = _Recognizer(recognizer_config) |
| 130 | self.config = recognizer_config | 132 | self.config = recognizer_config |
| 131 | 133 | ||
| 132 | - def create_stream(self): | ||
| 133 | - return self.recognizer.create_stream() | 134 | + def create_stream(self, contexts_list : Optional[List[List[int]]] = None): |
| 135 | + if contexts_list is None: | ||
| 136 | + return self.recognizer.create_stream() | ||
| 137 | + else: | ||
| 138 | + return self.recognizer.create_stream(contexts_list) | ||
| 134 | 139 | ||
| 135 | def decode_stream(self, s: OnlineStream): | 140 | def decode_stream(self, s: OnlineStream): |
| 136 | self.recognizer.decode_stream(s) | 141 | self.recognizer.decode_stream(s) |
-
请 注册 或 登录 后发表评论