Askars Salimbajevs
Committed by GitHub

Add LODR support to online and offline recognizers (#2026)

This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore.

- Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id.
- Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths.
- Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options.
... ... @@ -281,7 +281,39 @@ time $EXE \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
lm_repo=$(basename $lm_repo_url)
pushd $lm_repo
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
popd
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
log "Download bi-gram LM from ${bigram_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
bigramlm_repo=$(basename $bigram_repo_url)
pushd $bigramlm_repo
git lfs pull --include "2gram.fst"
popd
log "Start testing with LM and bi-gram LODR"
# TODO: find test examples that change with the LODR
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
--num-threads=2 \
--decoding_method="modified_beam_search" \
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
--lodr-fst=$bigramlm_repo/2gram.fst \
--lodr-scale=-0.5 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo $lm_repo $bigramlm_repo
log "------------------------------------------------------------"
log "Run Paraformer (Chinese)"
... ...
... ... @@ -174,7 +174,60 @@ for wave in ${waves[@]}; do
$wave
done
rm -rf $repo
lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
lm_repo=$(basename $lm_repo_url)
pushd $lm_repo
git lfs pull --include "with-state-epoch-99-avg-1.onnx"
popd
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
log "Download bi-gram LM from ${bigram_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
bigramlm_repo=$(basename $bigram_repo_url)
pushd $bigramlm_repo
git lfs pull --include "2gram.fst"
popd
log "Start testing LODR"
waves=(
$repo/test_wavs/0.wav
$repo/test_wavs/1.wav
$repo/test_wavs/8k.wav
)
for wave in ${waves[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
--num-threads=2 \
--decoding_method="modified_beam_search" \
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
--lodr-fst=$bigramlm_repo/2gram.fst \
--lodr-scale=-0.5 \
$wave
done
for wave in ${waves[@]}; do
time $EXE \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
--num-threads=2 \
--decoding_method="modified_beam_search" \
--lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
--lodr-fst=$bigramlm_repo/2gram.fst \
--lodr-scale=-0.5 \
--lm-shallow-fusion=true \
$wave
done
rm -rf $repo $bigramlm_repo $lm_repo
log "------------------------------------------------------------"
log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"
... ...
... ... @@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
log "Download pre-trained RNN-LM model from ${lm_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
lm_repo=$(basename $lm_repo_url)
pushd $lm_repo
git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
popd
bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
log "Download bi-gram LM from ${bigram_repo_url}"
GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
bigramlm_repo=$(basename $bigram_repo_url)
pushd $bigramlm_repo
git lfs pull --include "2gram.fst"
popd
log "Perform offline decoding with RNN-LM and LODR"
python3 ./python-api-examples/offline-decode-files.py \
--tokens=$repo/tokens.txt \
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
--decoding-method=modified_beam_search \
--lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
--lodr-fst=$bigramlm_repo/2gram.fst \
--lodr-scale=-0.5 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $repo
rm -rf $repo $lm_repo $bigramlm_repo
log "Test non-streaming paraformer models"
... ...
... ... @@ -35,6 +35,25 @@ file(s) with a non-streaming model.
/path/to/0.wav \
/path/to/1.wav
also with RNN LM rescoring and LODR (optional):
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--decoding-method=modified_beam_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
--lm=/path/to/lm.onnx \
--lm-scale=0.1 \
--lodr-fst=/path/to/lodr.fst \
--lodr-scale=-0.1 \
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
... ... @@ -269,6 +288,39 @@ def get_args():
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--lm",
metavar="file",
type=str,
default="",
help="Path to RNN LM model",
)
parser.add_argument(
"--lm-scale",
metavar="lm_scale",
type=float,
default=0.1,
help="LM model scale for rescoring",
)
parser.add_argument(
"--lodr-fst",
metavar="file",
type=str,
default="",
help="Path to LODR FST model. Used only when --lm is given.",
)
parser.add_argument(
"--lodr-scale",
metavar="lodr_scale",
type=float,
default=-0.1,
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
)
parser.add_argument(
"--debug",
type=bool,
... ... @@ -364,6 +416,10 @@ def main():
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
lm=args.lm,
lm_scale=args.lm_scale,
lodr_fst=args.lodr_fst,
lodr_scale=args.lodr_scale,
decoding_method=args.decoding_method,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
... ...
... ... @@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
or with RNN LM rescoring and LODR:
./python-api-examples/online-decode-files.py \
--tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
--encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
--joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
--decoding-method=modified_beam_search \
--lm=/path/to/lm.onnx \
--lm-scale=0.1 \
--lodr-fst=/path/to/lodr.fst \
--lodr-scale=-0.1 \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
(2) Streaming paraformer
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
... ... @@ -187,6 +203,22 @@ def get_args():
)
parser.add_argument(
"--lodr-fst",
metavar="file",
type=str,
default="",
help="Path to LODR FST model. Used only when --lm is given.",
)
parser.add_argument(
"--lodr-scale",
metavar="lodr_scale",
type=float,
default=-0.1,
help="LODR scale for rescoring.Used only when --lodr_fst is given.",
)
parser.add_argument(
"--provider",
type=str,
default="cpu",
... ... @@ -320,6 +352,8 @@ def main():
max_active_paths=args.max_active_paths,
lm=args.lm,
lm_scale=args.lm_scale,
lodr_fst=args.lodr_fst,
lodr_scale=args.lodr_scale,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
modeling_unit=args.modeling_unit,
... ...
... ... @@ -25,6 +25,7 @@ set(sources
jieba.cc
keyword-spotter-impl.cc
keyword-spotter.cc
lodr-fst.cc
offline-canary-model-config.cc
offline-canary-model.cc
offline-ctc-fst-decoder-config.cc
... ...
... ... @@ -12,9 +12,11 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/lodr-fst.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -61,6 +63,9 @@ struct Hypothesis {
// the nn lm states
std::vector<CopyableOrtValue> nn_lm_states;
// the LODR states
std::shared_ptr<LodrStateCost> lodr_state;
const ContextState *context_state;
// TODO(fangjun): Make it configurable
... ...
// sherpa-onnx/csrc/lodr-fst.cc
//
// Contains code copied from icefall/utils/ngram_lm.py
// Copyright (c) 2023 Xiaomi Corporation
//
// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/lodr-fst.h"
#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
int32_t LodrFst::FindBackoffId() {
// assume that the backoff id is the only input label with epsilon output
for (int32_t state = 0; state < fst_->NumStates(); ++state) {
fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
for ( ; !arc_iter.Done(); arc_iter.Next()) {
const auto& arc = arc_iter.Value();
if (arc.olabel == 0) { // Check if the output label is epsilon (0)
return arc.ilabel; // Return the input label
}
}
}
return -1; // Return -1 if no such input symbol is found
}
LodrFst::LodrFst(const std::string &fst_path, int32_t backoff_id)
: backoff_id_(backoff_id) {
fst_ = std::unique_ptr<fst::StdConstFst>(
CastOrConvertToConstFst(fst::StdVectorFst::Read(fst_path)));
if (backoff_id < 0) {
// backoff_id_ is not provided, find it automatically
backoff_id_ = FindBackoffId();
if (backoff_id_ < 0) {
std::string err_msg = "Failed to initialize LODR: No backoff arc found";
SHERPA_ONNX_LOGE("%s", err_msg.c_str());
SHERPA_ONNX_EXIT(-1);
}
}
}
std::vector<std::tuple<int32_t, float>> LodrFst::ProcessBackoffArcs(
int32_t state, float cost) {
std::vector<std::tuple<int32_t, float>> ans;
auto next = GetNextStatesCostsNoBackoff(state, backoff_id_);
if (!next.has_value()) {
return ans;
}
auto [next_state, next_cost] = next.value();
ans.emplace_back(next_state, next_cost + cost);
auto recursive_result = ProcessBackoffArcs(next_state, next_cost + cost);
ans.insert(ans.end(), recursive_result.begin(), recursive_result.end());
return ans;
}
std::optional<std::tuple<int32_t, float>> LodrFst::GetNextStatesCostsNoBackoff(
int32_t state, int32_t label) {
fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
int32_t num_arcs = fst_->NumArcs(state);
int32_t left = 0, right = num_arcs - 1;
while (left <= right) {
int32_t mid = (left + right) / 2;
arc_iter.Seek(mid);
auto arc = arc_iter.Value();
if (arc.ilabel < label) {
left = mid + 1;
} else if (arc.ilabel > label) {
right = mid - 1;
} else {
return std::make_tuple(arc.nextstate, arc.weight.Value());
}
}
return std::nullopt;
}
std::pair<std::vector<int32_t>, std::vector<float>> LodrFst::GetNextStateCosts(
int32_t state, int32_t label) {
std::vector<int32_t> states = {state};
std::vector<float> costs = {0};
auto extra_states_costs = ProcessBackoffArcs(state, 0);
for (const auto& [s, c] : extra_states_costs) {
states.push_back(s);
costs.push_back(c);
}
std::vector<int32_t> next_states;
std::vector<float> next_costs;
for (size_t i = 0; i < states.size(); ++i) {
auto next = GetNextStatesCostsNoBackoff(states[i], label);
if (next.has_value()) {
auto [ns, nc] = next.value();
next_states.push_back(ns);
next_costs.push_back(costs[i] + nc);
}
}
return std::make_pair(next_states, next_costs);
}
void LodrFst::ComputeScore(float scale, Hypothesis *hyp, int32_t offset) {
if (scale == 0) {
return;
}
hyp->lodr_state = std::make_unique<LodrStateCost>(this);
// Walk through the FST with the input text from the hypothesis
for (size_t i = offset; i < hyp->ys.size(); ++i) {
*hyp->lodr_state = hyp->lodr_state->ForwardOneStep(hyp->ys[i]);
}
float lodr_score = hyp->lodr_state->FinalScore();
if (lodr_score == -std::numeric_limits<float>::infinity()) {
SHERPA_ONNX_LOGE("Failed to compute LODR. Empty or mismatched FST?");
return;
}
// Update the hyp score
hyp->log_prob += scale * lodr_score;
}
float LodrFst::GetFinalCost(int32_t state) {
auto final_weight = fst_->Final(state);
if (final_weight == fst::StdArc::Weight::Zero()) {
return 0.0;
}
return final_weight.Value();
}
LodrStateCost::LodrStateCost(
LodrFst* fst, const std::unordered_map<int32_t, float> &state_cost)
: fst_(fst) {
if (state_cost.empty()) {
state_cost_[0] = 0.0;
} else {
state_cost_ = state_cost;
}
}
LodrStateCost LodrStateCost::ForwardOneStep(int32_t label) {
std::unordered_map<int32_t, float> state_cost;
for (const auto& [s, c] : state_cost_) {
auto [next_states, next_costs] = fst_->GetNextStateCosts(s, label);
for (size_t i = 0; i < next_states.size(); ++i) {
int32_t ns = next_states[i];
float nc = next_costs[i];
if (state_cost.find(ns) == state_cost.end()) {
state_cost[ns] = std::numeric_limits<float>::infinity();
}
state_cost[ns] = std::min(state_cost[ns], c + nc);
}
}
return LodrStateCost(fst_, state_cost);
}
float LodrStateCost::Score() const {
if (state_cost_.empty()) {
return -std::numeric_limits<float>::infinity();
}
auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
[](const auto& a, const auto& b) {
return a.second < b.second;
});
return -min_cost->second;
}
float LodrStateCost::FinalScore() const {
if (state_cost_.empty()) {
return -std::numeric_limits<float>::infinity();
}
auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
[](const auto& a, const auto& b) {
return a.second < b.second;
});
return -(min_cost->second +
fst_->GetFinalCost(min_cost->first));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/lodr-fst.h
//
// Contains code copied from icefall/utils/ngram_lm.py
// Copyright (c) 2023 Xiaomi Corporation
//
// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
#ifndef SHERPA_ONNX_CSRC_LODR_FST_H_
#define SHERPA_ONNX_CSRC_LODR_FST_H_
#include <memory>
#include <string>
#include <vector>
#include <optional>
#include <tuple>
#include <unordered_map>
#include <limits>
#include <algorithm>
#include <utility>
#include "kaldifst/csrc/kaldi-fst-io.h"
namespace sherpa_onnx {
class Hypothesis;
class LodrFst {
public:
explicit LodrFst(const std::string &fst_path, int32_t backoff_id = -1);
std::pair<std::vector<int32_t>, std::vector<float>> GetNextStateCosts(
int32_t state, int32_t label);
float GetFinalCost(int32_t state);
void ComputeScore(float scale, Hypothesis *hyp, int32_t offset);
private:
fst::StdVectorFst YsToFst(const std::vector<int64_t> &ys, int32_t offset);
std::vector<std::tuple<int32_t, float>> ProcessBackoffArcs(
int32_t state, float cost);
std::optional<std::tuple<int32_t, float>> GetNextStatesCostsNoBackoff(
int32_t state, int32_t label);
int32_t FindBackoffId();
int32_t backoff_id_ = -1;
std::unique_ptr<fst::StdConstFst> fst_; // owned by this class
};
class LodrStateCost {
public:
explicit LodrStateCost(
LodrFst* fst,
const std::unordered_map<int32_t, float> &state_cost = {});
LodrStateCost ForwardOneStep(int32_t label);
float Score() const;
float FinalScore() const;
private:
// The fst_ is not owned by this class and borrowed from the caller
// (e.g. OnlineRnnLM).
LodrFst* fst_;
std::unordered_map<int32_t, float> state_cost_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_LODR_FST_H_
... ...
... ... @@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) {
"Number of threads to run the neural network of LM model");
po->Register("lm-provider", &lm_provider,
"Specify a provider to LM model use: cpu, cuda, coreml");
po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
po->Register("lodr-scale", &lodr_scale, "LODR scale.");
po->Register("lodr-backoff-id", &lodr_backoff_id,
"ID of the backoff in the LODR FST. -1 means autodetect");
}
bool OfflineLMConfig::Validate() const {
... ... @@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const {
return false;
}
if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
return false;
}
return true;
}
... ... @@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const {
os << "OfflineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ")";
os << "scale=" << scale << ", ";
os << "lodr_scale=" << lodr_scale << ", ";
os << "lodr_fst=\"" << lodr_fst << "\", ";
os << "lodr_backoff_id=" << lodr_backoff_id << ")";
return os.str();
}
... ...
... ... @@ -19,14 +19,23 @@ struct OfflineLMConfig {
int32_t lm_num_threads = 1;
std::string lm_provider = "cpu";
// LODR
std::string lodr_fst;
float lodr_scale = 0.01;
int32_t lodr_backoff_id = -1; // -1 means not set
OfflineLMConfig() = default;
OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
const std::string &lm_provider)
const std::string &lm_provider, const std::string &lodr_fst,
float lodr_scale, int32_t lodr_backoff_id)
: model(model),
scale(scale),
lm_num_threads(lm_num_threads),
lm_provider(lm_provider) {}
lm_provider(lm_provider),
lodr_fst(lodr_fst),
lodr_scale(lodr_scale),
lodr_backoff_id(lodr_backoff_id) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -17,6 +17,7 @@
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/lodr-fst.h"
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
namespace sherpa_onnx {
... ... @@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
}
auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
const float *p_nll = negative_loglike.GetTensorData<float>();
// We scale LODR scale with LM scale to replicate Icefall code
auto lodr_scale = config_.lodr_scale * scale;
for (auto &h : *hyps) {
for (auto &t : h) {
// Use -scale here since we want to change negative loglike to loglike.
t.second.lm_log_prob = -scale * (*p_nll);
++p_nll;
// apply LODR to hyp score
if (lodr_fst_ != nullptr) {
lodr_fst_->ComputeScore(lodr_scale, &t.second, context_size);
}
}
}
}
... ...
... ... @@ -10,12 +10,24 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/lodr-fst.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
namespace sherpa_onnx {
class OfflineLM {
public:
explicit OfflineLM(const OfflineLMConfig &config) : config_(config) {
if (!config_.lodr_fst.empty()) {
try {
lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
config_.lodr_backoff_id));
} catch (const std::exception& e) {
throw std::runtime_error("Failed to load LODR FST from: " +
config_.lodr_fst + ". Error: " + e.what());
}
}
}
virtual ~OfflineLM() = default;
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
... ... @@ -43,6 +55,11 @@ class OfflineLM {
// @param hyps It is changed in-place.
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps);
private:
std::unique_ptr<LodrFst> lodr_fst_;
float lodr_scale_;
OfflineLMConfig config_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -83,11 +83,11 @@ class OfflineRnnLM::Impl {
};
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
: impl_(std::make_unique<Impl>(config)), OfflineLM(config) {}
template <typename Manager>
OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
: impl_(std::make_unique<Impl>(mgr, config)), OfflineLM(config) {}
OfflineRnnLM::~OfflineRnnLM() = default;
... ...
... ... @@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) {
"Specify a provider to LM model use: cpu, cuda, coreml");
po->Register("lm-shallow-fusion", &shallow_fusion,
"Boolean whether to use shallow fusion or rescore.");
po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
po->Register("lodr-scale", &lodr_scale, "LODR scale.");
po->Register("lodr-backoff-id", &lodr_backoff_id,
"ID of the backoff in the LODR FST. -1 means autodetect");
}
bool OnlineLMConfig::Validate() const {
... ... @@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const {
return false;
}
if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
return false;
}
return true;
}
... ... @@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const {
os << "OnlineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ", ";
os << "lodr_scale=" << lodr_scale << ", ";
os << "lodr_fst=\"" << lodr_fst << "\", ";
os << "lodr_backoff_id=" << lodr_backoff_id << ", ";
os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";
return os.str();
... ...
... ... @@ -18,18 +18,26 @@ struct OnlineLMConfig {
float scale = 0.5;
int32_t lm_num_threads = 1;
std::string lm_provider = "cpu";
std::string lodr_fst;
float lodr_scale = 0.01;
int32_t lodr_backoff_id = -1; // -1 means not set
// enable shallow fusion
bool shallow_fusion = true;
OnlineLMConfig() = default;
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
const std::string &lm_provider, bool shallow_fusion)
const std::string &lm_provider, bool shallow_fusion,
const std::string &lodr_fst, float lodr_scale,
int32_t lodr_backoff_id)
: model(model),
scale(scale),
lm_num_threads(lm_num_threads),
lm_provider(lm_provider),
shallow_fusion(shallow_fusion) {}
shallow_fusion(shallow_fusion),
lodr_fst(lodr_fst),
lodr_scale(lodr_scale),
lodr_backoff_id(lodr_backoff_id) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -12,6 +12,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/lodr-fst.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
... ... @@ -35,12 +36,27 @@ class OnlineRnnLM::Impl {
auto init_states = GetInitStatesSF();
hyp->nn_lm_scores.value = std::move(init_states.first);
hyp->nn_lm_states = Convert(std::move(init_states.second));
// if LODR enabled, we need to initialize the LODR state
if (lodr_fst_ != nullptr) {
hyp->lodr_state = std::make_unique<LodrStateCost>(lodr_fst_.get());
}
}
// get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale;
// if LODR enabled, we need to update the LODR state
if (lodr_fst_ != nullptr) {
auto next_lodr_state = std::make_unique<LodrStateCost>(
hyp->lodr_state->ForwardOneStep(hyp->ys.back()));
// calculate the score of the latest token
auto score = next_lodr_state->Score() - hyp->lodr_state->Score();
hyp->lodr_state = std::move(next_lodr_state);
// apply LODR to hyp score
hyp->lm_log_prob += score * config_.lodr_scale;
}
// get lm scores for next tokens given the hyp->ys[:] and save to
// nn_lm_scores
std::array<int64_t, 2> x_shape{1, 1};
... ... @@ -89,6 +105,12 @@ class OnlineRnnLM::Impl {
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// apply LODR to hyp score
if (lodr_fst_ != nullptr) {
// We scale LODR scale with LM scale to replicate Icefall code
lodr_fst_->ComputeScore(config_.lodr_scale*scale, &h, context_size);
}
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
... ... @@ -154,6 +176,11 @@ class OnlineRnnLM::Impl {
SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
ComputeInitStates();
if (!config_.lodr_fst.empty()) {
lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
config_.lodr_backoff_id));
}
}
void ComputeInitStates() {
... ... @@ -203,6 +230,8 @@ class OnlineRnnLM::Impl {
int32_t rnn_num_layers_ = 2;
int32_t rnn_hidden_size_ = 512;
int32_t sos_id_ = 1;
std::unique_ptr<LodrFst> lodr_fst_;
};
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
... ...
... ... @@ -13,13 +13,19 @@ namespace sherpa_onnx {
void PybindOfflineLMConfig(py::module *m) {
using PyClass = OfflineLMConfig;
py::class_<PyClass>(*m, "OfflineLMConfig")
.def(py::init<const std::string &, float, int32_t, const std::string &>(),
.def(py::init<const std::string &, float, int32_t, const std::string &,
const std::string &, float, int32_t>(),
py::arg("model"), py::arg("scale") = 0.5f,
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu")
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
py::arg("lodr_fst") = "", py::arg("lodr_scale") = 0.0f,
py::arg("lodr_backoff_id") = -1)
.def_readwrite("model", &PyClass::model)
.def_readwrite("scale", &PyClass::scale)
.def_readwrite("lm_provider", &PyClass::lm_provider)
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
.def_readwrite("lodr_fst", &PyClass::lodr_fst)
.def_readwrite("lodr_scale", &PyClass::lodr_scale)
.def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) {
using PyClass = OnlineLMConfig;
py::class_<PyClass>(*m, "OnlineLMConfig")
.def(py::init<const std::string &, float, int32_t,
const std::string &, bool>(),
const std::string &, bool, const std::string &,
float, int>(),
py::arg("model") = "", py::arg("scale") = 0.5f,
py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
py::arg("shallow_fusion") = true)
py::arg("shallow_fusion") = true, py::arg("lodr_fst") = "",
py::arg("lodr_scale") = 0.0f, py::arg("lodr_backoff_id") = -1)
.def_readwrite("model", &PyClass::model)
.def_readwrite("scale", &PyClass::scale)
.def_readwrite("lm_provider", &PyClass::lm_provider)
.def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
.def_readwrite("shallow_fusion", &PyClass::shallow_fusion)
.def_readwrite("lodr_fst", &PyClass::lodr_fst)
.def_readwrite("lodr_scale", &PyClass::lodr_scale)
.def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -69,6 +69,8 @@ class OfflineRecognizer(object):
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
lodr_fst: str = "",
lodr_scale: float = 0.0,
):
"""
Please refer to
... ... @@ -133,6 +135,10 @@ class OfflineRecognizer(object):
rule_fars:
If not empty, it specifies fst archives for inverse text normalization.
If there are multiple archives, they are separated by a comma.
lodr_fst:
Path to the LODR FST file in binary format. If empty, LODR is disabled.
lodr_scale:
Scale factor for LODR rescoring. Only used when lodr_fst is provided.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
... ... @@ -173,6 +179,8 @@ class OfflineRecognizer(object):
scale=lm_scale,
lm_num_threads=num_threads,
lm_provider=provider,
lodr_fst=lodr_fst,
lodr_scale=lodr_scale,
)
recognizer_config = OfflineRecognizerConfig(
... ...
... ... @@ -89,6 +89,8 @@ class OnlineRecognizer(object):
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
lodr_fst: str = "",
lodr_scale: float = 0.0,
):
"""
Please refer to
... ... @@ -216,6 +218,10 @@ class OnlineRecognizer(object):
"Set path for storing timing cache." TensorRT EP
trt_dump_subgraphs: bool = False,
"Dump optimized subgraphs for debugging." TensorRT EP
lodr_fst:
Path to the LODR FST file in binary format. If empty, LODR is disabled.
lodr_scale:
Scale factor for LODR rescoring. Only used when lodr_fst is provided.
"""
self = cls.__new__(cls)
_assert_file_exists(tokens)
... ... @@ -298,6 +304,8 @@ class OnlineRecognizer(object):
model=lm,
scale=lm_scale,
shallow_fusion=lm_shallow_fusion,
lodr_fst=lodr_fst,
lodr_scale=lodr_scale,
)
recognizer_config = OnlineRecognizerConfig(
... ...