Committed by
GitHub
Implement context biasing with a Aho Corasick automata (#145)
* Implement context graph * Modify the interface to support context biasing * Support context biasing in modified beam search; add python wrapper * Support context biasing in python api example * Minor fixes * Fix context graph * Minor fixes * Fix tests * Fix style * Fix style * Fix comments * Minor fixes * Add missing header * Replace std::shared_ptr with std::unique_ptr for effciency * Build graph in constructor * Fix comments * Minor fixes * Fix docs
正在显示
23 个修改的文件
包含
515 行增加
和
29 行删除
| @@ -54,7 +54,7 @@ jobs: | @@ -54,7 +54,7 @@ jobs: | ||
| 54 | - name: Install Python dependencies | 54 | - name: Install Python dependencies |
| 55 | shell: bash | 55 | shell: bash |
| 56 | run: | | 56 | run: | |
| 57 | - python3 -m pip install --upgrade pip numpy | 57 | + python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 |
| 58 | 58 | ||
| 59 | - name: Install sherpa-onnx | 59 | - name: Install sherpa-onnx |
| 60 | shell: bash | 60 | shell: bash |
| @@ -43,9 +43,10 @@ import argparse | @@ -43,9 +43,10 @@ import argparse | ||
| 43 | import time | 43 | import time |
| 44 | import wave | 44 | import wave |
| 45 | from pathlib import Path | 45 | from pathlib import Path |
| 46 | -from typing import Tuple | 46 | +from typing import List, Tuple |
| 47 | 47 | ||
| 48 | import numpy as np | 48 | import numpy as np |
| 49 | +import sentencepiece as spm | ||
| 49 | import sherpa_onnx | 50 | import sherpa_onnx |
| 50 | 51 | ||
| 51 | 52 | ||
| @@ -61,6 +62,47 @@ def get_args(): | @@ -61,6 +62,47 @@ def get_args(): | ||
| 61 | ) | 62 | ) |
| 62 | 63 | ||
| 63 | parser.add_argument( | 64 | parser.add_argument( |
| 65 | + "--bpe-model", | ||
| 66 | + type=str, | ||
| 67 | + default="", | ||
| 68 | + help=""" | ||
| 69 | + Path to bpe.model, | ||
| 70 | + Used only when --decoding-method=modified_beam_search | ||
| 71 | + """, | ||
| 72 | + ) | ||
| 73 | + | ||
| 74 | + parser.add_argument( | ||
| 75 | + "--modeling-unit", | ||
| 76 | + type=str, | ||
| 77 | + default="char", | ||
| 78 | + help=""" | ||
| 79 | + The type of modeling unit. | ||
| 80 | + Valid values are bpe, bpe+char, char. | ||
| 81 | + Note: the char here means characters in CJK languages. | ||
| 82 | + """, | ||
| 83 | + ) | ||
| 84 | + | ||
| 85 | + parser.add_argument( | ||
| 86 | + "--contexts", | ||
| 87 | + type=str, | ||
| 88 | + default="", | ||
| 89 | + help=""" | ||
| 90 | + The context list, it is a string containing some words/phrases separated | ||
| 91 | + with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY". | ||
| 92 | + """, | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + parser.add_argument( | ||
| 96 | + "--context-score", | ||
| 97 | + type=float, | ||
| 98 | + default=1.5, | ||
| 99 | + help=""" | ||
| 100 | + The context score of each token for biasing word/phrase. Used only if | ||
| 101 | + --contexts is given. | ||
| 102 | + """, | ||
| 103 | + ) | ||
| 104 | + | ||
| 105 | + parser.add_argument( | ||
| 64 | "--encoder", | 106 | "--encoder", |
| 65 | default="", | 107 | default="", |
| 66 | type=str, | 108 | type=str, |
| @@ -153,6 +195,24 @@ def assert_file_exists(filename: str): | @@ -153,6 +195,24 @@ def assert_file_exists(filename: str): | ||
| 153 | ) | 195 | ) |
| 154 | 196 | ||
| 155 | 197 | ||
| 198 | +def encode_contexts(args, contexts: List[str]) -> List[List[int]]: | ||
| 199 | + sp = None | ||
| 200 | + if "bpe" in args.modeling_unit: | ||
| 201 | + assert_file_exists(args.bpe_model) | ||
| 202 | + sp = spm.SentencePieceProcessor() | ||
| 203 | + sp.load(args.bpe_model) | ||
| 204 | + tokens = {} | ||
| 205 | + with open(args.tokens, "r", encoding="utf-8") as f: | ||
| 206 | + for line in f: | ||
| 207 | + toks = line.strip().split() | ||
| 208 | + assert len(toks) == 2, len(toks) | ||
| 209 | + assert toks[0] not in tokens, f"Duplicate token: {toks} " | ||
| 210 | + tokens[toks[0]] = int(toks[1]) | ||
| 211 | + return sherpa_onnx.encode_contexts( | ||
| 212 | + modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens | ||
| 213 | + ) | ||
| 214 | + | ||
| 215 | + | ||
| 156 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | 216 | def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: |
| 157 | """ | 217 | """ |
| 158 | Args: | 218 | Args: |
| @@ -182,10 +242,17 @@ def main(): | @@ -182,10 +242,17 @@ def main(): | ||
| 182 | args = get_args() | 242 | args = get_args() |
| 183 | assert_file_exists(args.tokens) | 243 | assert_file_exists(args.tokens) |
| 184 | assert args.num_threads > 0, args.num_threads | 244 | assert args.num_threads > 0, args.num_threads |
| 245 | + | ||
| 246 | + contexts_list = [] | ||
| 185 | if args.encoder: | 247 | if args.encoder: |
| 186 | assert len(args.paraformer) == 0, args.paraformer | 248 | assert len(args.paraformer) == 0, args.paraformer |
| 187 | assert len(args.nemo_ctc) == 0, args.nemo_ctc | 249 | assert len(args.nemo_ctc) == 0, args.nemo_ctc |
| 188 | 250 | ||
| 251 | + contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] | ||
| 252 | + if contexts: | ||
| 253 | + print(f"Contexts list: {contexts}") | ||
| 254 | + contexts_list = encode_contexts(args, contexts) | ||
| 255 | + | ||
| 189 | assert_file_exists(args.encoder) | 256 | assert_file_exists(args.encoder) |
| 190 | assert_file_exists(args.decoder) | 257 | assert_file_exists(args.decoder) |
| 191 | assert_file_exists(args.joiner) | 258 | assert_file_exists(args.joiner) |
| @@ -199,6 +266,7 @@ def main(): | @@ -199,6 +266,7 @@ def main(): | ||
| 199 | sample_rate=args.sample_rate, | 266 | sample_rate=args.sample_rate, |
| 200 | feature_dim=args.feature_dim, | 267 | feature_dim=args.feature_dim, |
| 201 | decoding_method=args.decoding_method, | 268 | decoding_method=args.decoding_method, |
| 269 | + context_score=args.context_score, | ||
| 202 | debug=args.debug, | 270 | debug=args.debug, |
| 203 | ) | 271 | ) |
| 204 | elif args.paraformer: | 272 | elif args.paraformer: |
| @@ -238,8 +306,12 @@ def main(): | @@ -238,8 +306,12 @@ def main(): | ||
| 238 | samples, sample_rate = read_wave(wave_filename) | 306 | samples, sample_rate = read_wave(wave_filename) |
| 239 | duration = len(samples) / sample_rate | 307 | duration = len(samples) / sample_rate |
| 240 | total_duration += duration | 308 | total_duration += duration |
| 241 | - | ||
| 242 | - s = recognizer.create_stream() | 309 | + if contexts_list: |
| 310 | + assert len(args.paraformer) == 0, args.paraformer | ||
| 311 | + assert len(args.nemo_ctc) == 0, args.nemo_ctc | ||
| 312 | + s = recognizer.create_stream(contexts_list=contexts_list) | ||
| 313 | + else: | ||
| 314 | + s = recognizer.create_stream() | ||
| 243 | s.accept_waveform(sample_rate, samples) | 315 | s.accept_waveform(sample_rate, samples) |
| 244 | 316 | ||
| 245 | streams.append(s) | 317 | streams.append(s) |
| @@ -12,6 +12,7 @@ endif() | @@ -12,6 +12,7 @@ endif() | ||
| 12 | 12 | ||
| 13 | set(sources | 13 | set(sources |
| 14 | cat.cc | 14 | cat.cc |
| 15 | + context-graph.cc | ||
| 15 | endpoint.cc | 16 | endpoint.cc |
| 16 | features.cc | 17 | features.cc |
| 17 | file-utils.cc | 18 | file-utils.cc |
| @@ -248,6 +249,7 @@ endif() | @@ -248,6 +249,7 @@ endif() | ||
| 248 | if(SHERPA_ONNX_ENABLE_TESTS) | 249 | if(SHERPA_ONNX_ENABLE_TESTS) |
| 249 | set(sherpa_onnx_test_srcs | 250 | set(sherpa_onnx_test_srcs |
| 250 | cat-test.cc | 251 | cat-test.cc |
| 252 | + context-graph-test.cc | ||
| 251 | packed-sequence-test.cc | 253 | packed-sequence-test.cc |
| 252 | pad-sequence-test.cc | 254 | pad-sequence-test.cc |
| 253 | slice-test.cc | 255 | slice-test.cc |
sherpa-onnx/csrc/context-graph-test.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/context-graph-test.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 6 | + | ||
| 7 | +#include <map> | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "gtest/gtest.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +TEST(ContextGraph, TestBasic) { | ||
| 16 | + std::vector<std::string> contexts_str( | ||
| 17 | + {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); | ||
| 18 | + std::vector<std::vector<int32_t>> contexts; | ||
| 19 | + for (int32_t i = 0; i < contexts_str.size(); ++i) { | ||
| 20 | + contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); | ||
| 21 | + } | ||
| 22 | + auto context_graph = ContextGraph(contexts, 1); | ||
| 23 | + | ||
| 24 | + auto queries = std::map<std::string, float>{ | ||
| 25 | + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6}, | ||
| 26 | + {"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; | ||
| 27 | + | ||
| 28 | + for (const auto &iter : queries) { | ||
| 29 | + float total_scores = 0; | ||
| 30 | + auto state = context_graph.Root(); | ||
| 31 | + for (auto q : iter.first) { | ||
| 32 | + auto res = context_graph.ForwardOneStep(state, q); | ||
| 33 | + total_scores += res.first; | ||
| 34 | + state = res.second; | ||
| 35 | + } | ||
| 36 | + auto res = context_graph.Finalize(state); | ||
| 37 | + EXPECT_EQ(res.second->token, -1); | ||
| 38 | + total_scores += res.first; | ||
| 39 | + EXPECT_EQ(total_scores, iter.second); | ||
| 40 | + } | ||
| 41 | +} | ||
| 42 | + | ||
| 43 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/context-graph.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/context-graph.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 6 | + | ||
| 7 | +#include <cassert> | ||
| 8 | +#include <queue> | ||
| 9 | +#include <utility> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | +void ContextGraph::Build( | ||
| 13 | + const std::vector<std::vector<int32_t>> &token_ids) const { | ||
| 14 | + for (int32_t i = 0; i < token_ids.size(); ++i) { | ||
| 15 | + auto node = root_.get(); | ||
| 16 | + for (int32_t j = 0; j < token_ids[i].size(); ++j) { | ||
| 17 | + int32_t token = token_ids[i][j]; | ||
| 18 | + if (0 == node->next.count(token)) { | ||
| 19 | + bool is_end = j == token_ids[i].size() - 1; | ||
| 20 | + node->next[token] = std::make_unique<ContextState>( | ||
| 21 | + token, context_score_, node->node_score + context_score_, | ||
| 22 | + is_end ? 0 : node->local_node_score + context_score_, is_end); | ||
| 23 | + } | ||
| 24 | + node = node->next[token].get(); | ||
| 25 | + } | ||
| 26 | + } | ||
| 27 | + FillFailOutput(); | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +std::pair<float, const ContextState *> ContextGraph::ForwardOneStep( | ||
| 31 | + const ContextState *state, int32_t token) const { | ||
| 32 | + const ContextState *node; | ||
| 33 | + float score; | ||
| 34 | + if (1 == state->next.count(token)) { | ||
| 35 | + node = state->next.at(token).get(); | ||
| 36 | + score = node->token_score; | ||
| 37 | + if (state->is_end) score += state->node_score; | ||
| 38 | + } else { | ||
| 39 | + node = state->fail; | ||
| 40 | + while (0 == node->next.count(token)) { | ||
| 41 | + node = node->fail; | ||
| 42 | + if (-1 == node->token) break; // root | ||
| 43 | + } | ||
| 44 | + if (1 == node->next.count(token)) { | ||
| 45 | + node = node->next.at(token).get(); | ||
| 46 | + } | ||
| 47 | + score = node->node_score - state->local_node_score; | ||
| 48 | + } | ||
| 49 | + SHERPA_ONNX_CHECK(nullptr != node); | ||
| 50 | + float matched_score = 0; | ||
| 51 | + auto output = node->output; | ||
| 52 | + while (nullptr != output) { | ||
| 53 | + matched_score += output->node_score; | ||
| 54 | + output = output->output; | ||
| 55 | + } | ||
| 56 | + return std::make_pair(score + matched_score, node); | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +std::pair<float, const ContextState *> ContextGraph::Finalize( | ||
| 60 | + const ContextState *state) const { | ||
| 61 | + float score = -state->node_score; | ||
| 62 | + if (state->is_end) { | ||
| 63 | + score = 0; | ||
| 64 | + } | ||
| 65 | + return std::make_pair(score, root_.get()); | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +void ContextGraph::FillFailOutput() const { | ||
| 69 | + std::queue<const ContextState *> node_queue; | ||
| 70 | + for (auto &kv : root_->next) { | ||
| 71 | + kv.second->fail = root_.get(); | ||
| 72 | + node_queue.push(kv.second.get()); | ||
| 73 | + } | ||
| 74 | + while (!node_queue.empty()) { | ||
| 75 | + auto current_node = node_queue.front(); | ||
| 76 | + node_queue.pop(); | ||
| 77 | + for (auto &kv : current_node->next) { | ||
| 78 | + auto fail = current_node->fail; | ||
| 79 | + if (1 == fail->next.count(kv.first)) { | ||
| 80 | + fail = fail->next.at(kv.first).get(); | ||
| 81 | + } else { | ||
| 82 | + fail = fail->fail; | ||
| 83 | + while (0 == fail->next.count(kv.first)) { | ||
| 84 | + fail = fail->fail; | ||
| 85 | + if (-1 == fail->token) break; | ||
| 86 | + } | ||
| 87 | + if (1 == fail->next.count(kv.first)) | ||
| 88 | + fail = fail->next.at(kv.first).get(); | ||
| 89 | + } | ||
| 90 | + kv.second->fail = fail; | ||
| 91 | + // fill the output arc | ||
| 92 | + auto output = fail; | ||
| 93 | + while (!output->is_end) { | ||
| 94 | + output = output->fail; | ||
| 95 | + if (-1 == output->token) { | ||
| 96 | + output = nullptr; | ||
| 97 | + break; | ||
| 98 | + } | ||
| 99 | + } | ||
| 100 | + kv.second->output = output; | ||
| 101 | + node_queue.push(kv.second.get()); | ||
| 102 | + } | ||
| 103 | + } | ||
| 104 | +} | ||
| 105 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/context-graph.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/context-graph.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <unordered_map> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/log.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class ContextGraph; | ||
| 18 | +using ContextGraphPtr = std::shared_ptr<ContextGraph>; | ||
| 19 | + | ||
| 20 | +struct ContextState { | ||
| 21 | + int32_t token; | ||
| 22 | + float token_score; | ||
| 23 | + float node_score; | ||
| 24 | + float local_node_score; | ||
| 25 | + bool is_end; | ||
| 26 | + std::unordered_map<int32_t, std::unique_ptr<ContextState>> next; | ||
| 27 | + const ContextState *fail = nullptr; | ||
| 28 | + const ContextState *output = nullptr; | ||
| 29 | + | ||
| 30 | + ContextState() = default; | ||
| 31 | + ContextState(int32_t token, float token_score, float node_score, | ||
| 32 | + float local_node_score, bool is_end) | ||
| 33 | + : token(token), | ||
| 34 | + token_score(token_score), | ||
| 35 | + node_score(node_score), | ||
| 36 | + local_node_score(local_node_score), | ||
| 37 | + is_end(is_end) {} | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +class ContextGraph { | ||
| 41 | + public: | ||
| 42 | + ContextGraph() = default; | ||
| 43 | + ContextGraph(const std::vector<std::vector<int32_t>> &token_ids, | ||
| 44 | + float context_score) | ||
| 45 | + : context_score_(context_score) { | ||
| 46 | + root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false); | ||
| 47 | + root_->fail = root_.get(); | ||
| 48 | + Build(token_ids); | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + std::pair<float, const ContextState *> ForwardOneStep( | ||
| 52 | + const ContextState *state, int32_t token_id) const; | ||
| 53 | + std::pair<float, const ContextState *> Finalize( | ||
| 54 | + const ContextState *state) const; | ||
| 55 | + | ||
| 56 | + const ContextState *Root() const { return root_.get(); } | ||
| 57 | + | ||
| 58 | + private: | ||
| 59 | + float context_score_; | ||
| 60 | + std::unique_ptr<ContextState> root_; | ||
| 61 | + void Build(const std::vector<std::vector<int32_t>> &token_ids) const; | ||
| 62 | + void FillFailOutput() const; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx | ||
| 66 | +#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ |
| @@ -14,6 +14,7 @@ | @@ -14,6 +14,7 @@ | ||
| 14 | #include <vector> | 14 | #include <vector> |
| 15 | 15 | ||
| 16 | #include "onnxruntime_cxx_api.h" // NOLINT | 16 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 17 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 17 | #include "sherpa-onnx/csrc/math.h" | 18 | #include "sherpa-onnx/csrc/math.h" |
| 18 | #include "sherpa-onnx/csrc/onnx-utils.h" | 19 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 19 | 20 | ||
| @@ -39,11 +40,18 @@ struct Hypothesis { | @@ -39,11 +40,18 @@ struct Hypothesis { | ||
| 39 | // the nn lm states | 40 | // the nn lm states |
| 40 | std::vector<CopyableOrtValue> nn_lm_states; | 41 | std::vector<CopyableOrtValue> nn_lm_states; |
| 41 | 42 | ||
| 43 | + const ContextState *context_state; | ||
| 44 | + | ||
| 45 | + // TODO(fangjun): Make it configurable | ||
| 46 | + // the minimum of tokens in a chunk for streaming RNN LM | ||
| 47 | + int32_t lm_rescore_min_chunk = 2; // a const | ||
| 48 | + | ||
| 42 | int32_t num_trailing_blanks = 0; | 49 | int32_t num_trailing_blanks = 0; |
| 43 | 50 | ||
| 44 | Hypothesis() = default; | 51 | Hypothesis() = default; |
| 45 | - Hypothesis(const std::vector<int64_t> &ys, double log_prob) | ||
| 46 | - : ys(ys), log_prob(log_prob) {} | 52 | + Hypothesis(const std::vector<int64_t> &ys, double log_prob, |
| 53 | + const ContextState *context_state = nullptr) | ||
| 54 | + : ys(ys), log_prob(log_prob), context_state(context_state) {} | ||
| 47 | 55 | ||
| 48 | double TotalLogProb() const { return log_prob + lm_log_prob; } | 56 | double TotalLogProb() const { return log_prob + lm_log_prob; } |
| 49 | 57 |
| @@ -6,7 +6,9 @@ | @@ -6,7 +6,9 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <vector> | ||
| 9 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 11 | #include "sherpa-onnx/csrc/offline-stream.h" | 13 | #include "sherpa-onnx/csrc/offline-stream.h" |
| 12 | 14 | ||
| @@ -19,6 +21,12 @@ class OfflineRecognizerImpl { | @@ -19,6 +21,12 @@ class OfflineRecognizerImpl { | ||
| 19 | 21 | ||
| 20 | virtual ~OfflineRecognizerImpl() = default; | 22 | virtual ~OfflineRecognizerImpl() = default; |
| 21 | 23 | ||
| 24 | + virtual std::unique_ptr<OfflineStream> CreateStream( | ||
| 25 | + const std::vector<std::vector<int32_t>> &context_list) const { | ||
| 26 | + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); | ||
| 27 | + exit(-1); | ||
| 28 | + } | ||
| 29 | + | ||
| 22 | virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | 30 | virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; |
| 23 | 31 | ||
| 24 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; | 32 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <utility> | 10 | #include <utility> |
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 13 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 15 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 15 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 16 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| @@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 72 | } | 73 | } |
| 73 | } | 74 | } |
| 74 | 75 | ||
| 76 | + std::unique_ptr<OfflineStream> CreateStream( | ||
| 77 | + const std::vector<std::vector<int32_t>> &context_list) const override { | ||
| 78 | + // We create context_graph at this level, because we might have default | ||
| 79 | + // context_graph(will be added later if needed) that belongs to the whole | ||
| 80 | + // model rather than each stream. | ||
| 81 | + auto context_graph = | ||
| 82 | + std::make_shared<ContextGraph>(context_list, config_.context_score); | ||
| 83 | + return std::make_unique<OfflineStream>(config_.feat_config, context_graph); | ||
| 84 | + } | ||
| 85 | + | ||
| 75 | std::unique_ptr<OfflineStream> CreateStream() const override { | 86 | std::unique_ptr<OfflineStream> CreateStream() const override { |
| 76 | return std::make_unique<OfflineStream>(config_.feat_config); | 87 | return std::make_unique<OfflineStream>(config_.feat_config); |
| 77 | } | 88 | } |
| @@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 117 | -23.025850929940457f); | 128 | -23.025850929940457f); |
| 118 | 129 | ||
| 119 | auto t = model_->RunEncoder(std::move(x), std::move(x_length)); | 130 | auto t = model_->RunEncoder(std::move(x), std::move(x_length)); |
| 120 | - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | 131 | + auto results = |
| 132 | + decoder_->Decode(std::move(t.first), std::move(t.second), ss, n); | ||
| 121 | 133 | ||
| 122 | int32_t frame_shift_ms = 10; | 134 | int32_t frame_shift_ms = 10; |
| 123 | for (int32_t i = 0; i != n; ++i) { | 135 | for (int32_t i = 0; i != n; ++i) { |
| @@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 26 | 26 | ||
| 27 | po->Register("max-active-paths", &max_active_paths, | 27 | po->Register("max-active-paths", &max_active_paths, |
| 28 | "Used only when decoding_method is modified_beam_search"); | 28 | "Used only when decoding_method is modified_beam_search"); |
| 29 | + po->Register("context-score", &context_score, | ||
| 30 | + "The bonus score for each token in context word/phrase. " | ||
| 31 | + "Used only when decoding_method is modified_beam_search"); | ||
| 29 | } | 32 | } |
| 30 | 33 | ||
| 31 | bool OfflineRecognizerConfig::Validate() const { | 34 | bool OfflineRecognizerConfig::Validate() const { |
| @@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 49 | os << "model_config=" << model_config.ToString() << ", "; | 52 | os << "model_config=" << model_config.ToString() << ", "; |
| 50 | os << "lm_config=" << lm_config.ToString() << ", "; | 53 | os << "lm_config=" << lm_config.ToString() << ", "; |
| 51 | os << "decoding_method=\"" << decoding_method << "\", "; | 54 | os << "decoding_method=\"" << decoding_method << "\", "; |
| 52 | - os << "max_active_paths=" << max_active_paths << ")"; | 55 | + os << "max_active_paths=" << max_active_paths << ", "; |
| 56 | + os << "context_score=" << context_score << ")"; | ||
| 53 | 57 | ||
| 54 | return os.str(); | 58 | return os.str(); |
| 55 | } | 59 | } |
| @@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | @@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) | ||
| 59 | 63 | ||
| 60 | OfflineRecognizer::~OfflineRecognizer() = default; | 64 | OfflineRecognizer::~OfflineRecognizer() = default; |
| 61 | 65 | ||
| 66 | +std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream( | ||
| 67 | + const std::vector<std::vector<int32_t>> &context_list) const { | ||
| 68 | + return impl_->CreateStream(context_list); | ||
| 69 | +} | ||
| 70 | + | ||
| 62 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { | 71 | std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const { |
| 63 | return impl_->CreateStream(); | 72 | return impl_->CreateStream(); |
| 64 | } | 73 | } |
| @@ -26,6 +26,7 @@ struct OfflineRecognizerConfig { | @@ -26,6 +26,7 @@ struct OfflineRecognizerConfig { | ||
| 26 | 26 | ||
| 27 | std::string decoding_method = "greedy_search"; | 27 | std::string decoding_method = "greedy_search"; |
| 28 | int32_t max_active_paths = 4; | 28 | int32_t max_active_paths = 4; |
| 29 | + float context_score = 1.5; | ||
| 29 | // only greedy_search is implemented | 30 | // only greedy_search is implemented |
| 30 | // TODO(fangjun): Implement modified_beam_search | 31 | // TODO(fangjun): Implement modified_beam_search |
| 31 | 32 | ||
| @@ -34,12 +35,13 @@ struct OfflineRecognizerConfig { | @@ -34,12 +35,13 @@ struct OfflineRecognizerConfig { | ||
| 34 | const OfflineModelConfig &model_config, | 35 | const OfflineModelConfig &model_config, |
| 35 | const OfflineLMConfig &lm_config, | 36 | const OfflineLMConfig &lm_config, |
| 36 | const std::string &decoding_method, | 37 | const std::string &decoding_method, |
| 37 | - int32_t max_active_paths) | 38 | + int32_t max_active_paths, float context_score) |
| 38 | : feat_config(feat_config), | 39 | : feat_config(feat_config), |
| 39 | model_config(model_config), | 40 | model_config(model_config), |
| 40 | lm_config(lm_config), | 41 | lm_config(lm_config), |
| 41 | decoding_method(decoding_method), | 42 | decoding_method(decoding_method), |
| 42 | - max_active_paths(max_active_paths) {} | 43 | + max_active_paths(max_active_paths), |
| 44 | + context_score(context_score) {} | ||
| 43 | 45 | ||
| 44 | void Register(ParseOptions *po); | 46 | void Register(ParseOptions *po); |
| 45 | bool Validate() const; | 47 | bool Validate() const; |
| @@ -58,6 +60,10 @@ class OfflineRecognizer { | @@ -58,6 +60,10 @@ class OfflineRecognizer { | ||
| 58 | /// Create a stream for decoding. | 60 | /// Create a stream for decoding. |
| 59 | std::unique_ptr<OfflineStream> CreateStream() const; | 61 | std::unique_ptr<OfflineStream> CreateStream() const; |
| 60 | 62 | ||
| 63 | + /// Create a stream for decoding. | ||
| 64 | + std::unique_ptr<OfflineStream> CreateStream( | ||
| 65 | + const std::vector<std::vector<int32_t>> &context_list) const; | ||
| 66 | + | ||
| 61 | /** Decode a single stream | 67 | /** Decode a single stream |
| 62 | * | 68 | * |
| 63 | * @param s The stream to decode. | 69 | * @param s The stream to decode. |
| @@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const { | @@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const { | ||
| 75 | 75 | ||
| 76 | class OfflineStream::Impl { | 76 | class OfflineStream::Impl { |
| 77 | public: | 77 | public: |
| 78 | - explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { | 78 | + explicit Impl(const OfflineFeatureExtractorConfig &config, |
| 79 | + ContextGraphPtr context_graph) | ||
| 80 | + : config_(config), context_graph_(context_graph) { | ||
| 79 | opts_.frame_opts.dither = 0; | 81 | opts_.frame_opts.dither = 0; |
| 80 | opts_.frame_opts.snip_edges = false; | 82 | opts_.frame_opts.snip_edges = false; |
| 81 | opts_.frame_opts.samp_freq = config.sampling_rate; | 83 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| @@ -152,6 +154,8 @@ class OfflineStream::Impl { | @@ -152,6 +154,8 @@ class OfflineStream::Impl { | ||
| 152 | 154 | ||
| 153 | const OfflineRecognitionResult &GetResult() const { return r_; } | 155 | const OfflineRecognitionResult &GetResult() const { return r_; } |
| 154 | 156 | ||
| 157 | + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } | ||
| 158 | + | ||
| 155 | private: | 159 | private: |
| 156 | void NemoNormalizeFeatures(float *p, int32_t num_frames, | 160 | void NemoNormalizeFeatures(float *p, int32_t num_frames, |
| 157 | int32_t feature_dim) const { | 161 | int32_t feature_dim) const { |
| @@ -189,11 +193,13 @@ class OfflineStream::Impl { | @@ -189,11 +193,13 @@ class OfflineStream::Impl { | ||
| 189 | std::unique_ptr<knf::OnlineFbank> fbank_; | 193 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 190 | knf::FbankOptions opts_; | 194 | knf::FbankOptions opts_; |
| 191 | OfflineRecognitionResult r_; | 195 | OfflineRecognitionResult r_; |
| 196 | + ContextGraphPtr context_graph_; | ||
| 192 | }; | 197 | }; |
| 193 | 198 | ||
| 194 | OfflineStream::OfflineStream( | 199 | OfflineStream::OfflineStream( |
| 195 | - const OfflineFeatureExtractorConfig &config /*= {}*/) | ||
| 196 | - : impl_(std::make_unique<Impl>(config)) {} | 200 | + const OfflineFeatureExtractorConfig &config /*= {}*/, |
| 201 | + ContextGraphPtr context_graph /*= nullptr*/) | ||
| 202 | + : impl_(std::make_unique<Impl>(config, context_graph)) {} | ||
| 197 | 203 | ||
| 198 | OfflineStream::~OfflineStream() = default; | 204 | OfflineStream::~OfflineStream() = default; |
| 199 | 205 | ||
| @@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) { | @@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) { | ||
| 212 | impl_->SetResult(r); | 218 | impl_->SetResult(r); |
| 213 | } | 219 | } |
| 214 | 220 | ||
| 221 | +const ContextGraphPtr &OfflineStream::GetContextGraph() const { | ||
| 222 | + return impl_->GetContextGraph(); | ||
| 223 | +} | ||
| 224 | + | ||
| 215 | const OfflineRecognitionResult &OfflineStream::GetResult() const { | 225 | const OfflineRecognitionResult &OfflineStream::GetResult() const { |
| 216 | return impl_->GetResult(); | 226 | return impl_->GetResult(); |
| 217 | } | 227 | } |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <string> | 10 | #include <string> |
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 13 | #include "sherpa-onnx/csrc/parse-options.h" | 14 | #include "sherpa-onnx/csrc/parse-options.h" |
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| @@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig { | @@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig { | ||
| 66 | 67 | ||
| 67 | class OfflineStream { | 68 | class OfflineStream { |
| 68 | public: | 69 | public: |
| 69 | - explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}); | 70 | + explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, |
| 71 | + ContextGraphPtr context_graph = nullptr); | ||
| 70 | ~OfflineStream(); | 72 | ~OfflineStream(); |
| 71 | 73 | ||
| 72 | /** | 74 | /** |
| @@ -96,6 +98,9 @@ class OfflineStream { | @@ -96,6 +98,9 @@ class OfflineStream { | ||
| 96 | /** Get the recognition result of this stream */ | 98 | /** Get the recognition result of this stream */ |
| 97 | const OfflineRecognitionResult &GetResult() const; | 99 | const OfflineRecognitionResult &GetResult() const; |
| 98 | 100 | ||
| 101 | + /** Get the ContextGraph of this stream */ | ||
| 102 | + const ContextGraphPtr &GetContextGraph() const; | ||
| 103 | + | ||
| 99 | private: | 104 | private: |
| 100 | class Impl; | 105 | class Impl; |
| 101 | std::unique_ptr<Impl> impl_; | 106 | std::unique_ptr<Impl> impl_; |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | +#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| @@ -33,7 +34,8 @@ class OfflineTransducerDecoder { | @@ -33,7 +34,8 @@ class OfflineTransducerDecoder { | ||
| 33 | * @return Return a vector of size `N` containing the decoded results. | 34 | * @return Return a vector of size `N` containing the decoded results. |
| 34 | */ | 35 | */ |
| 35 | virtual std::vector<OfflineTransducerDecoderResult> Decode( | 36 | virtual std::vector<OfflineTransducerDecoderResult> Decode( |
| 36 | - Ort::Value encoder_out, Ort::Value encoder_out_length) = 0; | 37 | + Ort::Value encoder_out, Ort::Value encoder_out_length, |
| 38 | + OfflineStream **ss = nullptr, int32_t n = 0) = 0; | ||
| 37 | }; | 39 | }; |
| 38 | 40 | ||
| 39 | } // namespace sherpa_onnx | 41 | } // namespace sherpa_onnx |
| @@ -16,7 +16,9 @@ namespace sherpa_onnx { | @@ -16,7 +16,9 @@ namespace sherpa_onnx { | ||
| 16 | 16 | ||
| 17 | std::vector<OfflineTransducerDecoderResult> | 17 | std::vector<OfflineTransducerDecoderResult> |
| 18 | OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, | 18 | OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, |
| 19 | - Ort::Value encoder_out_length) { | 19 | + Ort::Value encoder_out_length, |
| 20 | + OfflineStream **ss /*= nullptr*/, | ||
| 21 | + int32_t n /*= 0*/) { | ||
| 20 | PackedSequence packed_encoder_out = PackPaddedSequence( | 22 | PackedSequence packed_encoder_out = PackPaddedSequence( |
| 21 | model_->Allocator(), &encoder_out, &encoder_out_length); | 23 | model_->Allocator(), &encoder_out, &encoder_out_length); |
| 22 | 24 |
| @@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | @@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 18 | : model_(model) {} | 18 | : model_(model) {} |
| 19 | 19 | ||
| 20 | std::vector<OfflineTransducerDecoderResult> Decode( | 20 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 21 | - Ort::Value encoder_out, Ort::Value encoder_out_length) override; | 21 | + Ort::Value encoder_out, Ort::Value encoder_out_length, |
| 22 | + OfflineStream **ss = nullptr, int32_t n = 0) override; | ||
| 22 | 23 | ||
| 23 | private: | 24 | private: |
| 24 | OfflineTransducerModel *model_; // Not owned | 25 | OfflineTransducerModel *model_; // Not owned |
| @@ -8,7 +8,9 @@ | @@ -8,7 +8,9 @@ | ||
| 8 | #include <utility> | 8 | #include <utility> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/context-graph.h" | ||
| 11 | #include "sherpa-onnx/csrc/hypothesis.h" | 12 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 13 | +#include "sherpa-onnx/csrc/log.h" | ||
| 12 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 13 | #include "sherpa-onnx/csrc/packed-sequence.h" | 15 | #include "sherpa-onnx/csrc/packed-sequence.h" |
| 14 | #include "sherpa-onnx/csrc/slice.h" | 16 | #include "sherpa-onnx/csrc/slice.h" |
| @@ -17,23 +19,39 @@ namespace sherpa_onnx { | @@ -17,23 +19,39 @@ namespace sherpa_onnx { | ||
| 17 | 19 | ||
| 18 | std::vector<OfflineTransducerDecoderResult> | 20 | std::vector<OfflineTransducerDecoderResult> |
| 19 | OfflineTransducerModifiedBeamSearchDecoder::Decode( | 21 | OfflineTransducerModifiedBeamSearchDecoder::Decode( |
| 20 | - Ort::Value encoder_out, Ort::Value encoder_out_length) { | 22 | + Ort::Value encoder_out, Ort::Value encoder_out_length, |
| 23 | + OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) { | ||
| 21 | PackedSequence packed_encoder_out = PackPaddedSequence( | 24 | PackedSequence packed_encoder_out = PackPaddedSequence( |
| 22 | model_->Allocator(), &encoder_out, &encoder_out_length); | 25 | model_->Allocator(), &encoder_out, &encoder_out_length); |
| 23 | 26 | ||
| 24 | int32_t batch_size = | 27 | int32_t batch_size = |
| 25 | static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); | 28 | static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); |
| 26 | 29 | ||
| 30 | + if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n); | ||
| 31 | + | ||
| 27 | int32_t vocab_size = model_->VocabSize(); | 32 | int32_t vocab_size = model_->VocabSize(); |
| 28 | int32_t context_size = model_->ContextSize(); | 33 | int32_t context_size = model_->ContextSize(); |
| 29 | 34 | ||
| 30 | std::vector<int64_t> blanks(context_size, 0); | 35 | std::vector<int64_t> blanks(context_size, 0); |
| 31 | - Hypotheses blank_hyp({{blanks, 0}}); | ||
| 32 | 36 | ||
| 33 | std::deque<Hypotheses> finalized; | 37 | std::deque<Hypotheses> finalized; |
| 34 | - std::vector<Hypotheses> cur(batch_size, blank_hyp); | 38 | + std::vector<Hypotheses> cur; |
| 35 | std::vector<Hypothesis> prev; | 39 | std::vector<Hypothesis> prev; |
| 36 | 40 | ||
| 41 | + std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr); | ||
| 42 | + | ||
| 43 | + for (int32_t i = 0; i < batch_size; ++i) { | ||
| 44 | + const ContextState *context_state; | ||
| 45 | + if (ss != nullptr) { | ||
| 46 | + context_graphs[i] = | ||
| 47 | + ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph(); | ||
| 48 | + if (context_graphs[i] != nullptr) | ||
| 49 | + context_state = context_graphs[i]->Root(); | ||
| 50 | + } | ||
| 51 | + Hypotheses blank_hyp({{blanks, 0, context_state}}); | ||
| 52 | + cur.emplace_back(std::move(blank_hyp)); | ||
| 53 | + } | ||
| 54 | + | ||
| 37 | int32_t start = 0; | 55 | int32_t start = 0; |
| 38 | int32_t t = 0; | 56 | int32_t t = 0; |
| 39 | for (auto n : packed_encoder_out.batch_sizes) { | 57 | for (auto n : packed_encoder_out.batch_sizes) { |
| @@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 106 | int32_t new_token = k % vocab_size; | 124 | int32_t new_token = k % vocab_size; |
| 107 | Hypothesis new_hyp = prev[hyp_index]; | 125 | Hypothesis new_hyp = prev[hyp_index]; |
| 108 | 126 | ||
| 127 | + float context_score = 0; | ||
| 128 | + auto context_state = new_hyp.context_state; | ||
| 109 | if (new_token != 0) { | 129 | if (new_token != 0) { |
| 110 | // blank id is fixed to 0 | 130 | // blank id is fixed to 0 |
| 111 | new_hyp.ys.push_back(new_token); | 131 | new_hyp.ys.push_back(new_token); |
| 112 | new_hyp.timestamps.push_back(t); | 132 | new_hyp.timestamps.push_back(t); |
| 133 | + if (context_graphs[i] != nullptr) { | ||
| 134 | + auto context_res = | ||
| 135 | + context_graphs[i]->ForwardOneStep(context_state, new_token); | ||
| 136 | + context_score = context_res.first; | ||
| 137 | + new_hyp.context_state = context_res.second; | ||
| 138 | + } | ||
| 113 | } | 139 | } |
| 114 | 140 | ||
| 115 | - new_hyp.log_prob = p_logprob[k]; | 141 | + new_hyp.log_prob = p_logprob[k] + context_score; |
| 116 | hyps.Add(std::move(new_hyp)); | 142 | hyps.Add(std::move(new_hyp)); |
| 117 | } // for (auto k : topk) | 143 | } // for (auto k : topk) |
| 118 | p_logprob += (end - start) * vocab_size; | 144 | p_logprob += (end - start) * vocab_size; |
| @@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 126 | cur.push_back(std::move(h)); | 152 | cur.push_back(std::move(h)); |
| 127 | } | 153 | } |
| 128 | 154 | ||
| 155 | + // Finalize context biasing matching.. | ||
| 156 | + for (int32_t i = 0; i < cur.size(); ++i) { | ||
| 157 | + for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) { | ||
| 158 | + if (context_graphs[i] != nullptr) { | ||
| 159 | + auto context_res = | ||
| 160 | + context_graphs[i]->Finalize(iter->second.context_state); | ||
| 161 | + iter->second.log_prob += context_res.first; | ||
| 162 | + iter->second.context_state = context_res.second; | ||
| 163 | + } | ||
| 164 | + } | ||
| 165 | + } | ||
| 166 | + | ||
| 129 | if (lm_) { | 167 | if (lm_) { |
| 130 | // use LM for rescoring | 168 | // use LM for rescoring |
| 131 | lm_->ComputeLMScore(lm_scale_, context_size, &cur); | 169 | lm_->ComputeLMScore(lm_scale_, context_size, &cur); |
| @@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder | @@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder | ||
| 26 | lm_scale_(lm_scale) {} | 26 | lm_scale_(lm_scale) {} |
| 27 | 27 | ||
| 28 | std::vector<OfflineTransducerDecoderResult> Decode( | 28 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 29 | - Ort::Value encoder_out, Ort::Value encoder_out_length) override; | 29 | + Ort::Value encoder_out, Ort::Value encoder_out_length, |
| 30 | + OfflineStream **ss = nullptr, int32_t n = 0) override; | ||
| 30 | 31 | ||
| 31 | private: | 32 | private: |
| 32 | OfflineTransducerModel *model_; // Not owned | 33 | OfflineTransducerModel *model_; // Not owned |
| @@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| 17 | .def(py::init<const OfflineFeatureExtractorConfig &, | 17 | .def(py::init<const OfflineFeatureExtractorConfig &, |
| 18 | const OfflineModelConfig &, const OfflineLMConfig &, | 18 | const OfflineModelConfig &, const OfflineLMConfig &, |
| 19 | - const std::string &, int32_t>(), | 19 | + const std::string &, int32_t, float>(), |
| 20 | py::arg("feat_config"), py::arg("model_config"), | 20 | py::arg("feat_config"), py::arg("model_config"), |
| 21 | py::arg("lm_config") = OfflineLMConfig(), | 21 | py::arg("lm_config") = OfflineLMConfig(), |
| 22 | py::arg("decoding_method") = "greedy_search", | 22 | py::arg("decoding_method") = "greedy_search", |
| 23 | - py::arg("max_active_paths") = 4) | 23 | + py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5) |
| 24 | .def_readwrite("feat_config", &PyClass::feat_config) | 24 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 25 | .def_readwrite("model_config", &PyClass::model_config) | 25 | .def_readwrite("model_config", &PyClass::model_config) |
| 26 | .def_readwrite("lm_config", &PyClass::lm_config) | 26 | .def_readwrite("lm_config", &PyClass::lm_config) |
| 27 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 27 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 28 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 28 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 29 | + .def_readwrite("context_score", &PyClass::context_score) | ||
| 29 | .def("__str__", &PyClass::ToString); | 30 | .def("__str__", &PyClass::ToString); |
| 30 | } | 31 | } |
| 31 | 32 | ||
| @@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) { | @@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) { | ||
| 35 | using PyClass = OfflineRecognizer; | 36 | using PyClass = OfflineRecognizer; |
| 36 | py::class_<PyClass>(*m, "OfflineRecognizer") | 37 | py::class_<PyClass>(*m, "OfflineRecognizer") |
| 37 | .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) | 38 | .def(py::init<const OfflineRecognizerConfig &>(), py::arg("config")) |
| 38 | - .def("create_stream", &PyClass::CreateStream) | 39 | + .def("create_stream", |
| 40 | + [](const PyClass &self) { return self.CreateStream(); }) | ||
| 41 | + .def( | ||
| 42 | + "create_stream", | ||
| 43 | + [](PyClass &self, | ||
| 44 | + const std::vector<std::vector<int32_t>> &contexts_list) { | ||
| 45 | + return self.CreateStream(contexts_list); | ||
| 46 | + }, | ||
| 47 | + py::arg("contexts_list")) | ||
| 39 | .def("decode_stream", &PyClass::DecodeStream) | 48 | .def("decode_stream", &PyClass::DecodeStream) |
| 40 | .def("decode_streams", | 49 | .def("decode_streams", |
| 41 | - [](PyClass &self, std::vector<OfflineStream *> ss) { | 50 | + [](const PyClass &self, std::vector<OfflineStream *> ss) { |
| 42 | self.DecodeStreams(ss.data(), ss.size()); | 51 | self.DecodeStreams(ss.data(), ss.size()); |
| 43 | }); | 52 | }); |
| 44 | } | 53 | } |
| 1 | +from typing import Dict, List, Optional | ||
| 2 | + | ||
| 1 | from _sherpa_onnx import Display | 3 | from _sherpa_onnx import Display |
| 2 | 4 | ||
| 3 | from .online_recognizer import OnlineRecognizer | 5 | from .online_recognizer import OnlineRecognizer |
| 4 | from .online_recognizer import OnlineStream | 6 | from .online_recognizer import OnlineStream |
| 5 | from .offline_recognizer import OfflineRecognizer | 7 | from .offline_recognizer import OfflineRecognizer |
| 8 | + | ||
| 9 | +from .utils import encode_contexts | ||
| 10 | + | ||
| 11 | + | ||
| 12 | + |
| 1 | # Copyright (c) 2023 by manyeyes | 1 | # Copyright (c) 2023 by manyeyes |
| 2 | from pathlib import Path | 2 | from pathlib import Path |
| 3 | -from typing import List | 3 | +from typing import List, Optional |
| 4 | 4 | ||
| 5 | from _sherpa_onnx import ( | 5 | from _sherpa_onnx import ( |
| 6 | OfflineFeatureExtractorConfig, | 6 | OfflineFeatureExtractorConfig, |
| @@ -39,6 +39,7 @@ class OfflineRecognizer(object): | @@ -39,6 +39,7 @@ class OfflineRecognizer(object): | ||
| 39 | sample_rate: int = 16000, | 39 | sample_rate: int = 16000, |
| 40 | feature_dim: int = 80, | 40 | feature_dim: int = 80, |
| 41 | decoding_method: str = "greedy_search", | 41 | decoding_method: str = "greedy_search", |
| 42 | + context_score: float = 1.5, | ||
| 42 | debug: bool = False, | 43 | debug: bool = False, |
| 43 | provider: str = "cpu", | 44 | provider: str = "cpu", |
| 44 | ): | 45 | ): |
| @@ -96,6 +97,7 @@ class OfflineRecognizer(object): | @@ -96,6 +97,7 @@ class OfflineRecognizer(object): | ||
| 96 | feat_config=feat_config, | 97 | feat_config=feat_config, |
| 97 | model_config=model_config, | 98 | model_config=model_config, |
| 98 | decoding_method=decoding_method, | 99 | decoding_method=decoding_method, |
| 100 | + context_score=context_score, | ||
| 99 | ) | 101 | ) |
| 100 | self.recognizer = _Recognizer(recognizer_config) | 102 | self.recognizer = _Recognizer(recognizer_config) |
| 101 | return self | 103 | return self |
| @@ -216,8 +218,11 @@ class OfflineRecognizer(object): | @@ -216,8 +218,11 @@ class OfflineRecognizer(object): | ||
| 216 | self.recognizer = _Recognizer(recognizer_config) | 218 | self.recognizer = _Recognizer(recognizer_config) |
| 217 | return self | 219 | return self |
| 218 | 220 | ||
| 219 | - def create_stream(self): | ||
| 220 | - return self.recognizer.create_stream() | 221 | + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): |
| 222 | + if contexts_list is None: | ||
| 223 | + return self.recognizer.create_stream() | ||
| 224 | + else: | ||
| 225 | + return self.recognizer.create_stream(contexts_list) | ||
| 221 | 226 | ||
| 222 | def decode_stream(self, s: OfflineStream): | 227 | def decode_stream(self, s: OfflineStream): |
| 223 | self.recognizer.decode_stream(s) | 228 | self.recognizer.decode_stream(s) |
sherpa-onnx/python/sherpa_onnx/utils.py
0 → 100644
| 1 | +from typing import Dict, List, Optional | ||
| 2 | + | ||
| 3 | + | ||
| 4 | +def encode_contexts( | ||
| 5 | + modeling_unit: str, | ||
| 6 | + contexts: List[str], | ||
| 7 | + sp: Optional["SentencePieceProcessor"] = None, | ||
| 8 | + tokens_table: Optional[Dict[str, int]] = None, | ||
| 9 | +) -> List[List[int]]: | ||
| 10 | + """ | ||
| 11 | + Encode the given contexts (a list of string) to a list of a list of token ids. | ||
| 12 | + | ||
| 13 | + Args: | ||
| 14 | + modeling_unit: | ||
| 15 | + The valid values are bpe, char, bpe+char. | ||
| 16 | + Note: char here means characters in CJK languages, not English like languages. | ||
| 17 | + contexts: | ||
| 18 | + The given contexts list (a list of string). | ||
| 19 | + sp: | ||
| 20 | + An instance of SentencePieceProcessor. | ||
| 21 | + tokens_table: | ||
| 22 | + The tokens_table containing the tokens and the corresponding ids. | ||
| 23 | + Returns: | ||
| 24 | + Return the contexts_list, it is a list of a list of token ids. | ||
| 25 | + """ | ||
| 26 | + contexts_list = [] | ||
| 27 | + if "bpe" in modeling_unit: | ||
| 28 | + assert sp is not None | ||
| 29 | + if "char" in modeling_unit: | ||
| 30 | + assert tokens_table is not None | ||
| 31 | + assert len(tokens_table) > 0, len(tokens_table) | ||
| 32 | + | ||
| 33 | + if "char" == modeling_unit: | ||
| 34 | + for context in contexts: | ||
| 35 | + assert ' ' not in context | ||
| 36 | + ids = [ | ||
| 37 | + tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"] | ||
| 38 | + for txt in context | ||
| 39 | + ] | ||
| 40 | + contexts_list.append(ids) | ||
| 41 | + elif "bpe" == modeling_unit: | ||
| 42 | + contexts_list = sp.encode(contexts, out_type=int) | ||
| 43 | + else: | ||
| 44 | + assert modeling_unit == "bpe+char", modeling_unit | ||
| 45 | + | ||
| 46 | + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: | ||
| 47 | + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | ||
| 48 | + pattern = re.compile(r"([\u4e00-\u9fff])") | ||
| 49 | + for context in contexts: | ||
| 50 | + # Example: | ||
| 51 | + # txt = "你好 ITS'S OKAY 的" | ||
| 52 | + # chars = ["你", "好", " ITS'S OKAY ", "的"] | ||
| 53 | + chars = pattern.split(context.upper()) | ||
| 54 | + mix_chars = [w for w in chars if len(w.strip()) > 0] | ||
| 55 | + ids = [] | ||
| 56 | + for ch_or_w in mix_chars: | ||
| 57 | + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. | ||
| 58 | + if pattern.fullmatch(ch_or_w) is not None: | ||
| 59 | + ids.append( | ||
| 60 | + tokens_table[ch_or_w] | ||
| 61 | + if ch_or_w in tokens_table | ||
| 62 | + else tokens_table["<unk>"] | ||
| 63 | + ) | ||
| 64 | + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), | ||
| 65 | + # encode ch_or_w using bpe_model. | ||
| 66 | + else: | ||
| 67 | + for p in sp.encode_as_pieces(ch_or_w): | ||
| 68 | + ids.append( | ||
| 69 | + tokens_table[p] | ||
| 70 | + if p in tokens_table | ||
| 71 | + else tokens_table["<unk>"] | ||
| 72 | + ) | ||
| 73 | + contexts_list.append(ids) | ||
| 74 | + return contexts_list |
-
请 注册 或 登录 后发表评论