Wei Kang
Committed by GitHub

Support contextual-biasing for streaming model (#184)

* Support contextual-biasing for streaming model

* The whole pipeline runs normally

* Fix comments
... ... @@ -20,9 +20,10 @@ import argparse
import time
import wave
from pathlib import Path
from typing import Tuple
from typing import List, Tuple
import numpy as np
import sentencepiece as spm
import sherpa_onnx
... ... @@ -70,6 +71,59 @@ def get_args():
)
parser.add_argument(
"--max-active-paths",
type=int,
default=4,
help="""Used only when --decoding-method is modified_beam_search.
It specifies number of active paths to keep during decoding.
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="",
help="""
Path to bpe.model, it will be used to tokenize contexts biasing phrases.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="char",
help="""
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--contexts",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--context-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
... ... @@ -116,6 +170,27 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
return samples_float32, f.getframerate()
def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit,
contexts=contexts,
sp=sp,
tokens_table=tokens,
)
def main():
args = get_args()
assert_file_exists(args.encoder)
... ... @@ -132,11 +207,20 @@ def main():
sample_rate=16000,
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
)
print("Started!")
start_time = time.time()
contexts_list = []
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)
streams = []
total_duration = 0
for wave_filename in args.sound_files:
... ... @@ -145,7 +229,11 @@ def main():
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
if contexts_list:
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
... ...
... ... @@ -88,6 +88,9 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
... ... @@ -115,6 +118,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
... ... @@ -166,10 +170,37 @@ class OnlineRecognizer::Impl {
}
#endif
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
std::unique_ptr<OnlineStream> CreateStream() const {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetResult(decoder_->GetEmptyResult());
stream->SetStates(model_->GetEncoderInitStates());
InitOnlineStream(stream.get());
return stream;
}
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
auto context_graph =
std::make_shared<ContextGraph>(contexts, config_.context_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
return stream;
}
... ... @@ -188,8 +219,12 @@ class OnlineRecognizer::Impl {
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<int64_t> all_processed_frames(n);
bool has_context_graph = false;
for (int32_t i = 0; i != n; ++i) {
if (!has_context_graph && ss[i]->GetContextGraph())
has_context_graph = true;
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
ss[i]->GetFrames(num_processed_frames, chunk_size);
... ... @@ -226,7 +261,11 @@ class OnlineRecognizer::Impl {
auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(processed_frames));
decoder_->Decode(std::move(pair.first), &results);
if (has_context_graph) {
decoder_->Decode(std::move(pair.first), ss, &results);
} else {
decoder_->Decode(std::move(pair.first), &results);
}
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second);
... ... @@ -297,6 +336,11 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
... ...
... ... @@ -75,7 +75,10 @@ struct OnlineRecognizerConfig {
std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
// used only for modified_beam_search
int32_t max_active_paths = 4;
/// used only for modified_beam_search
float context_score = 1.5;
OnlineRecognizerConfig() = default;
... ... @@ -85,13 +88,14 @@ struct OnlineRecognizerConfig {
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
int32_t max_active_paths, float context_score)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
max_active_paths(max_active_paths),
context_score(context_score) {}
void Register(ParseOptions *po);
bool Validate() const;
... ... @@ -112,6 +116,10 @@ class OnlineRecognizer {
/// Create a stream for decoding.
std::unique_ptr<OnlineStream> CreateStream() const;
// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
/**
* Return true if the given stream has enough frames for decoding.
* Return false otherwise
... ...
... ... @@ -13,8 +13,9 @@ namespace sherpa_onnx {
class OnlineStream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config)
: feat_extractor_(config) {}
explicit Impl(const FeatureExtractorConfig &config,
ContextGraphPtr context_graph)
: feat_extractor_(config), context_graph_(context_graph) {}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
... ... @@ -54,16 +55,21 @@ class OnlineStream::Impl {
std::vector<Ort::Value> &GetStates() { return states_; }
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
ContextGraphPtr context_graph_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
ContextGraphPtr context_graph /*= nullptr */)
: impl_(std::make_unique<Impl>(config, context_graph)) {}
OnlineStream::~OnlineStream() = default;
... ... @@ -109,4 +115,8 @@ std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}
const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
} // namespace sherpa_onnx
... ...
... ... @@ -9,6 +9,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
... ... @@ -16,7 +17,8 @@ namespace sherpa_onnx {
class OnlineStream {
public:
explicit OnlineStream(const FeatureExtractorConfig &config = {});
explicit OnlineStream(const FeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = nullptr);
~OnlineStream();
/**
... ... @@ -71,6 +73,13 @@ class OnlineStream {
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
/**
* Get the context graph corresponding to this stream.
*
* @return Return the context graph for this stream.
*/
const ContextGraphPtr &GetContextGraph() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
... ... @@ -9,6 +9,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
... ... @@ -45,6 +46,7 @@ struct OnlineTransducerDecoderResult {
OnlineTransducerDecoderResult &&other);
};
class OnlineStream;
class OnlineTransducerDecoder {
public:
virtual ~OnlineTransducerDecoder() = default;
... ... @@ -76,6 +78,26 @@ class OnlineTransducerDecoder {
virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0;
/** Run transducer beam search given the output from the encoder model.
*
* Note: Currently this interface is for contextual-biasing feature which
* needs a ContextGraph owned by the OnlineStream.
*
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
* @param ss A list of OnlineStreams.
* @param result It is modified in-place.
*
* @note There is no need to pass encoder_out_length here since for the
* online decoding case, each utterance has the same number of frames
* and there are no paddings.
*/
virtual void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) {
SHERPA_ONNX_LOGE(
"This interface is for OnlineTransducerModifiedBeamSearchDecoder.");
exit(-1);
}
// used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
};
... ...
... ... @@ -9,6 +9,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
... ... @@ -62,6 +63,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) {
Decode(std::move(encoder_out), nullptr, result);
}
void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) {
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
... ... @@ -74,6 +81,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
... ... @@ -142,18 +150,27 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index];
const float prev_lm_log_prob = new_hyp.lm_log_prob;
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, &new_hyp);
}
} else {
++new_hyp.num_trailing_blanks;
}
new_hyp.log_prob =
p_logprob[k] - prev_lm_log_prob; // log_prob only includes the
new_hyp.log_prob = p_logprob[k] + context_score -
prev_lm_log_prob; // log_prob only includes the
// score of the transducer
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
... ...
... ... @@ -9,6 +9,7 @@
#include <vector>
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
... ... @@ -33,6 +34,9 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override;
void Decode(Ort::Value encoder_out, OnlineStream **ss,
std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private:
... ...
... ... @@ -22,18 +22,19 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
const EndpointConfig &, bool, const std::string &,
int32_t>(),
const EndpointConfig &, bool, const std::string &, int32_t,
float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths"))
py::arg("max_active_paths"), py::arg("context_score"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("context_score", &PyClass::context_score)
.def("__str__", &PyClass::ToString);
}
... ... @@ -44,7 +45,15 @@ void PybindOnlineRecognizer(py::module *m) {
using PyClass = OnlineRecognizer;
py::class_<PyClass>(*m, "OnlineRecognizer")
.def(py::init<const OnlineRecognizerConfig &>(), py::arg("config"))
.def("create_stream", &PyClass::CreateStream)
.def("create_stream",
[](const PyClass &self) { return self.CreateStream(); })
.def(
"create_stream",
[](PyClass &self,
const std::vector<std::vector<int32_t>> &contexts_list) {
return self.CreateStream(contexts_list);
},
py::arg("contexts_list"))
.def("is_ready", &PyClass::IsReady)
.def("decode_stream", &PyClass::DecodeStream)
.def("decode_streams",
... ...
# Copyright (c) 2023 Xiaomi Corporation
from pathlib import Path
from typing import List
from typing import List, Optional
from _sherpa_onnx import (
EndpointConfig,
... ... @@ -39,6 +39,7 @@ class OnlineRecognizer(object):
rule3_min_utterance_length: float = 20.0,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
context_score: float = 1.5,
provider: str = "cpu",
):
"""
... ... @@ -124,13 +125,17 @@ class OnlineRecognizer(object):
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
context_score=context_score,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
def create_stream(self):
return self.recognizer.create_stream()
def create_stream(self, contexts_list : Optional[List[List[int]]] = None):
if contexts_list is None:
return self.recognizer.create_stream()
else:
return self.recognizer.create_stream(contexts_list)
def decode_stream(self, s: OnlineStream):
self.recognizer.decode_stream(s)
... ...