Askars Salimbajevs
Committed by GitHub

Add LODR support to online and offline recognizers (#2026)

This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore.

- Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id.
- Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths.
- Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options.
@@ -281,7 +281,39 @@ time $EXE \ @@ -281,7 +281,39 @@ time $EXE \
281 $repo/test_wavs/1.wav \ 281 $repo/test_wavs/1.wav \
282 $repo/test_wavs/8k.wav 282 $repo/test_wavs/8k.wav
283 283
284 -rm -rf $repo 284 +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
  285 +log "Download pre-trained RNN-LM model from ${lm_repo_url}"
  286 +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
  287 +lm_repo=$(basename $lm_repo_url)
  288 +pushd $lm_repo
  289 +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
  290 +popd
  291 +
  292 +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
  293 +log "Download bi-gram LM from ${bigram_repo_url}"
  294 +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
  295 +bigramlm_repo=$(basename $bigram_repo_url)
  296 +pushd $bigramlm_repo
  297 +git lfs pull --include "2gram.fst"
  298 +popd
  299 +
  300 +log "Start testing with LM and bi-gram LODR"
  301 +# TODO: find test examples that change with the LODR
  302 +time $EXE \
  303 + --tokens=$repo/tokens.txt \
  304 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  305 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  306 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  307 + --num-threads=2 \
  308 + --decoding_method="modified_beam_search" \
  309 + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
  310 + --lodr-fst=$bigramlm_repo/2gram.fst \
  311 + --lodr-scale=-0.5 \
  312 + $repo/test_wavs/0.wav \
  313 + $repo/test_wavs/1.wav \
  314 + $repo/test_wavs/8k.wav
  315 +
  316 +rm -rf $repo $lm_repo $bigramlm_repo
285 317
286 log "------------------------------------------------------------" 318 log "------------------------------------------------------------"
287 log "Run Paraformer (Chinese)" 319 log "Run Paraformer (Chinese)"
@@ -174,7 +174,60 @@ for wave in ${waves[@]}; do @@ -174,7 +174,60 @@ for wave in ${waves[@]}; do
174 $wave 174 $wave
175 done 175 done
176 176
177 -rm -rf $repo 177 +lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm
  178 +log "Download pre-trained RNN-LM model from ${lm_repo_url}"
  179 +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
  180 +lm_repo=$(basename $lm_repo_url)
  181 +pushd $lm_repo
  182 +git lfs pull --include "with-state-epoch-99-avg-1.onnx"
  183 +popd
  184 +
  185 +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
  186 +log "Download bi-gram LM from ${bigram_repo_url}"
  187 +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
  188 +bigramlm_repo=$(basename $bigram_repo_url)
  189 +pushd $bigramlm_repo
  190 +git lfs pull --include "2gram.fst"
  191 +popd
  192 +
  193 +log "Start testing LODR"
  194 +
  195 +waves=(
  196 +$repo/test_wavs/0.wav
  197 +$repo/test_wavs/1.wav
  198 +$repo/test_wavs/8k.wav
  199 +)
  200 +
  201 +for wave in ${waves[@]}; do
  202 + time $EXE \
  203 + --tokens=$repo/tokens.txt \
  204 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  205 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  206 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  207 + --num-threads=2 \
  208 + --decoding_method="modified_beam_search" \
  209 + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
  210 + --lodr-fst=$bigramlm_repo/2gram.fst \
  211 + --lodr-scale=-0.5 \
  212 + $wave
  213 +done
  214 +
  215 +for wave in ${waves[@]}; do
  216 + time $EXE \
  217 + --tokens=$repo/tokens.txt \
  218 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  219 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  220 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  221 + --num-threads=2 \
  222 + --decoding_method="modified_beam_search" \
  223 + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \
  224 + --lodr-fst=$bigramlm_repo/2gram.fst \
  225 + --lodr-scale=-0.5 \
  226 + --lm-shallow-fusion=true \
  227 + $wave
  228 +done
  229 +
  230 +rm -rf $repo $bigramlm_repo $lm_repo
178 231
179 log "------------------------------------------------------------" 232 log "------------------------------------------------------------"
180 log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" 233 log "Run streaming Zipformer transducer (Bilingual, Chinese + English)"
@@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \
562 $repo/test_wavs/1.wav \ 562 $repo/test_wavs/1.wav \
563 $repo/test_wavs/8k.wav 563 $repo/test_wavs/8k.wav
564 564
  565 +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
  566 +log "Download pre-trained RNN-LM model from ${lm_repo_url}"
  567 +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url
  568 +lm_repo=$(basename $lm_repo_url)
  569 +pushd $lm_repo
  570 +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx"
  571 +popd
  572 +
  573 +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26
  574 +log "Download bi-gram LM from ${bigram_repo_url}"
  575 +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url
  576 +bigramlm_repo=$(basename $bigram_repo_url)
  577 +pushd $bigramlm_repo
  578 +git lfs pull --include "2gram.fst"
  579 +popd
  580 +
  581 +log "Perform offline decoding with RNN-LM and LODR"
  582 +python3 ./python-api-examples/offline-decode-files.py \
  583 + --tokens=$repo/tokens.txt \
  584 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  585 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  586 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  587 + --decoding-method=modified_beam_search \
  588 + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \
  589 + --lodr-fst=$bigramlm_repo/2gram.fst \
  590 + --lodr-scale=-0.5 \
  591 + $repo/test_wavs/0.wav \
  592 + $repo/test_wavs/1.wav \
  593 + $repo/test_wavs/8k.wav
  594 +
565 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose 595 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
566 596
567 -rm -rf $repo 597 +rm -rf $repo $lm_repo $bigramlm_repo
568 598
569 log "Test non-streaming paraformer models" 599 log "Test non-streaming paraformer models"
570 600
@@ -35,6 +35,25 @@ file(s) with a non-streaming model. @@ -35,6 +35,25 @@ file(s) with a non-streaming model.
35 /path/to/0.wav \ 35 /path/to/0.wav \
36 /path/to/1.wav 36 /path/to/1.wav
37 37
  38 + also with RNN LM rescoring and LODR (optional):
  39 +
  40 + ./python-api-examples/offline-decode-files.py \
  41 + --tokens=/path/to/tokens.txt \
  42 + --encoder=/path/to/encoder.onnx \
  43 + --decoder=/path/to/decoder.onnx \
  44 + --joiner=/path/to/joiner.onnx \
  45 + --num-threads=2 \
  46 + --decoding-method=modified_beam_search \
  47 + --debug=false \
  48 + --sample-rate=16000 \
  49 + --feature-dim=80 \
  50 + --lm=/path/to/lm.onnx \
  51 + --lm-scale=0.1 \
  52 + --lodr-fst=/path/to/lodr.fst \
  53 + --lodr-scale=-0.1 \
  54 + /path/to/0.wav \
  55 + /path/to/1.wav
  56 +
38 (3) For CTC models from NeMo 57 (3) For CTC models from NeMo
39 58
40 python3 ./python-api-examples/offline-decode-files.py \ 59 python3 ./python-api-examples/offline-decode-files.py \
@@ -269,6 +288,39 @@ def get_args(): @@ -269,6 +288,39 @@ def get_args():
269 default="greedy_search", 288 default="greedy_search",
270 help="Valid values are greedy_search and modified_beam_search", 289 help="Valid values are greedy_search and modified_beam_search",
271 ) 290 )
  291 +
  292 + parser.add_argument(
  293 + "--lm",
  294 + metavar="file",
  295 + type=str,
  296 + default="",
  297 + help="Path to RNN LM model",
  298 + )
  299 +
  300 + parser.add_argument(
  301 + "--lm-scale",
  302 + metavar="lm_scale",
  303 + type=float,
  304 + default=0.1,
  305 + help="LM model scale for rescoring",
  306 + )
  307 +
  308 + parser.add_argument(
  309 + "--lodr-fst",
  310 + metavar="file",
  311 + type=str,
  312 + default="",
  313 + help="Path to LODR FST model. Used only when --lm is given.",
  314 + )
  315 +
  316 + parser.add_argument(
  317 + "--lodr-scale",
  318 + metavar="lodr_scale",
  319 + type=float,
  320 + default=-0.1,
  321 + help="LODR scale for rescoring.Used only when --lodr_fst is given.",
  322 + )
  323 +
272 parser.add_argument( 324 parser.add_argument(
273 "--debug", 325 "--debug",
274 type=bool, 326 type=bool,
@@ -364,6 +416,10 @@ def main(): @@ -364,6 +416,10 @@ def main():
364 num_threads=args.num_threads, 416 num_threads=args.num_threads,
365 sample_rate=args.sample_rate, 417 sample_rate=args.sample_rate,
366 feature_dim=args.feature_dim, 418 feature_dim=args.feature_dim,
  419 + lm=args.lm,
  420 + lm_scale=args.lm_scale,
  421 + lodr_fst=args.lodr_fst,
  422 + lodr_scale=args.lodr_scale,
367 decoding_method=args.decoding_method, 423 decoding_method=args.decoding_method,
368 hotwords_file=args.hotwords_file, 424 hotwords_file=args.hotwords_file,
369 hotwords_score=args.hotwords_score, 425 hotwords_score=args.hotwords_score,
@@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 @@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2
21 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ 21 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
22 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav 22 ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
23 23
  24 +or with RNN LM rescoring and LODR:
  25 +
  26 +./python-api-examples/online-decode-files.py \
  27 + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \
  28 + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \
  29 + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \
  30 + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \
  31 + --decoding-method=modified_beam_search \
  32 + --lm=/path/to/lm.onnx \
  33 + --lm-scale=0.1 \
  34 + --lodr-fst=/path/to/lodr.fst \
  35 + --lodr-scale=-0.1 \
  36 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \
  37 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \
  38 + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav
  39 +
24 (2) Streaming paraformer 40 (2) Streaming paraformer
25 41
26 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 42 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2
@@ -187,6 +203,22 @@ def get_args(): @@ -187,6 +203,22 @@ def get_args():
187 ) 203 )
188 204
189 parser.add_argument( 205 parser.add_argument(
  206 + "--lodr-fst",
  207 + metavar="file",
  208 + type=str,
  209 + default="",
  210 + help="Path to LODR FST model. Used only when --lm is given.",
  211 + )
  212 +
  213 + parser.add_argument(
  214 + "--lodr-scale",
  215 + metavar="lodr_scale",
  216 + type=float,
  217 + default=-0.1,
  218 + help="LODR scale for rescoring.Used only when --lodr_fst is given.",
  219 + )
  220 +
  221 + parser.add_argument(
190 "--provider", 222 "--provider",
191 type=str, 223 type=str,
192 default="cpu", 224 default="cpu",
@@ -320,6 +352,8 @@ def main(): @@ -320,6 +352,8 @@ def main():
320 max_active_paths=args.max_active_paths, 352 max_active_paths=args.max_active_paths,
321 lm=args.lm, 353 lm=args.lm,
322 lm_scale=args.lm_scale, 354 lm_scale=args.lm_scale,
  355 + lodr_fst=args.lodr_fst,
  356 + lodr_scale=args.lodr_scale,
323 hotwords_file=args.hotwords_file, 357 hotwords_file=args.hotwords_file,
324 hotwords_score=args.hotwords_score, 358 hotwords_score=args.hotwords_score,
325 modeling_unit=args.modeling_unit, 359 modeling_unit=args.modeling_unit,
@@ -25,6 +25,7 @@ set(sources @@ -25,6 +25,7 @@ set(sources
25 jieba.cc 25 jieba.cc
26 keyword-spotter-impl.cc 26 keyword-spotter-impl.cc
27 keyword-spotter.cc 27 keyword-spotter.cc
  28 + lodr-fst.cc
28 offline-canary-model-config.cc 29 offline-canary-model-config.cc
29 offline-canary-model.cc 30 offline-canary-model.cc
30 offline-ctc-fst-decoder-config.cc 31 offline-ctc-fst-decoder-config.cc
@@ -12,9 +12,11 @@ @@ -12,9 +12,11 @@
12 #include <unordered_map> 12 #include <unordered_map>
13 #include <utility> 13 #include <utility>
14 #include <vector> 14 #include <vector>
  15 +#include <memory>
15 16
16 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
17 #include "sherpa-onnx/csrc/context-graph.h" 18 #include "sherpa-onnx/csrc/context-graph.h"
  19 +#include "sherpa-onnx/csrc/lodr-fst.h"
18 #include "sherpa-onnx/csrc/math.h" 20 #include "sherpa-onnx/csrc/math.h"
19 #include "sherpa-onnx/csrc/onnx-utils.h" 21 #include "sherpa-onnx/csrc/onnx-utils.h"
20 22
@@ -61,6 +63,9 @@ struct Hypothesis { @@ -61,6 +63,9 @@ struct Hypothesis {
61 // the nn lm states 63 // the nn lm states
62 std::vector<CopyableOrtValue> nn_lm_states; 64 std::vector<CopyableOrtValue> nn_lm_states;
63 65
  66 + // the LODR states
  67 + std::shared_ptr<LodrStateCost> lodr_state;
  68 +
64 const ContextState *context_state; 69 const ContextState *context_state;
65 70
66 // TODO(fangjun): Make it configurable 71 // TODO(fangjun): Make it configurable
  1 +// sherpa-onnx/csrc/lodr-fst.cc
  2 +//
  3 +// Contains code copied from icefall/utils/ngram_lm.py
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +//
  6 +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
  7 +
  8 +#include <algorithm>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/lodr-fst.h"
  13 +#include "sherpa-onnx/csrc/log.h"
  14 +#include "sherpa-onnx/csrc/hypothesis.h"
  15 +#include "sherpa-onnx/csrc/macros.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +int32_t LodrFst::FindBackoffId() {
  20 + // assume that the backoff id is the only input label with epsilon output
  21 +
  22 + for (int32_t state = 0; state < fst_->NumStates(); ++state) {
  23 + fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
  24 + for ( ; !arc_iter.Done(); arc_iter.Next()) {
  25 + const auto& arc = arc_iter.Value();
  26 + if (arc.olabel == 0) { // Check if the output label is epsilon (0)
  27 + return arc.ilabel; // Return the input label
  28 + }
  29 + }
  30 + }
  31 +
  32 + return -1; // Return -1 if no such input symbol is found
  33 +}
  34 +
  35 +LodrFst::LodrFst(const std::string &fst_path, int32_t backoff_id)
  36 + : backoff_id_(backoff_id) {
  37 + fst_ = std::unique_ptr<fst::StdConstFst>(
  38 + CastOrConvertToConstFst(fst::StdVectorFst::Read(fst_path)));
  39 +
  40 + if (backoff_id < 0) {
  41 + // backoff_id_ is not provided, find it automatically
  42 + backoff_id_ = FindBackoffId();
  43 + if (backoff_id_ < 0) {
  44 + std::string err_msg = "Failed to initialize LODR: No backoff arc found";
  45 + SHERPA_ONNX_LOGE("%s", err_msg.c_str());
  46 + SHERPA_ONNX_EXIT(-1);
  47 + }
  48 + }
  49 +}
  50 +
  51 +std::vector<std::tuple<int32_t, float>> LodrFst::ProcessBackoffArcs(
  52 + int32_t state, float cost) {
  53 + std::vector<std::tuple<int32_t, float>> ans;
  54 + auto next = GetNextStatesCostsNoBackoff(state, backoff_id_);
  55 + if (!next.has_value()) {
  56 + return ans;
  57 + }
  58 + auto [next_state, next_cost] = next.value();
  59 + ans.emplace_back(next_state, next_cost + cost);
  60 + auto recursive_result = ProcessBackoffArcs(next_state, next_cost + cost);
  61 + ans.insert(ans.end(), recursive_result.begin(), recursive_result.end());
  62 + return ans;
  63 +}
  64 +
  65 +std::optional<std::tuple<int32_t, float>> LodrFst::GetNextStatesCostsNoBackoff(
  66 + int32_t state, int32_t label) {
  67 + fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state);
  68 + int32_t num_arcs = fst_->NumArcs(state);
  69 +
  70 + int32_t left = 0, right = num_arcs - 1;
  71 + while (left <= right) {
  72 + int32_t mid = (left + right) / 2;
  73 + arc_iter.Seek(mid);
  74 + auto arc = arc_iter.Value();
  75 + if (arc.ilabel < label) {
  76 + left = mid + 1;
  77 + } else if (arc.ilabel > label) {
  78 + right = mid - 1;
  79 + } else {
  80 + return std::make_tuple(arc.nextstate, arc.weight.Value());
  81 + }
  82 + }
  83 + return std::nullopt;
  84 +}
  85 +
  86 +std::pair<std::vector<int32_t>, std::vector<float>> LodrFst::GetNextStateCosts(
  87 + int32_t state, int32_t label) {
  88 + std::vector<int32_t> states = {state};
  89 + std::vector<float> costs = {0};
  90 +
  91 + auto extra_states_costs = ProcessBackoffArcs(state, 0);
  92 + for (const auto& [s, c] : extra_states_costs) {
  93 + states.push_back(s);
  94 + costs.push_back(c);
  95 + }
  96 +
  97 + std::vector<int32_t> next_states;
  98 + std::vector<float> next_costs;
  99 + for (size_t i = 0; i < states.size(); ++i) {
  100 + auto next = GetNextStatesCostsNoBackoff(states[i], label);
  101 + if (next.has_value()) {
  102 + auto [ns, nc] = next.value();
  103 + next_states.push_back(ns);
  104 + next_costs.push_back(costs[i] + nc);
  105 + }
  106 + }
  107 +
  108 + return std::make_pair(next_states, next_costs);
  109 +}
  110 +
  111 +void LodrFst::ComputeScore(float scale, Hypothesis *hyp, int32_t offset) {
  112 + if (scale == 0) {
  113 + return;
  114 + }
  115 +
  116 + hyp->lodr_state = std::make_unique<LodrStateCost>(this);
  117 +
  118 + // Walk through the FST with the input text from the hypothesis
  119 + for (size_t i = offset; i < hyp->ys.size(); ++i) {
  120 + *hyp->lodr_state = hyp->lodr_state->ForwardOneStep(hyp->ys[i]);
  121 + }
  122 +
  123 + float lodr_score = hyp->lodr_state->FinalScore();
  124 +
  125 + if (lodr_score == -std::numeric_limits<float>::infinity()) {
  126 + SHERPA_ONNX_LOGE("Failed to compute LODR. Empty or mismatched FST?");
  127 + return;
  128 + }
  129 +
  130 + // Update the hyp score
  131 + hyp->log_prob += scale * lodr_score;
  132 +}
  133 +
  134 +float LodrFst::GetFinalCost(int32_t state) {
  135 + auto final_weight = fst_->Final(state);
  136 + if (final_weight == fst::StdArc::Weight::Zero()) {
  137 + return 0.0;
  138 + }
  139 + return final_weight.Value();
  140 +}
  141 +
  142 +LodrStateCost::LodrStateCost(
  143 + LodrFst* fst, const std::unordered_map<int32_t, float> &state_cost)
  144 + : fst_(fst) {
  145 + if (state_cost.empty()) {
  146 + state_cost_[0] = 0.0;
  147 + } else {
  148 + state_cost_ = state_cost;
  149 + }
  150 +}
  151 +
  152 +LodrStateCost LodrStateCost::ForwardOneStep(int32_t label) {
  153 + std::unordered_map<int32_t, float> state_cost;
  154 + for (const auto& [s, c] : state_cost_) {
  155 + auto [next_states, next_costs] = fst_->GetNextStateCosts(s, label);
  156 + for (size_t i = 0; i < next_states.size(); ++i) {
  157 + int32_t ns = next_states[i];
  158 + float nc = next_costs[i];
  159 + if (state_cost.find(ns) == state_cost.end()) {
  160 + state_cost[ns] = std::numeric_limits<float>::infinity();
  161 + }
  162 + state_cost[ns] = std::min(state_cost[ns], c + nc);
  163 + }
  164 + }
  165 + return LodrStateCost(fst_, state_cost);
  166 +}
  167 +
  168 +float LodrStateCost::Score() const {
  169 + if (state_cost_.empty()) {
  170 + return -std::numeric_limits<float>::infinity();
  171 + }
  172 + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
  173 + [](const auto& a, const auto& b) {
  174 + return a.second < b.second;
  175 + });
  176 + return -min_cost->second;
  177 +}
  178 +
  179 +float LodrStateCost::FinalScore() const {
  180 + if (state_cost_.empty()) {
  181 + return -std::numeric_limits<float>::infinity();
  182 + }
  183 + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(),
  184 + [](const auto& a, const auto& b) {
  185 + return a.second < b.second;
  186 + });
  187 + return -(min_cost->second +
  188 + fst_->GetFinalCost(min_cost->first));
  189 +}
  190 +
  191 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/lodr-fst.h
  2 +//
  3 +// Contains code copied from icefall/utils/ngram_lm.py
  4 +// Copyright (c) 2023 Xiaomi Corporation
  5 +//
  6 +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs)
  7 +
  8 +
  9 +#ifndef SHERPA_ONNX_CSRC_LODR_FST_H_
  10 +#define SHERPA_ONNX_CSRC_LODR_FST_H_
  11 +
  12 +#include <memory>
  13 +#include <string>
  14 +#include <vector>
  15 +#include <optional>
  16 +#include <tuple>
  17 +#include <unordered_map>
  18 +#include <limits>
  19 +#include <algorithm>
  20 +#include <utility>
  21 +
  22 +#include "kaldifst/csrc/kaldi-fst-io.h"
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +class Hypothesis;
  27 +
  28 +class LodrFst {
  29 + public:
  30 + explicit LodrFst(const std::string &fst_path, int32_t backoff_id = -1);
  31 +
  32 + std::pair<std::vector<int32_t>, std::vector<float>> GetNextStateCosts(
  33 + int32_t state, int32_t label);
  34 +
  35 + float GetFinalCost(int32_t state);
  36 +
  37 + void ComputeScore(float scale, Hypothesis *hyp, int32_t offset);
  38 +
  39 + private:
  40 + fst::StdVectorFst YsToFst(const std::vector<int64_t> &ys, int32_t offset);
  41 +
  42 + std::vector<std::tuple<int32_t, float>> ProcessBackoffArcs(
  43 + int32_t state, float cost);
  44 +
  45 + std::optional<std::tuple<int32_t, float>> GetNextStatesCostsNoBackoff(
  46 + int32_t state, int32_t label);
  47 +
  48 + int32_t FindBackoffId();
  49 +
  50 +
  51 + int32_t backoff_id_ = -1;
  52 + std::unique_ptr<fst::StdConstFst> fst_; // owned by this class
  53 +};
  54 +
  55 +class LodrStateCost {
  56 + public:
  57 + explicit LodrStateCost(
  58 + LodrFst* fst,
  59 + const std::unordered_map<int32_t, float> &state_cost = {});
  60 +
  61 + LodrStateCost ForwardOneStep(int32_t label);
  62 +
  63 + float Score() const;
  64 + float FinalScore() const;
  65 +
  66 + private:
  67 + // The fst_ is not owned by this class and borrowed from the caller
  68 + // (e.g. OnlineRnnLM).
  69 + LodrFst* fst_;
  70 + std::unordered_map<int32_t, float> state_cost_;
  71 +};
  72 +
  73 +} // namespace sherpa_onnx
  74 +
  75 +#endif // SHERPA_ONNX_CSRC_LODR_FST_H_
@@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) { @@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) {
18 "Number of threads to run the neural network of LM model"); 18 "Number of threads to run the neural network of LM model");
19 po->Register("lm-provider", &lm_provider, 19 po->Register("lm-provider", &lm_provider,
20 "Specify a provider to LM model use: cpu, cuda, coreml"); 20 "Specify a provider to LM model use: cpu, cuda, coreml");
  21 + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
  22 + po->Register("lodr-scale", &lodr_scale, "LODR scale.");
  23 + po->Register("lodr-backoff-id", &lodr_backoff_id,
  24 + "ID of the backoff in the LODR FST. -1 means autodetect");
21 } 25 }
22 26
23 bool OfflineLMConfig::Validate() const { 27 bool OfflineLMConfig::Validate() const {
@@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const { @@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const {
26 return false; 30 return false;
27 } 31 }
28 32
  33 + if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
  34 + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
  35 + return false;
  36 + }
  37 +
29 return true; 38 return true;
30 } 39 }
31 40
@@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const { @@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const {
34 43
35 os << "OfflineLMConfig("; 44 os << "OfflineLMConfig(";
36 os << "model=\"" << model << "\", "; 45 os << "model=\"" << model << "\", ";
37 - os << "scale=" << scale << ")"; 46 + os << "scale=" << scale << ", ";
  47 + os << "lodr_scale=" << lodr_scale << ", ";
  48 + os << "lodr_fst=\"" << lodr_fst << "\", ";
  49 + os << "lodr_backoff_id=" << lodr_backoff_id << ")";
38 50
39 return os.str(); 51 return os.str();
40 } 52 }
@@ -19,14 +19,23 @@ struct OfflineLMConfig { @@ -19,14 +19,23 @@ struct OfflineLMConfig {
19 int32_t lm_num_threads = 1; 19 int32_t lm_num_threads = 1;
20 std::string lm_provider = "cpu"; 20 std::string lm_provider = "cpu";
21 21
  22 + // LODR
  23 + std::string lodr_fst;
  24 + float lodr_scale = 0.01;
  25 + int32_t lodr_backoff_id = -1; // -1 means not set
  26 +
22 OfflineLMConfig() = default; 27 OfflineLMConfig() = default;
23 28
24 OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, 29 OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
25 - const std::string &lm_provider) 30 + const std::string &lm_provider, const std::string &lodr_fst,
  31 + float lodr_scale, int32_t lodr_backoff_id)
26 : model(model), 32 : model(model),
27 scale(scale), 33 scale(scale),
28 lm_num_threads(lm_num_threads), 34 lm_num_threads(lm_num_threads),
29 - lm_provider(lm_provider) {} 35 + lm_provider(lm_provider),
  36 + lodr_fst(lodr_fst),
  37 + lodr_scale(lodr_scale),
  38 + lodr_backoff_id(lodr_backoff_id) {}
30 39
31 void Register(ParseOptions *po); 40 void Register(ParseOptions *po);
32 bool Validate() const; 41 bool Validate() const;
@@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
17 #include "rawfile/raw_file_manager.h" 17 #include "rawfile/raw_file_manager.h"
18 #endif 18 #endif
19 19
  20 +#include "sherpa-onnx/csrc/lodr-fst.h"
20 #include "sherpa-onnx/csrc/offline-rnn-lm.h" 21 #include "sherpa-onnx/csrc/offline-rnn-lm.h"
21 22
22 namespace sherpa_onnx { 23 namespace sherpa_onnx {
@@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size, @@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
74 } 75 }
75 auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); 76 auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
76 const float *p_nll = negative_loglike.GetTensorData<float>(); 77 const float *p_nll = negative_loglike.GetTensorData<float>();
  78 + // We scale LODR scale with LM scale to replicate Icefall code
  79 + auto lodr_scale = config_.lodr_scale * scale;
77 for (auto &h : *hyps) { 80 for (auto &h : *hyps) {
78 for (auto &t : h) { 81 for (auto &t : h) {
79 // Use -scale here since we want to change negative loglike to loglike. 82 // Use -scale here since we want to change negative loglike to loglike.
80 t.second.lm_log_prob = -scale * (*p_nll); 83 t.second.lm_log_prob = -scale * (*p_nll);
81 ++p_nll; 84 ++p_nll;
  85 + // apply LODR to hyp score
  86 + if (lodr_fst_ != nullptr) {
  87 + lodr_fst_->ComputeScore(lodr_scale, &t.second, context_size);
  88 + }
82 } 89 }
83 } 90 }
84 } 91 }
@@ -10,12 +10,24 @@ @@ -10,12 +10,24 @@
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/hypothesis.h" 12 #include "sherpa-onnx/csrc/hypothesis.h"
  13 +#include "sherpa-onnx/csrc/lodr-fst.h"
13 #include "sherpa-onnx/csrc/offline-lm-config.h" 14 #include "sherpa-onnx/csrc/offline-lm-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
17 class OfflineLM { 18 class OfflineLM {
18 public: 19 public:
  20 + explicit OfflineLM(const OfflineLMConfig &config) : config_(config) {
  21 + if (!config_.lodr_fst.empty()) {
  22 + try {
  23 + lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
  24 + config_.lodr_backoff_id));
  25 + } catch (const std::exception& e) {
  26 + throw std::runtime_error("Failed to load LODR FST from: " +
  27 + config_.lodr_fst + ". Error: " + e.what());
  28 + }
  29 + }
  30 + }
19 virtual ~OfflineLM() = default; 31 virtual ~OfflineLM() = default;
20 32
21 static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); 33 static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
@@ -43,6 +55,11 @@ class OfflineLM { @@ -43,6 +55,11 @@ class OfflineLM {
43 // @param hyps It is changed in-place. 55 // @param hyps It is changed in-place.
44 void ComputeLMScore(float scale, int32_t context_size, 56 void ComputeLMScore(float scale, int32_t context_size,
45 std::vector<Hypotheses> *hyps); 57 std::vector<Hypotheses> *hyps);
  58 +
  59 + private:
  60 + std::unique_ptr<LodrFst> lodr_fst_;
  61 + float lodr_scale_;
  62 + OfflineLMConfig config_;
46 }; 63 };
47 64
48 } // namespace sherpa_onnx 65 } // namespace sherpa_onnx
@@ -83,11 +83,11 @@ class OfflineRnnLM::Impl { @@ -83,11 +83,11 @@ class OfflineRnnLM::Impl {
83 }; 83 };
84 84
85 OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) 85 OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
86 - : impl_(std::make_unique<Impl>(config)) {} 86 + : impl_(std::make_unique<Impl>(config)), OfflineLM(config) {}
87 87
88 template <typename Manager> 88 template <typename Manager>
89 OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config) 89 OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config)
90 - : impl_(std::make_unique<Impl>(mgr, config)) {} 90 + : impl_(std::make_unique<Impl>(mgr, config)), OfflineLM(config) {}
91 91
92 OfflineRnnLM::~OfflineRnnLM() = default; 92 OfflineRnnLM::~OfflineRnnLM() = default;
93 93
@@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) { @@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) {
20 "Specify a provider to LM model use: cpu, cuda, coreml"); 20 "Specify a provider to LM model use: cpu, cuda, coreml");
21 po->Register("lm-shallow-fusion", &shallow_fusion, 21 po->Register("lm-shallow-fusion", &shallow_fusion,
22 "Boolean whether to use shallow fusion or rescore."); 22 "Boolean whether to use shallow fusion or rescore.");
  23 + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model.");
  24 + po->Register("lodr-scale", &lodr_scale, "LODR scale.");
  25 + po->Register("lodr-backoff-id", &lodr_backoff_id,
  26 + "ID of the backoff in the LODR FST. -1 means autodetect");
23 } 27 }
24 28
25 bool OnlineLMConfig::Validate() const { 29 bool OnlineLMConfig::Validate() const {
@@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const { @@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const {
28 return false; 32 return false;
29 } 33 }
30 34
  35 + if (!lodr_fst.empty() && !FileExists(lodr_fst)) {
  36 + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str());
  37 + return false;
  38 + }
  39 +
31 return true; 40 return true;
32 } 41 }
33 42
@@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const { @@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const {
37 os << "OnlineLMConfig("; 46 os << "OnlineLMConfig(";
38 os << "model=\"" << model << "\", "; 47 os << "model=\"" << model << "\", ";
39 os << "scale=" << scale << ", "; 48 os << "scale=" << scale << ", ";
  49 + os << "lodr_scale=" << lodr_scale << ", ";
  50 + os << "lodr_fst=\"" << lodr_fst << "\", ";
  51 + os << "lodr_backoff_id=" << lodr_backoff_id << ", ";
40 os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; 52 os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")";
41 53
42 return os.str(); 54 return os.str();
@@ -18,18 +18,26 @@ struct OnlineLMConfig { @@ -18,18 +18,26 @@ struct OnlineLMConfig {
18 float scale = 0.5; 18 float scale = 0.5;
19 int32_t lm_num_threads = 1; 19 int32_t lm_num_threads = 1;
20 std::string lm_provider = "cpu"; 20 std::string lm_provider = "cpu";
  21 + std::string lodr_fst;
  22 + float lodr_scale = 0.01;
  23 + int32_t lodr_backoff_id = -1; // -1 means not set
21 // enable shallow fusion 24 // enable shallow fusion
22 bool shallow_fusion = true; 25 bool shallow_fusion = true;
23 26
24 OnlineLMConfig() = default; 27 OnlineLMConfig() = default;
25 28
26 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, 29 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
27 - const std::string &lm_provider, bool shallow_fusion) 30 + const std::string &lm_provider, bool shallow_fusion,
  31 + const std::string &lodr_fst, float lodr_scale,
  32 + int32_t lodr_backoff_id)
28 : model(model), 33 : model(model),
29 scale(scale), 34 scale(scale),
30 lm_num_threads(lm_num_threads), 35 lm_num_threads(lm_num_threads),
31 lm_provider(lm_provider), 36 lm_provider(lm_provider),
32 - shallow_fusion(shallow_fusion) {} 37 + shallow_fusion(shallow_fusion),
  38 + lodr_fst(lodr_fst),
  39 + lodr_scale(lodr_scale),
  40 + lodr_backoff_id(lodr_backoff_id) {}
33 41
34 void Register(ParseOptions *po); 42 void Register(ParseOptions *po);
35 bool Validate() const; 43 bool Validate() const;
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 12
13 #include "onnxruntime_cxx_api.h" // NOLINT 13 #include "onnxruntime_cxx_api.h" // NOLINT
14 #include "sherpa-onnx/csrc/file-utils.h" 14 #include "sherpa-onnx/csrc/file-utils.h"
  15 +#include "sherpa-onnx/csrc/lodr-fst.h"
15 #include "sherpa-onnx/csrc/macros.h" 16 #include "sherpa-onnx/csrc/macros.h"
16 #include "sherpa-onnx/csrc/onnx-utils.h" 17 #include "sherpa-onnx/csrc/onnx-utils.h"
17 #include "sherpa-onnx/csrc/session.h" 18 #include "sherpa-onnx/csrc/session.h"
@@ -35,12 +36,27 @@ class OnlineRnnLM::Impl { @@ -35,12 +36,27 @@ class OnlineRnnLM::Impl {
35 auto init_states = GetInitStatesSF(); 36 auto init_states = GetInitStatesSF();
36 hyp->nn_lm_scores.value = std::move(init_states.first); 37 hyp->nn_lm_scores.value = std::move(init_states.first);
37 hyp->nn_lm_states = Convert(std::move(init_states.second)); 38 hyp->nn_lm_states = Convert(std::move(init_states.second));
  39 + // if LODR enabled, we need to initialize the LODR state
  40 + if (lodr_fst_ != nullptr) {
  41 + hyp->lodr_state = std::make_unique<LodrStateCost>(lodr_fst_.get());
  42 + }
38 } 43 }
39 44
40 // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob 45 // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
41 const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); 46 const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
42 hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; 47 hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale;
43 48
  49 + // if LODR enabled, we need to update the LODR state
  50 + if (lodr_fst_ != nullptr) {
  51 + auto next_lodr_state = std::make_unique<LodrStateCost>(
  52 + hyp->lodr_state->ForwardOneStep(hyp->ys.back()));
  53 + // calculate the score of the latest token
  54 + auto score = next_lodr_state->Score() - hyp->lodr_state->Score();
  55 + hyp->lodr_state = std::move(next_lodr_state);
  56 + // apply LODR to hyp score
  57 + hyp->lm_log_prob += score * config_.lodr_scale;
  58 + }
  59 +
44 // get lm scores for next tokens given the hyp->ys[:] and save to 60 // get lm scores for next tokens given the hyp->ys[:] and save to
45 // nn_lm_scores 61 // nn_lm_scores
46 std::array<int64_t, 2> x_shape{1, 1}; 62 std::array<int64_t, 2> x_shape{1, 1};
@@ -89,6 +105,12 @@ class OnlineRnnLM::Impl { @@ -89,6 +105,12 @@ class OnlineRnnLM::Impl {
89 const float *p_nll = out.first.GetTensorData<float>(); 105 const float *p_nll = out.first.GetTensorData<float>();
90 h.lm_log_prob = -scale * (*p_nll); 106 h.lm_log_prob = -scale * (*p_nll);
91 107
  108 + // apply LODR to hyp score
  109 + if (lodr_fst_ != nullptr) {
  110 + // We scale LODR scale with LM scale to replicate Icefall code
  111 + lodr_fst_->ComputeScore(config_.lodr_scale*scale, &h, context_size);
  112 + }
  113 +
92 // update NN LM states in hyp 114 // update NN LM states in hyp
93 h.nn_lm_states = Convert(std::move(out.second)); 115 h.nn_lm_states = Convert(std::move(out.second));
94 116
@@ -154,6 +176,11 @@ class OnlineRnnLM::Impl { @@ -154,6 +176,11 @@ class OnlineRnnLM::Impl {
154 SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); 176 SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
155 177
156 ComputeInitStates(); 178 ComputeInitStates();
  179 +
  180 + if (!config_.lodr_fst.empty()) {
  181 + lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst,
  182 + config_.lodr_backoff_id));
  183 + }
157 } 184 }
158 185
159 void ComputeInitStates() { 186 void ComputeInitStates() {
@@ -203,6 +230,8 @@ class OnlineRnnLM::Impl { @@ -203,6 +230,8 @@ class OnlineRnnLM::Impl {
203 int32_t rnn_num_layers_ = 2; 230 int32_t rnn_num_layers_ = 2;
204 int32_t rnn_hidden_size_ = 512; 231 int32_t rnn_hidden_size_ = 512;
205 int32_t sos_id_ = 1; 232 int32_t sos_id_ = 1;
  233 +
  234 + std::unique_ptr<LodrFst> lodr_fst_;
206 }; 235 };
207 236
208 OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) 237 OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
@@ -13,13 +13,19 @@ namespace sherpa_onnx { @@ -13,13 +13,19 @@ namespace sherpa_onnx {
13 void PybindOfflineLMConfig(py::module *m) { 13 void PybindOfflineLMConfig(py::module *m) {
14 using PyClass = OfflineLMConfig; 14 using PyClass = OfflineLMConfig;
15 py::class_<PyClass>(*m, "OfflineLMConfig") 15 py::class_<PyClass>(*m, "OfflineLMConfig")
16 - .def(py::init<const std::string &, float, int32_t, const std::string &>(), 16 + .def(py::init<const std::string &, float, int32_t, const std::string &,
  17 + const std::string &, float, int32_t>(),
17 py::arg("model"), py::arg("scale") = 0.5f, 18 py::arg("model"), py::arg("scale") = 0.5f,
18 - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") 19 + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
  20 + py::arg("lodr_fst") = "", py::arg("lodr_scale") = 0.0f,
  21 + py::arg("lodr_backoff_id") = -1)
19 .def_readwrite("model", &PyClass::model) 22 .def_readwrite("model", &PyClass::model)
20 .def_readwrite("scale", &PyClass::scale) 23 .def_readwrite("scale", &PyClass::scale)
21 .def_readwrite("lm_provider", &PyClass::lm_provider) 24 .def_readwrite("lm_provider", &PyClass::lm_provider)
22 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) 25 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
  26 + .def_readwrite("lodr_fst", &PyClass::lodr_fst)
  27 + .def_readwrite("lodr_scale", &PyClass::lodr_scale)
  28 + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
23 .def("__str__", &PyClass::ToString); 29 .def("__str__", &PyClass::ToString);
24 } 30 }
25 31
@@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) { @@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) {
14 using PyClass = OnlineLMConfig; 14 using PyClass = OnlineLMConfig;
15 py::class_<PyClass>(*m, "OnlineLMConfig") 15 py::class_<PyClass>(*m, "OnlineLMConfig")
16 .def(py::init<const std::string &, float, int32_t, 16 .def(py::init<const std::string &, float, int32_t,
17 - const std::string &, bool>(), 17 + const std::string &, bool, const std::string &,
  18 + float, int>(),
18 py::arg("model") = "", py::arg("scale") = 0.5f, 19 py::arg("model") = "", py::arg("scale") = 0.5f,
19 py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", 20 py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu",
20 - py::arg("shallow_fusion") = true) 21 + py::arg("shallow_fusion") = true, py::arg("lodr_fst") = "",
  22 + py::arg("lodr_scale") = 0.0f, py::arg("lodr_backoff_id") = -1)
21 .def_readwrite("model", &PyClass::model) 23 .def_readwrite("model", &PyClass::model)
22 .def_readwrite("scale", &PyClass::scale) 24 .def_readwrite("scale", &PyClass::scale)
23 .def_readwrite("lm_provider", &PyClass::lm_provider) 25 .def_readwrite("lm_provider", &PyClass::lm_provider)
24 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) 26 .def_readwrite("lm_num_threads", &PyClass::lm_num_threads)
25 .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) 27 .def_readwrite("shallow_fusion", &PyClass::shallow_fusion)
  28 + .def_readwrite("lodr_fst", &PyClass::lodr_fst)
  29 + .def_readwrite("lodr_scale", &PyClass::lodr_scale)
  30 + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id)
  31 +
26 .def("__str__", &PyClass::ToString); 32 .def("__str__", &PyClass::ToString);
27 } 33 }
28 34
@@ -69,6 +69,8 @@ class OfflineRecognizer(object): @@ -69,6 +69,8 @@ class OfflineRecognizer(object):
69 hr_dict_dir: str = "", 69 hr_dict_dir: str = "",
70 hr_rule_fsts: str = "", 70 hr_rule_fsts: str = "",
71 hr_lexicon: str = "", 71 hr_lexicon: str = "",
  72 + lodr_fst: str = "",
  73 + lodr_scale: float = 0.0,
72 ): 74 ):
73 """ 75 """
74 Please refer to 76 Please refer to
@@ -133,6 +135,10 @@ class OfflineRecognizer(object): @@ -133,6 +135,10 @@ class OfflineRecognizer(object):
133 rule_fars: 135 rule_fars:
134 If not empty, it specifies fst archives for inverse text normalization. 136 If not empty, it specifies fst archives for inverse text normalization.
135 If there are multiple archives, they are separated by a comma. 137 If there are multiple archives, they are separated by a comma.
  138 + lodr_fst:
  139 + Path to the LODR FST file in binary format. If empty, LODR is disabled.
  140 + lodr_scale:
  141 + Scale factor for LODR rescoring. Only used when lodr_fst is provided.
136 """ 142 """
137 self = cls.__new__(cls) 143 self = cls.__new__(cls)
138 model_config = OfflineModelConfig( 144 model_config = OfflineModelConfig(
@@ -173,6 +179,8 @@ class OfflineRecognizer(object): @@ -173,6 +179,8 @@ class OfflineRecognizer(object):
173 scale=lm_scale, 179 scale=lm_scale,
174 lm_num_threads=num_threads, 180 lm_num_threads=num_threads,
175 lm_provider=provider, 181 lm_provider=provider,
  182 + lodr_fst=lodr_fst,
  183 + lodr_scale=lodr_scale,
176 ) 184 )
177 185
178 recognizer_config = OfflineRecognizerConfig( 186 recognizer_config = OfflineRecognizerConfig(
@@ -89,6 +89,8 @@ class OnlineRecognizer(object): @@ -89,6 +89,8 @@ class OnlineRecognizer(object):
89 hr_dict_dir: str = "", 89 hr_dict_dir: str = "",
90 hr_rule_fsts: str = "", 90 hr_rule_fsts: str = "",
91 hr_lexicon: str = "", 91 hr_lexicon: str = "",
  92 + lodr_fst: str = "",
  93 + lodr_scale: float = 0.0,
92 ): 94 ):
93 """ 95 """
94 Please refer to 96 Please refer to
@@ -216,6 +218,10 @@ class OnlineRecognizer(object): @@ -216,6 +218,10 @@ class OnlineRecognizer(object):
216 "Set path for storing timing cache." TensorRT EP 218 "Set path for storing timing cache." TensorRT EP
217 trt_dump_subgraphs: bool = False, 219 trt_dump_subgraphs: bool = False,
218 "Dump optimized subgraphs for debugging." TensorRT EP 220 "Dump optimized subgraphs for debugging." TensorRT EP
  221 + lodr_fst:
  222 + Path to the LODR FST file in binary format. If empty, LODR is disabled.
  223 + lodr_scale:
  224 + Scale factor for LODR rescoring. Only used when lodr_fst is provided.
219 """ 225 """
220 self = cls.__new__(cls) 226 self = cls.__new__(cls)
221 _assert_file_exists(tokens) 227 _assert_file_exists(tokens)
@@ -298,6 +304,8 @@ class OnlineRecognizer(object): @@ -298,6 +304,8 @@ class OnlineRecognizer(object):
298 model=lm, 304 model=lm,
299 scale=lm_scale, 305 scale=lm_scale,
300 shallow_fusion=lm_shallow_fusion, 306 shallow_fusion=lm_shallow_fusion,
  307 + lodr_fst=lodr_fst,
  308 + lodr_scale=lodr_scale,
301 ) 309 )
302 310
303 recognizer_config = OnlineRecognizerConfig( 311 recognizer_config = OnlineRecognizerConfig(