Wei Kang
Committed by GitHub

Implement context biasing with a Aho Corasick automata (#145)

* Implement context graph

* Modify the interface to support context biasing

* Support context biasing in modified beam search; add python wrapper

* Support context biasing in python api example

* Minor fixes

* Fix context graph

* Minor fixes

* Fix tests

* Fix style

* Fix style

* Fix comments

* Minor fixes

* Add missing header

* Replace std::shared_ptr with std::unique_ptr for effciency

* Build graph in constructor

* Fix comments

* Minor fixes

* Fix docs
... ... @@ -54,7 +54,7 @@ jobs:
- name: Install Python dependencies
shell: bash
run: |
python3 -m pip install --upgrade pip numpy
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96
- name: Install sherpa-onnx
shell: bash
... ...
... ... @@ -43,9 +43,10 @@ import argparse
import time
import wave
from pathlib import Path
from typing import Tuple
from typing import List, Tuple
import numpy as np
import sentencepiece as spm
import sherpa_onnx
... ... @@ -61,6 +62,47 @@ def get_args():
)
parser.add_argument(
"--bpe-model",
type=str,
default="",
help="""
Path to bpe.model,
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
""",
)
parser.add_argument(
"--context-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
... ... @@ -153,6 +195,24 @@ def assert_file_exists(filename: str):
)
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
... ... @@ -182,10 +242,17 @@ def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
contexts_list = []
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
... ... @@ -199,6 +266,7 @@ def main():
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
context_score=args.context_score,
debug=args.debug,
)
elif args.paraformer:
... ... @@ -238,8 +306,12 @@ def main():
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
if contexts_list:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
... ...
... ... @@ -37,6 +37,7 @@ with open("sherpa-onnx/python/sherpa_onnx/__init__.py", "a") as f:
install_requires = [
"numpy",
"sentencepiece==0.1.96",
]
... ...
... ... @@ -12,6 +12,7 @@ endif()
set(sources
cat.cc
context-graph.cc
endpoint.cc
features.cc
file-utils.cc
... ... @@ -248,6 +249,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TESTS)
set(sherpa_onnx_test_srcs
cat-test.cc
context-graph-test.cc
packed-sequence-test.cc
pad-sequence-test.cc
slice-test.cc
... ...
// sherpa-onnx/csrc/context-graph-test.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <map>
#include <string>
#include <vector>
#include "gtest/gtest.h"
namespace sherpa_onnx {
TEST(ContextGraph, TestBasic) {
std::vector<std::string> contexts_str(
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
std::vector<std::vector<int32_t>> contexts;
for (int32_t i = 0; i < contexts_str.size(); ++i) {
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
}
auto context_graph = ContextGraph(contexts, 1);
auto queries = std::map<std::string, float>{
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, {"SHED", 6},
{"HELL", 2}, {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
for (const auto &iter : queries) {
float total_scores = 0;
auto state = context_graph.Root();
for (auto q : iter.first) {
auto res = context_graph.ForwardOneStep(state, q);
total_scores += res.first;
state = res.second;
}
auto res = context_graph.Finalize(state);
EXPECT_EQ(res.second->token, -1);
total_scores += res.first;
EXPECT_EQ(total_scores, iter.second);
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/context-graph.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/context-graph.h"
#include <cassert>
#include <queue>
#include <utility>
namespace sherpa_onnx {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? 0 : node->local_node_score + context_score_, is_end);
}
node = node->next[token].get();
}
}
FillFailOutput();
}
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
node = state->next.at(token).get();
score = node->token_score;
if (state->is_end) score += state->node_score;
} else {
node = state->fail;
while (0 == node->next.count(token)) {
node = node->fail;
if (-1 == node->token) break; // root
}
if (1 == node->next.count(token)) {
node = node->next.at(token).get();
}
score = node->node_score - state->local_node_score;
}
SHERPA_ONNX_CHECK(nullptr != node);
float matched_score = 0;
auto output = node->output;
while (nullptr != output) {
matched_score += output->node_score;
output = output->output;
}
return std::make_pair(score + matched_score, node);
}
std::pair<float, const ContextState *> ContextGraph::Finalize(
const ContextState *state) const {
float score = -state->node_score;
if (state->is_end) {
score = 0;
}
return std::make_pair(score, root_.get());
}
void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
kv.second->fail = root_.get();
node_queue.push(kv.second.get());
}
while (!node_queue.empty()) {
auto current_node = node_queue.front();
node_queue.pop();
for (auto &kv : current_node->next) {
auto fail = current_node->fail;
if (1 == fail->next.count(kv.first)) {
fail = fail->next.at(kv.first).get();
} else {
fail = fail->fail;
while (0 == fail->next.count(kv.first)) {
fail = fail->fail;
if (-1 == fail->token) break;
}
if (1 == fail->next.count(kv.first))
fail = fail->next.at(kv.first).get();
}
kv.second->fail = fail;
// fill the output arc
auto output = fail;
while (!output->is_end) {
output = output->fail;
if (-1 == output->token) {
output = nullptr;
break;
}
}
kv.second->output = output;
node_queue.push(kv.second.get());
}
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/context-graph.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
namespace sherpa_onnx {
class ContextGraph;
using ContextGraphPtr = std::shared_ptr<ContextGraph>;
struct ContextState {
int32_t token;
float token_score;
float node_score;
float local_node_score;
bool is_end;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;
ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float local_node_score, bool is_end)
: token(token),
token_score(token_score),
node_score(node_score),
local_node_score(local_node_score),
is_end(is_end) {}
};
class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score)
: context_score_(context_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
root_->fail = root_.get();
Build(token_ids);
}
std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;
const ContextState *Root() const { return root_.get(); }
private:
float context_score_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void FillFailOutput() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
... ...
... ... @@ -14,6 +14,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -39,11 +40,18 @@ struct Hypothesis {
// the nn lm states
std::vector<CopyableOrtValue> nn_lm_states;
const ContextState *context_state;
// TODO(fangjun): Make it configurable
// the minimum of tokens in a chunk for streaming RNN LM
int32_t lm_rescore_min_chunk = 2; // a const
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
Hypothesis(const std::vector<int64_t> &ys, double log_prob,
const ContextState *context_state = nullptr)
: ys(ys), log_prob(log_prob), context_state(context_state) {}
double TotalLogProb() const { return log_prob + lm_log_prob; }
... ...
... ... @@ -6,7 +6,9 @@
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_
#include <memory>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
... ... @@ -19,6 +21,12 @@ class OfflineRecognizerImpl {
virtual ~OfflineRecognizerImpl() = default;
virtual std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
... ...
... ... @@ -10,6 +10,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
... ... @@ -72,6 +73,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
}
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(context_list, config_.context_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
... ... @@ -117,7 +128,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
-23.025850929940457f);
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
auto results =
decoder_->Decode(std::move(t.first), std::move(t.second), ss, n);
int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
... ...
... ... @@ -26,6 +26,9 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
}
bool OfflineRecognizerConfig::Validate() const {
... ... @@ -49,7 +52,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ")";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ")";
return os.str();
}
... ... @@ -59,6 +63,11 @@ OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
... ...
... ... @@ -26,6 +26,7 @@ struct OfflineRecognizerConfig {
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
float context_score = 1.5;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
... ... @@ -34,12 +35,13 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths)
int32_t max_active_paths, float context_score)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po);
bool Validate() const;
... ... @@ -58,6 +60,10 @@ class OfflineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream() const;
/// Create a stream for decoding.
std::unique_ptr<OfflineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/** Decode a single stream
*
* @param s The stream to decode.
... ...
... ... @@ -75,7 +75,9 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
class OfflineStream::Impl {
public:
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
explicit Impl(const OfflineFeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: config_(config), context_graph_(context_graph) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
... ... @@ -152,6 +154,8 @@ class OfflineStream::Impl {
const OfflineRecognitionResult &GetResult() const { return r_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
... ... @@ -189,11 +193,13 @@ class OfflineStream::Impl {
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
};
OfflineStream::OfflineStream(
const OfflineFeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
const OfflineFeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::~OfflineStream() = default;
... ... @@ -212,6 +218,10 @@ void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
impl_->SetResult(r);
}
const ContextGraphPtr &OfflineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
const OfflineRecognitionResult &OfflineStream::GetResult() const {
return impl_->GetResult();
}
... ...
... ... @@ -10,6 +10,7 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
... ... @@ -66,7 +67,8 @@ struct OfflineFeatureExtractorConfig {
class OfflineStream {
public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
~OfflineStream();
/**
... ... @@ -96,6 +98,9 @@ class OfflineStream {
/** Get the recognition result of this stream */
const OfflineRecognitionResult &GetResult() const;
/** Get the ContextGraph of this stream */
const ContextGraphPtr &GetContextGraph() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-stream.h"
namespace sherpa_onnx {
... ... @@ -33,7 +34,8 @@ class OfflineTransducerDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) = 0;
};
} // namespace sherpa_onnx
... ...
... ... @@ -16,7 +16,9 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
Ort::Value encoder_out_length) {
Ort::Value encoder_out_length,
OfflineStream **ss /*= nullptr*/,
int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
... ...
... ... @@ -18,7 +18,8 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
: model_(model) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private:
OfflineTransducerModel *model_; // Not owned
... ...
... ... @@ -8,7 +8,9 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
... ... @@ -17,23 +19,39 @@ namespace sherpa_onnx {
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) {
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n);
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, 0);
Hypotheses blank_hyp({{blanks, 0}});
std::deque<Hypotheses> finalized;
std::vector<Hypotheses> cur(batch_size, blank_hyp);
std::vector<Hypotheses> cur;
std::vector<Hypothesis> prev;
std::vector<ContextGraphPtr> context_graphs(batch_size, nullptr);
for (int32_t i = 0; i < batch_size; ++i) {
const ContextState *context_state;
if (ss != nullptr) {
context_graphs[i] =
ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph();
if (context_graphs[i] != nullptr)
context_state = context_graphs[i]->Root();
}
Hypotheses blank_hyp({{blanks, 0, context_state}});
cur.emplace_back(std::move(blank_hyp));
}
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
... ... @@ -106,13 +124,21 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
// blank id is fixed to 0
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t);
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
}
new_hyp.log_prob = p_logprob[k];
new_hyp.log_prob = p_logprob[k] + context_score;
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
p_logprob += (end - start) * vocab_size;
... ... @@ -126,6 +152,18 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(h));
}
// Finalize context biasing matching..
for (int32_t i = 0; i < cur.size(); ++i) {
for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) {
if (context_graphs[i] != nullptr) {
auto context_res =
context_graphs[i]->Finalize(iter->second.context_state);
iter->second.log_prob += context_res.first;
iter->second.context_state = context_res.second;
}
}
}
if (lm_) {
// use LM for rescoring
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
... ...
... ... @@ -26,7 +26,8 @@ class OfflineTransducerModifiedBeamSearchDecoder
lm_scale_(lm_scale) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
Ort::Value encoder_out, Ort::Value encoder_out_length,
OfflineStream **ss = nullptr, int32_t n = 0) override;
private:
OfflineTransducerModel *model_; // Not owned
... ...
... ... @@ -16,16 +16,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t>(),
const std::string &, int32_t, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4)
py::arg("max_active_paths") = 4, py::arg("context_score") = 1.5)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def("__str__", &PyClass::ToString);
}
... ... @@ -35,10 +36,18 @@ void PybindOfflineRecognizer(py::module *m) {
using PyClass = OfflineRecognizer;
py::class_<PyClass>(*m, "OfflineRecognizer")
.def(py::init<const OfflineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", &PyClass::CreateStream)
.def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
},
py::arg("contexts_list"))
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
[](PyClass &self, std::vector<OfflineStream *> ss) {
[](const PyClass &self, std::vector<OfflineStream *> ss) {
self.DecodeStreams(ss.data(), ss.size());
});
}
... ...
from typing import Dict, List, Optional
from _sherpa_onnx import Display
from .online_recognizer import OnlineRecognizer
from .online_recognizer import OnlineStream
from .offline_recognizer import OfflineRecognizer
from .utils import encode_contexts
... ...
# Copyright (c) 2023 by manyeyes
from pathlib import Path
from typing import List
from typing import List, Optional
from _sherpa_onnx import (
OfflineFeatureExtractorConfig,
... ... @@ -39,6 +39,7 @@ class OfflineRecognizer(object):
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
context_score: float = 1.5,
debug: bool = False,
provider: str = "cpu",
):
... ... @@ -96,6 +97,7 @@ class OfflineRecognizer(object):
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
context_score=context_score,
)
self.recognizer = _Recognizer(recognizer_config)
return self
... ... @@ -216,8 +218,11 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self):
return self.recognizer.create_stream()
def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OfflineStream):
self.recognizer.decode_stream(s)
... ...
from typing import Dict, List, Optional
def encode_contexts(
modeling_unit: str,
contexts: List[str],
sp: Optional["SentencePieceProcessor"] = None,
tokens_table: Optional[Dict[str, int]] = None,
) -> List[List[int]]:
"""
Encode the given contexts (a list of string) to a list of a list of token ids.
Args:
modeling_unit:
The valid values are bpe, char, bpe+char.
Note: char here means characters in CJK languages, not English like languages.
contexts:
The given contexts list (a list of string).
sp:
An instance of SentencePieceProcessor.
tokens_table:
The tokens_table containing the tokens and the corresponding ids.
Returns:
Return the contexts_list, it is a list of a list of token ids.
"""
contexts_list = []
if "bpe" in modeling_unit:
assert sp is not None
if "char" in modeling_unit:
assert tokens_table is not None
assert len(tokens_table) > 0, len(tokens_table)
if "char" == modeling_unit:
for context in contexts:
assert ' ' not in context
ids = [
tokens_table[txt] if txt in tokens_table else tokens_table["<unk>"]
for txt in context
]
contexts_list.append(ids)
elif "bpe" == modeling_unit:
contexts_list = sp.encode(contexts, out_type=int)
else:
assert modeling_unit == "bpe+char", modeling_unit
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
for context in contexts:
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(context.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
ids = []
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
ids.append(
tokens_table[ch_or_w]
if ch_or_w in tokens_table
else tokens_table["<unk>"]
)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
ids.append(
tokens_table[p]
if p in tokens_table
else tokens_table["<unk>"]
)
contexts_list.append(ids)
return contexts_list
... ...