Committed by
GitHub
Add LODR support to online and offline recognizers (#2026)
This PR integrates LODR (Level-Ordered Deterministic Rescoring) support from Icefall into both online and offline recognizers, enabling LODR for LM shallow fusion and LM rescore. - Extended OnlineLMConfig and OfflineLMConfig to include lodr_fst, lodr_scale, and lodr_backoff_id. - Implemented LodrFst and LodrStateCost classes and wired them into RNN LM scoring in both online and offline code paths. - Updated Python bindings, CLI entry points, examples, and CI test scripts to accept and exercise the new LODR options.
正在显示
21 个修改的文件
包含
613 行增加
和
14 行删除
| @@ -281,7 +281,39 @@ time $EXE \ | @@ -281,7 +281,39 @@ time $EXE \ | ||
| 281 | $repo/test_wavs/1.wav \ | 281 | $repo/test_wavs/1.wav \ |
| 282 | $repo/test_wavs/8k.wav | 282 | $repo/test_wavs/8k.wav |
| 283 | 283 | ||
| 284 | -rm -rf $repo | 284 | +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm |
| 285 | +log "Download pre-trained RNN-LM model from ${lm_repo_url}" | ||
| 286 | +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url | ||
| 287 | +lm_repo=$(basename $lm_repo_url) | ||
| 288 | +pushd $lm_repo | ||
| 289 | +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx" | ||
| 290 | +popd | ||
| 291 | + | ||
| 292 | +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 | ||
| 293 | +log "Download bi-gram LM from ${bigram_repo_url}" | ||
| 294 | +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url | ||
| 295 | +bigramlm_repo=$(basename $bigram_repo_url) | ||
| 296 | +pushd $bigramlm_repo | ||
| 297 | +git lfs pull --include "2gram.fst" | ||
| 298 | +popd | ||
| 299 | + | ||
| 300 | +log "Start testing with LM and bi-gram LODR" | ||
| 301 | +# TODO: find test examples that change with the LODR | ||
| 302 | +time $EXE \ | ||
| 303 | + --tokens=$repo/tokens.txt \ | ||
| 304 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 305 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 306 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 307 | + --num-threads=2 \ | ||
| 308 | + --decoding_method="modified_beam_search" \ | ||
| 309 | + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \ | ||
| 310 | + --lodr-fst=$bigramlm_repo/2gram.fst \ | ||
| 311 | + --lodr-scale=-0.5 \ | ||
| 312 | + $repo/test_wavs/0.wav \ | ||
| 313 | + $repo/test_wavs/1.wav \ | ||
| 314 | + $repo/test_wavs/8k.wav | ||
| 315 | + | ||
| 316 | +rm -rf $repo $lm_repo $bigramlm_repo | ||
| 285 | 317 | ||
| 286 | log "------------------------------------------------------------" | 318 | log "------------------------------------------------------------" |
| 287 | log "Run Paraformer (Chinese)" | 319 | log "Run Paraformer (Chinese)" |
| @@ -174,7 +174,60 @@ for wave in ${waves[@]}; do | @@ -174,7 +174,60 @@ for wave in ${waves[@]}; do | ||
| 174 | $wave | 174 | $wave |
| 175 | done | 175 | done |
| 176 | 176 | ||
| 177 | -rm -rf $repo | 177 | +lm_repo_url=https://huggingface.co/vsd-vector/icefall-librispeech-rnn-lm |
| 178 | +log "Download pre-trained RNN-LM model from ${lm_repo_url}" | ||
| 179 | +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url | ||
| 180 | +lm_repo=$(basename $lm_repo_url) | ||
| 181 | +pushd $lm_repo | ||
| 182 | +git lfs pull --include "with-state-epoch-99-avg-1.onnx" | ||
| 183 | +popd | ||
| 184 | + | ||
| 185 | +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 | ||
| 186 | +log "Download bi-gram LM from ${bigram_repo_url}" | ||
| 187 | +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url | ||
| 188 | +bigramlm_repo=$(basename $bigram_repo_url) | ||
| 189 | +pushd $bigramlm_repo | ||
| 190 | +git lfs pull --include "2gram.fst" | ||
| 191 | +popd | ||
| 192 | + | ||
| 193 | +log "Start testing LODR" | ||
| 194 | + | ||
| 195 | +waves=( | ||
| 196 | +$repo/test_wavs/0.wav | ||
| 197 | +$repo/test_wavs/1.wav | ||
| 198 | +$repo/test_wavs/8k.wav | ||
| 199 | +) | ||
| 200 | + | ||
| 201 | +for wave in ${waves[@]}; do | ||
| 202 | + time $EXE \ | ||
| 203 | + --tokens=$repo/tokens.txt \ | ||
| 204 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 205 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 206 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 207 | + --num-threads=2 \ | ||
| 208 | + --decoding_method="modified_beam_search" \ | ||
| 209 | + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \ | ||
| 210 | + --lodr-fst=$bigramlm_repo/2gram.fst \ | ||
| 211 | + --lodr-scale=-0.5 \ | ||
| 212 | + $wave | ||
| 213 | +done | ||
| 214 | + | ||
| 215 | +for wave in ${waves[@]}; do | ||
| 216 | + time $EXE \ | ||
| 217 | + --tokens=$repo/tokens.txt \ | ||
| 218 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 219 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 220 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 221 | + --num-threads=2 \ | ||
| 222 | + --decoding_method="modified_beam_search" \ | ||
| 223 | + --lm=$lm_repo/with-state-epoch-99-avg-1.onnx \ | ||
| 224 | + --lodr-fst=$bigramlm_repo/2gram.fst \ | ||
| 225 | + --lodr-scale=-0.5 \ | ||
| 226 | + --lm-shallow-fusion=true \ | ||
| 227 | + $wave | ||
| 228 | +done | ||
| 229 | + | ||
| 230 | +rm -rf $repo $bigramlm_repo $lm_repo | ||
| 178 | 231 | ||
| 179 | log "------------------------------------------------------------" | 232 | log "------------------------------------------------------------" |
| 180 | log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" | 233 | log "Run streaming Zipformer transducer (Bilingual, Chinese + English)" |
| @@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -562,9 +562,39 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 562 | $repo/test_wavs/1.wav \ | 562 | $repo/test_wavs/1.wav \ |
| 563 | $repo/test_wavs/8k.wav | 563 | $repo/test_wavs/8k.wav |
| 564 | 564 | ||
| 565 | +lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm | ||
| 566 | +log "Download pre-trained RNN-LM model from ${lm_repo_url}" | ||
| 567 | +GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url | ||
| 568 | +lm_repo=$(basename $lm_repo_url) | ||
| 569 | +pushd $lm_repo | ||
| 570 | +git lfs pull --include "exp/no-state-epoch-99-avg-1.onnx" | ||
| 571 | +popd | ||
| 572 | + | ||
| 573 | +bigram_repo_url=https://huggingface.co/vsd-vector/librispeech_bigram_sherpa-onnx-zipformer-large-en-2023-06-26 | ||
| 574 | +log "Download bi-gram LM from ${bigram_repo_url}" | ||
| 575 | +GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url | ||
| 576 | +bigramlm_repo=$(basename $bigram_repo_url) | ||
| 577 | +pushd $bigramlm_repo | ||
| 578 | +git lfs pull --include "2gram.fst" | ||
| 579 | +popd | ||
| 580 | + | ||
| 581 | +log "Perform offline decoding with RNN-LM and LODR" | ||
| 582 | +python3 ./python-api-examples/offline-decode-files.py \ | ||
| 583 | + --tokens=$repo/tokens.txt \ | ||
| 584 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 585 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 586 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 587 | + --decoding-method=modified_beam_search \ | ||
| 588 | + --lm=$lm_repo/exp/no-state-epoch-99-avg-1.onnx \ | ||
| 589 | + --lodr-fst=$bigramlm_repo/2gram.fst \ | ||
| 590 | + --lodr-scale=-0.5 \ | ||
| 591 | + $repo/test_wavs/0.wav \ | ||
| 592 | + $repo/test_wavs/1.wav \ | ||
| 593 | + $repo/test_wavs/8k.wav | ||
| 594 | + | ||
| 565 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | 595 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
| 566 | 596 | ||
| 567 | -rm -rf $repo | 597 | +rm -rf $repo $lm_repo $bigramlm_repo |
| 568 | 598 | ||
| 569 | log "Test non-streaming paraformer models" | 599 | log "Test non-streaming paraformer models" |
| 570 | 600 |
| @@ -35,6 +35,25 @@ file(s) with a non-streaming model. | @@ -35,6 +35,25 @@ file(s) with a non-streaming model. | ||
| 35 | /path/to/0.wav \ | 35 | /path/to/0.wav \ |
| 36 | /path/to/1.wav | 36 | /path/to/1.wav |
| 37 | 37 | ||
| 38 | + also with RNN LM rescoring and LODR (optional): | ||
| 39 | + | ||
| 40 | + ./python-api-examples/offline-decode-files.py \ | ||
| 41 | + --tokens=/path/to/tokens.txt \ | ||
| 42 | + --encoder=/path/to/encoder.onnx \ | ||
| 43 | + --decoder=/path/to/decoder.onnx \ | ||
| 44 | + --joiner=/path/to/joiner.onnx \ | ||
| 45 | + --num-threads=2 \ | ||
| 46 | + --decoding-method=modified_beam_search \ | ||
| 47 | + --debug=false \ | ||
| 48 | + --sample-rate=16000 \ | ||
| 49 | + --feature-dim=80 \ | ||
| 50 | + --lm=/path/to/lm.onnx \ | ||
| 51 | + --lm-scale=0.1 \ | ||
| 52 | + --lodr-fst=/path/to/lodr.fst \ | ||
| 53 | + --lodr-scale=-0.1 \ | ||
| 54 | + /path/to/0.wav \ | ||
| 55 | + /path/to/1.wav | ||
| 56 | + | ||
| 38 | (3) For CTC models from NeMo | 57 | (3) For CTC models from NeMo |
| 39 | 58 | ||
| 40 | python3 ./python-api-examples/offline-decode-files.py \ | 59 | python3 ./python-api-examples/offline-decode-files.py \ |
| @@ -269,6 +288,39 @@ def get_args(): | @@ -269,6 +288,39 @@ def get_args(): | ||
| 269 | default="greedy_search", | 288 | default="greedy_search", |
| 270 | help="Valid values are greedy_search and modified_beam_search", | 289 | help="Valid values are greedy_search and modified_beam_search", |
| 271 | ) | 290 | ) |
| 291 | + | ||
| 292 | + parser.add_argument( | ||
| 293 | + "--lm", | ||
| 294 | + metavar="file", | ||
| 295 | + type=str, | ||
| 296 | + default="", | ||
| 297 | + help="Path to RNN LM model", | ||
| 298 | + ) | ||
| 299 | + | ||
| 300 | + parser.add_argument( | ||
| 301 | + "--lm-scale", | ||
| 302 | + metavar="lm_scale", | ||
| 303 | + type=float, | ||
| 304 | + default=0.1, | ||
| 305 | + help="LM model scale for rescoring", | ||
| 306 | + ) | ||
| 307 | + | ||
| 308 | + parser.add_argument( | ||
| 309 | + "--lodr-fst", | ||
| 310 | + metavar="file", | ||
| 311 | + type=str, | ||
| 312 | + default="", | ||
| 313 | + help="Path to LODR FST model. Used only when --lm is given.", | ||
| 314 | + ) | ||
| 315 | + | ||
| 316 | + parser.add_argument( | ||
| 317 | + "--lodr-scale", | ||
| 318 | + metavar="lodr_scale", | ||
| 319 | + type=float, | ||
| 320 | + default=-0.1, | ||
| 321 | + help="LODR scale for rescoring.Used only when --lodr_fst is given.", | ||
| 322 | + ) | ||
| 323 | + | ||
| 272 | parser.add_argument( | 324 | parser.add_argument( |
| 273 | "--debug", | 325 | "--debug", |
| 274 | type=bool, | 326 | type=bool, |
| @@ -364,6 +416,10 @@ def main(): | @@ -364,6 +416,10 @@ def main(): | ||
| 364 | num_threads=args.num_threads, | 416 | num_threads=args.num_threads, |
| 365 | sample_rate=args.sample_rate, | 417 | sample_rate=args.sample_rate, |
| 366 | feature_dim=args.feature_dim, | 418 | feature_dim=args.feature_dim, |
| 419 | + lm=args.lm, | ||
| 420 | + lm_scale=args.lm_scale, | ||
| 421 | + lodr_fst=args.lodr_fst, | ||
| 422 | + lodr_scale=args.lodr_scale, | ||
| 367 | decoding_method=args.decoding_method, | 423 | decoding_method=args.decoding_method, |
| 368 | hotwords_file=args.hotwords_file, | 424 | hotwords_file=args.hotwords_file, |
| 369 | hotwords_score=args.hotwords_score, | 425 | hotwords_score=args.hotwords_score, |
| @@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 | @@ -21,6 +21,22 @@ rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 | ||
| 21 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ | 21 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ |
| 22 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav | 22 | ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav |
| 23 | 23 | ||
| 24 | +or with RNN LM rescoring and LODR: | ||
| 25 | + | ||
| 26 | +./python-api-examples/online-decode-files.py \ | ||
| 27 | + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \ | ||
| 28 | + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 29 | + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 30 | + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ | ||
| 31 | + --decoding-method=modified_beam_search \ | ||
| 32 | + --lm=/path/to/lm.onnx \ | ||
| 33 | + --lm-scale=0.1 \ | ||
| 34 | + --lodr-fst=/path/to/lodr.fst \ | ||
| 35 | + --lodr-scale=-0.1 \ | ||
| 36 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \ | ||
| 37 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ | ||
| 38 | + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav | ||
| 39 | + | ||
| 24 | (2) Streaming paraformer | 40 | (2) Streaming paraformer |
| 25 | 41 | ||
| 26 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 | 42 | curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 |
| @@ -187,6 +203,22 @@ def get_args(): | @@ -187,6 +203,22 @@ def get_args(): | ||
| 187 | ) | 203 | ) |
| 188 | 204 | ||
| 189 | parser.add_argument( | 205 | parser.add_argument( |
| 206 | + "--lodr-fst", | ||
| 207 | + metavar="file", | ||
| 208 | + type=str, | ||
| 209 | + default="", | ||
| 210 | + help="Path to LODR FST model. Used only when --lm is given.", | ||
| 211 | + ) | ||
| 212 | + | ||
| 213 | + parser.add_argument( | ||
| 214 | + "--lodr-scale", | ||
| 215 | + metavar="lodr_scale", | ||
| 216 | + type=float, | ||
| 217 | + default=-0.1, | ||
| 218 | + help="LODR scale for rescoring.Used only when --lodr_fst is given.", | ||
| 219 | + ) | ||
| 220 | + | ||
| 221 | + parser.add_argument( | ||
| 190 | "--provider", | 222 | "--provider", |
| 191 | type=str, | 223 | type=str, |
| 192 | default="cpu", | 224 | default="cpu", |
| @@ -320,6 +352,8 @@ def main(): | @@ -320,6 +352,8 @@ def main(): | ||
| 320 | max_active_paths=args.max_active_paths, | 352 | max_active_paths=args.max_active_paths, |
| 321 | lm=args.lm, | 353 | lm=args.lm, |
| 322 | lm_scale=args.lm_scale, | 354 | lm_scale=args.lm_scale, |
| 355 | + lodr_fst=args.lodr_fst, | ||
| 356 | + lodr_scale=args.lodr_scale, | ||
| 323 | hotwords_file=args.hotwords_file, | 357 | hotwords_file=args.hotwords_file, |
| 324 | hotwords_score=args.hotwords_score, | 358 | hotwords_score=args.hotwords_score, |
| 325 | modeling_unit=args.modeling_unit, | 359 | modeling_unit=args.modeling_unit, |
| @@ -25,6 +25,7 @@ set(sources | @@ -25,6 +25,7 @@ set(sources | ||
| 25 | jieba.cc | 25 | jieba.cc |
| 26 | keyword-spotter-impl.cc | 26 | keyword-spotter-impl.cc |
| 27 | keyword-spotter.cc | 27 | keyword-spotter.cc |
| 28 | + lodr-fst.cc | ||
| 28 | offline-canary-model-config.cc | 29 | offline-canary-model-config.cc |
| 29 | offline-canary-model.cc | 30 | offline-canary-model.cc |
| 30 | offline-ctc-fst-decoder-config.cc | 31 | offline-ctc-fst-decoder-config.cc |
| @@ -12,9 +12,11 @@ | @@ -12,9 +12,11 @@ | ||
| 12 | #include <unordered_map> | 12 | #include <unordered_map> |
| 13 | #include <utility> | 13 | #include <utility> |
| 14 | #include <vector> | 14 | #include <vector> |
| 15 | +#include <memory> | ||
| 15 | 16 | ||
| 16 | #include "onnxruntime_cxx_api.h" // NOLINT | 17 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 17 | #include "sherpa-onnx/csrc/context-graph.h" | 18 | #include "sherpa-onnx/csrc/context-graph.h" |
| 19 | +#include "sherpa-onnx/csrc/lodr-fst.h" | ||
| 18 | #include "sherpa-onnx/csrc/math.h" | 20 | #include "sherpa-onnx/csrc/math.h" |
| 19 | #include "sherpa-onnx/csrc/onnx-utils.h" | 21 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 20 | 22 | ||
| @@ -61,6 +63,9 @@ struct Hypothesis { | @@ -61,6 +63,9 @@ struct Hypothesis { | ||
| 61 | // the nn lm states | 63 | // the nn lm states |
| 62 | std::vector<CopyableOrtValue> nn_lm_states; | 64 | std::vector<CopyableOrtValue> nn_lm_states; |
| 63 | 65 | ||
| 66 | + // the LODR states | ||
| 67 | + std::shared_ptr<LodrStateCost> lodr_state; | ||
| 68 | + | ||
| 64 | const ContextState *context_state; | 69 | const ContextState *context_state; |
| 65 | 70 | ||
| 66 | // TODO(fangjun): Make it configurable | 71 | // TODO(fangjun): Make it configurable |
sherpa-onnx/csrc/lodr-fst.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/lodr-fst.cc | ||
| 2 | +// | ||
| 3 | +// Contains code copied from icefall/utils/ngram_lm.py | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | +// | ||
| 6 | +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs) | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/lodr-fst.h" | ||
| 13 | +#include "sherpa-onnx/csrc/log.h" | ||
| 14 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 15 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +int32_t LodrFst::FindBackoffId() { | ||
| 20 | + // assume that the backoff id is the only input label with epsilon output | ||
| 21 | + | ||
| 22 | + for (int32_t state = 0; state < fst_->NumStates(); ++state) { | ||
| 23 | + fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state); | ||
| 24 | + for ( ; !arc_iter.Done(); arc_iter.Next()) { | ||
| 25 | + const auto& arc = arc_iter.Value(); | ||
| 26 | + if (arc.olabel == 0) { // Check if the output label is epsilon (0) | ||
| 27 | + return arc.ilabel; // Return the input label | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + return -1; // Return -1 if no such input symbol is found | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +LodrFst::LodrFst(const std::string &fst_path, int32_t backoff_id) | ||
| 36 | + : backoff_id_(backoff_id) { | ||
| 37 | + fst_ = std::unique_ptr<fst::StdConstFst>( | ||
| 38 | + CastOrConvertToConstFst(fst::StdVectorFst::Read(fst_path))); | ||
| 39 | + | ||
| 40 | + if (backoff_id < 0) { | ||
| 41 | + // backoff_id_ is not provided, find it automatically | ||
| 42 | + backoff_id_ = FindBackoffId(); | ||
| 43 | + if (backoff_id_ < 0) { | ||
| 44 | + std::string err_msg = "Failed to initialize LODR: No backoff arc found"; | ||
| 45 | + SHERPA_ONNX_LOGE("%s", err_msg.c_str()); | ||
| 46 | + SHERPA_ONNX_EXIT(-1); | ||
| 47 | + } | ||
| 48 | + } | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +std::vector<std::tuple<int32_t, float>> LodrFst::ProcessBackoffArcs( | ||
| 52 | + int32_t state, float cost) { | ||
| 53 | + std::vector<std::tuple<int32_t, float>> ans; | ||
| 54 | + auto next = GetNextStatesCostsNoBackoff(state, backoff_id_); | ||
| 55 | + if (!next.has_value()) { | ||
| 56 | + return ans; | ||
| 57 | + } | ||
| 58 | + auto [next_state, next_cost] = next.value(); | ||
| 59 | + ans.emplace_back(next_state, next_cost + cost); | ||
| 60 | + auto recursive_result = ProcessBackoffArcs(next_state, next_cost + cost); | ||
| 61 | + ans.insert(ans.end(), recursive_result.begin(), recursive_result.end()); | ||
| 62 | + return ans; | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +std::optional<std::tuple<int32_t, float>> LodrFst::GetNextStatesCostsNoBackoff( | ||
| 66 | + int32_t state, int32_t label) { | ||
| 67 | + fst::ArcIterator<fst::StdConstFst> arc_iter(*fst_, state); | ||
| 68 | + int32_t num_arcs = fst_->NumArcs(state); | ||
| 69 | + | ||
| 70 | + int32_t left = 0, right = num_arcs - 1; | ||
| 71 | + while (left <= right) { | ||
| 72 | + int32_t mid = (left + right) / 2; | ||
| 73 | + arc_iter.Seek(mid); | ||
| 74 | + auto arc = arc_iter.Value(); | ||
| 75 | + if (arc.ilabel < label) { | ||
| 76 | + left = mid + 1; | ||
| 77 | + } else if (arc.ilabel > label) { | ||
| 78 | + right = mid - 1; | ||
| 79 | + } else { | ||
| 80 | + return std::make_tuple(arc.nextstate, arc.weight.Value()); | ||
| 81 | + } | ||
| 82 | + } | ||
| 83 | + return std::nullopt; | ||
| 84 | +} | ||
| 85 | + | ||
| 86 | +std::pair<std::vector<int32_t>, std::vector<float>> LodrFst::GetNextStateCosts( | ||
| 87 | + int32_t state, int32_t label) { | ||
| 88 | + std::vector<int32_t> states = {state}; | ||
| 89 | + std::vector<float> costs = {0}; | ||
| 90 | + | ||
| 91 | + auto extra_states_costs = ProcessBackoffArcs(state, 0); | ||
| 92 | + for (const auto& [s, c] : extra_states_costs) { | ||
| 93 | + states.push_back(s); | ||
| 94 | + costs.push_back(c); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + std::vector<int32_t> next_states; | ||
| 98 | + std::vector<float> next_costs; | ||
| 99 | + for (size_t i = 0; i < states.size(); ++i) { | ||
| 100 | + auto next = GetNextStatesCostsNoBackoff(states[i], label); | ||
| 101 | + if (next.has_value()) { | ||
| 102 | + auto [ns, nc] = next.value(); | ||
| 103 | + next_states.push_back(ns); | ||
| 104 | + next_costs.push_back(costs[i] + nc); | ||
| 105 | + } | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + return std::make_pair(next_states, next_costs); | ||
| 109 | +} | ||
| 110 | + | ||
| 111 | +void LodrFst::ComputeScore(float scale, Hypothesis *hyp, int32_t offset) { | ||
| 112 | + if (scale == 0) { | ||
| 113 | + return; | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + hyp->lodr_state = std::make_unique<LodrStateCost>(this); | ||
| 117 | + | ||
| 118 | + // Walk through the FST with the input text from the hypothesis | ||
| 119 | + for (size_t i = offset; i < hyp->ys.size(); ++i) { | ||
| 120 | + *hyp->lodr_state = hyp->lodr_state->ForwardOneStep(hyp->ys[i]); | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + float lodr_score = hyp->lodr_state->FinalScore(); | ||
| 124 | + | ||
| 125 | + if (lodr_score == -std::numeric_limits<float>::infinity()) { | ||
| 126 | + SHERPA_ONNX_LOGE("Failed to compute LODR. Empty or mismatched FST?"); | ||
| 127 | + return; | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + // Update the hyp score | ||
| 131 | + hyp->log_prob += scale * lodr_score; | ||
| 132 | +} | ||
| 133 | + | ||
| 134 | +float LodrFst::GetFinalCost(int32_t state) { | ||
| 135 | + auto final_weight = fst_->Final(state); | ||
| 136 | + if (final_weight == fst::StdArc::Weight::Zero()) { | ||
| 137 | + return 0.0; | ||
| 138 | + } | ||
| 139 | + return final_weight.Value(); | ||
| 140 | +} | ||
| 141 | + | ||
| 142 | +LodrStateCost::LodrStateCost( | ||
| 143 | + LodrFst* fst, const std::unordered_map<int32_t, float> &state_cost) | ||
| 144 | + : fst_(fst) { | ||
| 145 | + if (state_cost.empty()) { | ||
| 146 | + state_cost_[0] = 0.0; | ||
| 147 | + } else { | ||
| 148 | + state_cost_ = state_cost; | ||
| 149 | + } | ||
| 150 | +} | ||
| 151 | + | ||
| 152 | +LodrStateCost LodrStateCost::ForwardOneStep(int32_t label) { | ||
| 153 | + std::unordered_map<int32_t, float> state_cost; | ||
| 154 | + for (const auto& [s, c] : state_cost_) { | ||
| 155 | + auto [next_states, next_costs] = fst_->GetNextStateCosts(s, label); | ||
| 156 | + for (size_t i = 0; i < next_states.size(); ++i) { | ||
| 157 | + int32_t ns = next_states[i]; | ||
| 158 | + float nc = next_costs[i]; | ||
| 159 | + if (state_cost.find(ns) == state_cost.end()) { | ||
| 160 | + state_cost[ns] = std::numeric_limits<float>::infinity(); | ||
| 161 | + } | ||
| 162 | + state_cost[ns] = std::min(state_cost[ns], c + nc); | ||
| 163 | + } | ||
| 164 | + } | ||
| 165 | + return LodrStateCost(fst_, state_cost); | ||
| 166 | +} | ||
| 167 | + | ||
| 168 | +float LodrStateCost::Score() const { | ||
| 169 | + if (state_cost_.empty()) { | ||
| 170 | + return -std::numeric_limits<float>::infinity(); | ||
| 171 | + } | ||
| 172 | + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(), | ||
| 173 | + [](const auto& a, const auto& b) { | ||
| 174 | + return a.second < b.second; | ||
| 175 | + }); | ||
| 176 | + return -min_cost->second; | ||
| 177 | +} | ||
| 178 | + | ||
| 179 | +float LodrStateCost::FinalScore() const { | ||
| 180 | + if (state_cost_.empty()) { | ||
| 181 | + return -std::numeric_limits<float>::infinity(); | ||
| 182 | + } | ||
| 183 | + auto min_cost = std::min_element(state_cost_.begin(), state_cost_.end(), | ||
| 184 | + [](const auto& a, const auto& b) { | ||
| 185 | + return a.second < b.second; | ||
| 186 | + }); | ||
| 187 | + return -(min_cost->second + | ||
| 188 | + fst_->GetFinalCost(min_cost->first)); | ||
| 189 | +} | ||
| 190 | + | ||
| 191 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/lodr-fst.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/lodr-fst.h | ||
| 2 | +// | ||
| 3 | +// Contains code copied from icefall/utils/ngram_lm.py | ||
| 4 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 5 | +// | ||
| 6 | +// Copyright (c) 2025 Tilde SIA (Askars Salimbajevs) | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +#ifndef SHERPA_ONNX_CSRC_LODR_FST_H_ | ||
| 10 | +#define SHERPA_ONNX_CSRC_LODR_FST_H_ | ||
| 11 | + | ||
| 12 | +#include <memory> | ||
| 13 | +#include <string> | ||
| 14 | +#include <vector> | ||
| 15 | +#include <optional> | ||
| 16 | +#include <tuple> | ||
| 17 | +#include <unordered_map> | ||
| 18 | +#include <limits> | ||
| 19 | +#include <algorithm> | ||
| 20 | +#include <utility> | ||
| 21 | + | ||
| 22 | +#include "kaldifst/csrc/kaldi-fst-io.h" | ||
| 23 | + | ||
| 24 | +namespace sherpa_onnx { | ||
| 25 | + | ||
| 26 | +class Hypothesis; | ||
| 27 | + | ||
| 28 | +class LodrFst { | ||
| 29 | + public: | ||
| 30 | + explicit LodrFst(const std::string &fst_path, int32_t backoff_id = -1); | ||
| 31 | + | ||
| 32 | + std::pair<std::vector<int32_t>, std::vector<float>> GetNextStateCosts( | ||
| 33 | + int32_t state, int32_t label); | ||
| 34 | + | ||
| 35 | + float GetFinalCost(int32_t state); | ||
| 36 | + | ||
| 37 | + void ComputeScore(float scale, Hypothesis *hyp, int32_t offset); | ||
| 38 | + | ||
| 39 | + private: | ||
| 40 | + fst::StdVectorFst YsToFst(const std::vector<int64_t> &ys, int32_t offset); | ||
| 41 | + | ||
| 42 | + std::vector<std::tuple<int32_t, float>> ProcessBackoffArcs( | ||
| 43 | + int32_t state, float cost); | ||
| 44 | + | ||
| 45 | + std::optional<std::tuple<int32_t, float>> GetNextStatesCostsNoBackoff( | ||
| 46 | + int32_t state, int32_t label); | ||
| 47 | + | ||
| 48 | + int32_t FindBackoffId(); | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + int32_t backoff_id_ = -1; | ||
| 52 | + std::unique_ptr<fst::StdConstFst> fst_; // owned by this class | ||
| 53 | +}; | ||
| 54 | + | ||
| 55 | +class LodrStateCost { | ||
| 56 | + public: | ||
| 57 | + explicit LodrStateCost( | ||
| 58 | + LodrFst* fst, | ||
| 59 | + const std::unordered_map<int32_t, float> &state_cost = {}); | ||
| 60 | + | ||
| 61 | + LodrStateCost ForwardOneStep(int32_t label); | ||
| 62 | + | ||
| 63 | + float Score() const; | ||
| 64 | + float FinalScore() const; | ||
| 65 | + | ||
| 66 | + private: | ||
| 67 | + // The fst_ is not owned by this class and borrowed from the caller | ||
| 68 | + // (e.g. OnlineRnnLM). | ||
| 69 | + LodrFst* fst_; | ||
| 70 | + std::unordered_map<int32_t, float> state_cost_; | ||
| 71 | +}; | ||
| 72 | + | ||
| 73 | +} // namespace sherpa_onnx | ||
| 74 | + | ||
| 75 | +#endif // SHERPA_ONNX_CSRC_LODR_FST_H_ |
| @@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) { | @@ -18,6 +18,10 @@ void OfflineLMConfig::Register(ParseOptions *po) { | ||
| 18 | "Number of threads to run the neural network of LM model"); | 18 | "Number of threads to run the neural network of LM model"); |
| 19 | po->Register("lm-provider", &lm_provider, | 19 | po->Register("lm-provider", &lm_provider, |
| 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); | 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); |
| 21 | + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model."); | ||
| 22 | + po->Register("lodr-scale", &lodr_scale, "LODR scale."); | ||
| 23 | + po->Register("lodr-backoff-id", &lodr_backoff_id, | ||
| 24 | + "ID of the backoff in the LODR FST. -1 means autodetect"); | ||
| 21 | } | 25 | } |
| 22 | 26 | ||
| 23 | bool OfflineLMConfig::Validate() const { | 27 | bool OfflineLMConfig::Validate() const { |
| @@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const { | @@ -26,6 +30,11 @@ bool OfflineLMConfig::Validate() const { | ||
| 26 | return false; | 30 | return false; |
| 27 | } | 31 | } |
| 28 | 32 | ||
| 33 | + if (!lodr_fst.empty() && !FileExists(lodr_fst)) { | ||
| 34 | + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str()); | ||
| 35 | + return false; | ||
| 36 | + } | ||
| 37 | + | ||
| 29 | return true; | 38 | return true; |
| 30 | } | 39 | } |
| 31 | 40 | ||
| @@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const { | @@ -34,7 +43,10 @@ std::string OfflineLMConfig::ToString() const { | ||
| 34 | 43 | ||
| 35 | os << "OfflineLMConfig("; | 44 | os << "OfflineLMConfig("; |
| 36 | os << "model=\"" << model << "\", "; | 45 | os << "model=\"" << model << "\", "; |
| 37 | - os << "scale=" << scale << ")"; | 46 | + os << "scale=" << scale << ", "; |
| 47 | + os << "lodr_scale=" << lodr_scale << ", "; | ||
| 48 | + os << "lodr_fst=\"" << lodr_fst << "\", "; | ||
| 49 | + os << "lodr_backoff_id=" << lodr_backoff_id << ")"; | ||
| 38 | 50 | ||
| 39 | return os.str(); | 51 | return os.str(); |
| 40 | } | 52 | } |
| @@ -19,14 +19,23 @@ struct OfflineLMConfig { | @@ -19,14 +19,23 @@ struct OfflineLMConfig { | ||
| 19 | int32_t lm_num_threads = 1; | 19 | int32_t lm_num_threads = 1; |
| 20 | std::string lm_provider = "cpu"; | 20 | std::string lm_provider = "cpu"; |
| 21 | 21 | ||
| 22 | + // LODR | ||
| 23 | + std::string lodr_fst; | ||
| 24 | + float lodr_scale = 0.01; | ||
| 25 | + int32_t lodr_backoff_id = -1; // -1 means not set | ||
| 26 | + | ||
| 22 | OfflineLMConfig() = default; | 27 | OfflineLMConfig() = default; |
| 23 | 28 | ||
| 24 | OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, | 29 | OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 25 | - const std::string &lm_provider) | 30 | + const std::string &lm_provider, const std::string &lodr_fst, |
| 31 | + float lodr_scale, int32_t lodr_backoff_id) | ||
| 26 | : model(model), | 32 | : model(model), |
| 27 | scale(scale), | 33 | scale(scale), |
| 28 | lm_num_threads(lm_num_threads), | 34 | lm_num_threads(lm_num_threads), |
| 29 | - lm_provider(lm_provider) {} | 35 | + lm_provider(lm_provider), |
| 36 | + lodr_fst(lodr_fst), | ||
| 37 | + lodr_scale(lodr_scale), | ||
| 38 | + lodr_backoff_id(lodr_backoff_id) {} | ||
| 30 | 39 | ||
| 31 | void Register(ParseOptions *po); | 40 | void Register(ParseOptions *po); |
| 32 | bool Validate() const; | 41 | bool Validate() const; |
| @@ -17,6 +17,7 @@ | @@ -17,6 +17,7 @@ | ||
| 17 | #include "rawfile/raw_file_manager.h" | 17 | #include "rawfile/raw_file_manager.h" |
| 18 | #endif | 18 | #endif |
| 19 | 19 | ||
| 20 | +#include "sherpa-onnx/csrc/lodr-fst.h" | ||
| 20 | #include "sherpa-onnx/csrc/offline-rnn-lm.h" | 21 | #include "sherpa-onnx/csrc/offline-rnn-lm.h" |
| 21 | 22 | ||
| 22 | namespace sherpa_onnx { | 23 | namespace sherpa_onnx { |
| @@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size, | @@ -74,11 +75,17 @@ void OfflineLM::ComputeLMScore(float scale, int32_t context_size, | ||
| 74 | } | 75 | } |
| 75 | auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); | 76 | auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); |
| 76 | const float *p_nll = negative_loglike.GetTensorData<float>(); | 77 | const float *p_nll = negative_loglike.GetTensorData<float>(); |
| 78 | + // We scale LODR scale with LM scale to replicate Icefall code | ||
| 79 | + auto lodr_scale = config_.lodr_scale * scale; | ||
| 77 | for (auto &h : *hyps) { | 80 | for (auto &h : *hyps) { |
| 78 | for (auto &t : h) { | 81 | for (auto &t : h) { |
| 79 | // Use -scale here since we want to change negative loglike to loglike. | 82 | // Use -scale here since we want to change negative loglike to loglike. |
| 80 | t.second.lm_log_prob = -scale * (*p_nll); | 83 | t.second.lm_log_prob = -scale * (*p_nll); |
| 81 | ++p_nll; | 84 | ++p_nll; |
| 85 | + // apply LODR to hyp score | ||
| 86 | + if (lodr_fst_ != nullptr) { | ||
| 87 | + lodr_fst_->ComputeScore(lodr_scale, &t.second, context_size); | ||
| 88 | + } | ||
| 82 | } | 89 | } |
| 83 | } | 90 | } |
| 84 | } | 91 | } |
| @@ -10,12 +10,24 @@ | @@ -10,12 +10,24 @@ | ||
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/hypothesis.h" | 12 | #include "sherpa-onnx/csrc/hypothesis.h" |
| 13 | +#include "sherpa-onnx/csrc/lodr-fst.h" | ||
| 13 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 14 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| 17 | class OfflineLM { | 18 | class OfflineLM { |
| 18 | public: | 19 | public: |
| 20 | + explicit OfflineLM(const OfflineLMConfig &config) : config_(config) { | ||
| 21 | + if (!config_.lodr_fst.empty()) { | ||
| 22 | + try { | ||
| 23 | + lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst, | ||
| 24 | + config_.lodr_backoff_id)); | ||
| 25 | + } catch (const std::exception& e) { | ||
| 26 | + throw std::runtime_error("Failed to load LODR FST from: " + | ||
| 27 | + config_.lodr_fst + ". Error: " + e.what()); | ||
| 28 | + } | ||
| 29 | + } | ||
| 30 | + } | ||
| 19 | virtual ~OfflineLM() = default; | 31 | virtual ~OfflineLM() = default; |
| 20 | 32 | ||
| 21 | static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); | 33 | static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); |
| @@ -43,6 +55,11 @@ class OfflineLM { | @@ -43,6 +55,11 @@ class OfflineLM { | ||
| 43 | // @param hyps It is changed in-place. | 55 | // @param hyps It is changed in-place. |
| 44 | void ComputeLMScore(float scale, int32_t context_size, | 56 | void ComputeLMScore(float scale, int32_t context_size, |
| 45 | std::vector<Hypotheses> *hyps); | 57 | std::vector<Hypotheses> *hyps); |
| 58 | + | ||
| 59 | + private: | ||
| 60 | + std::unique_ptr<LodrFst> lodr_fst_; | ||
| 61 | + float lodr_scale_; | ||
| 62 | + OfflineLMConfig config_; | ||
| 46 | }; | 63 | }; |
| 47 | 64 | ||
| 48 | } // namespace sherpa_onnx | 65 | } // namespace sherpa_onnx |
| @@ -83,11 +83,11 @@ class OfflineRnnLM::Impl { | @@ -83,11 +83,11 @@ class OfflineRnnLM::Impl { | ||
| 83 | }; | 83 | }; |
| 84 | 84 | ||
| 85 | OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) | 85 | OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) |
| 86 | - : impl_(std::make_unique<Impl>(config)) {} | 86 | + : impl_(std::make_unique<Impl>(config)), OfflineLM(config) {} |
| 87 | 87 | ||
| 88 | template <typename Manager> | 88 | template <typename Manager> |
| 89 | OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config) | 89 | OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config) |
| 90 | - : impl_(std::make_unique<Impl>(mgr, config)) {} | 90 | + : impl_(std::make_unique<Impl>(mgr, config)), OfflineLM(config) {} |
| 91 | 91 | ||
| 92 | OfflineRnnLM::~OfflineRnnLM() = default; | 92 | OfflineRnnLM::~OfflineRnnLM() = default; |
| 93 | 93 |
| @@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) { | @@ -20,6 +20,10 @@ void OnlineLMConfig::Register(ParseOptions *po) { | ||
| 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); | 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); |
| 21 | po->Register("lm-shallow-fusion", &shallow_fusion, | 21 | po->Register("lm-shallow-fusion", &shallow_fusion, |
| 22 | "Boolean whether to use shallow fusion or rescore."); | 22 | "Boolean whether to use shallow fusion or rescore."); |
| 23 | + po->Register("lodr-fst", &lodr_fst, "Path to LODR FST model."); | ||
| 24 | + po->Register("lodr-scale", &lodr_scale, "LODR scale."); | ||
| 25 | + po->Register("lodr-backoff-id", &lodr_backoff_id, | ||
| 26 | + "ID of the backoff in the LODR FST. -1 means autodetect"); | ||
| 23 | } | 27 | } |
| 24 | 28 | ||
| 25 | bool OnlineLMConfig::Validate() const { | 29 | bool OnlineLMConfig::Validate() const { |
| @@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const { | @@ -28,6 +32,11 @@ bool OnlineLMConfig::Validate() const { | ||
| 28 | return false; | 32 | return false; |
| 29 | } | 33 | } |
| 30 | 34 | ||
| 35 | + if (!lodr_fst.empty() && !FileExists(lodr_fst)) { | ||
| 36 | + SHERPA_ONNX_LOGE("'%s' does not exist", lodr_fst.c_str()); | ||
| 37 | + return false; | ||
| 38 | + } | ||
| 39 | + | ||
| 31 | return true; | 40 | return true; |
| 32 | } | 41 | } |
| 33 | 42 | ||
| @@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const { | @@ -37,6 +46,9 @@ std::string OnlineLMConfig::ToString() const { | ||
| 37 | os << "OnlineLMConfig("; | 46 | os << "OnlineLMConfig("; |
| 38 | os << "model=\"" << model << "\", "; | 47 | os << "model=\"" << model << "\", "; |
| 39 | os << "scale=" << scale << ", "; | 48 | os << "scale=" << scale << ", "; |
| 49 | + os << "lodr_scale=" << lodr_scale << ", "; | ||
| 50 | + os << "lodr_fst=\"" << lodr_fst << "\", "; | ||
| 51 | + os << "lodr_backoff_id=" << lodr_backoff_id << ", "; | ||
| 40 | os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; | 52 | os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; |
| 41 | 53 | ||
| 42 | return os.str(); | 54 | return os.str(); |
| @@ -18,18 +18,26 @@ struct OnlineLMConfig { | @@ -18,18 +18,26 @@ struct OnlineLMConfig { | ||
| 18 | float scale = 0.5; | 18 | float scale = 0.5; |
| 19 | int32_t lm_num_threads = 1; | 19 | int32_t lm_num_threads = 1; |
| 20 | std::string lm_provider = "cpu"; | 20 | std::string lm_provider = "cpu"; |
| 21 | + std::string lodr_fst; | ||
| 22 | + float lodr_scale = 0.01; | ||
| 23 | + int32_t lodr_backoff_id = -1; // -1 means not set | ||
| 21 | // enable shallow fusion | 24 | // enable shallow fusion |
| 22 | bool shallow_fusion = true; | 25 | bool shallow_fusion = true; |
| 23 | 26 | ||
| 24 | OnlineLMConfig() = default; | 27 | OnlineLMConfig() = default; |
| 25 | 28 | ||
| 26 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, | 29 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 27 | - const std::string &lm_provider, bool shallow_fusion) | 30 | + const std::string &lm_provider, bool shallow_fusion, |
| 31 | + const std::string &lodr_fst, float lodr_scale, | ||
| 32 | + int32_t lodr_backoff_id) | ||
| 28 | : model(model), | 33 | : model(model), |
| 29 | scale(scale), | 34 | scale(scale), |
| 30 | lm_num_threads(lm_num_threads), | 35 | lm_num_threads(lm_num_threads), |
| 31 | lm_provider(lm_provider), | 36 | lm_provider(lm_provider), |
| 32 | - shallow_fusion(shallow_fusion) {} | 37 | + shallow_fusion(shallow_fusion), |
| 38 | + lodr_fst(lodr_fst), | ||
| 39 | + lodr_scale(lodr_scale), | ||
| 40 | + lodr_backoff_id(lodr_backoff_id) {} | ||
| 33 | 41 | ||
| 34 | void Register(ParseOptions *po); | 42 | void Register(ParseOptions *po); |
| 35 | bool Validate() const; | 43 | bool Validate() const; |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | 12 | ||
| 13 | #include "onnxruntime_cxx_api.h" // NOLINT | 13 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 14 | #include "sherpa-onnx/csrc/file-utils.h" | 14 | #include "sherpa-onnx/csrc/file-utils.h" |
| 15 | +#include "sherpa-onnx/csrc/lodr-fst.h" | ||
| 15 | #include "sherpa-onnx/csrc/macros.h" | 16 | #include "sherpa-onnx/csrc/macros.h" |
| 16 | #include "sherpa-onnx/csrc/onnx-utils.h" | 17 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 17 | #include "sherpa-onnx/csrc/session.h" | 18 | #include "sherpa-onnx/csrc/session.h" |
| @@ -35,12 +36,27 @@ class OnlineRnnLM::Impl { | @@ -35,12 +36,27 @@ class OnlineRnnLM::Impl { | ||
| 35 | auto init_states = GetInitStatesSF(); | 36 | auto init_states = GetInitStatesSF(); |
| 36 | hyp->nn_lm_scores.value = std::move(init_states.first); | 37 | hyp->nn_lm_scores.value = std::move(init_states.first); |
| 37 | hyp->nn_lm_states = Convert(std::move(init_states.second)); | 38 | hyp->nn_lm_states = Convert(std::move(init_states.second)); |
| 39 | + // if LODR enabled, we need to initialize the LODR state | ||
| 40 | + if (lodr_fst_ != nullptr) { | ||
| 41 | + hyp->lodr_state = std::make_unique<LodrStateCost>(lodr_fst_.get()); | ||
| 42 | + } | ||
| 38 | } | 43 | } |
| 39 | 44 | ||
| 40 | // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob | 45 | // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob |
| 41 | const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); | 46 | const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); |
| 42 | hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; | 47 | hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; |
| 43 | 48 | ||
| 49 | + // if LODR enabled, we need to update the LODR state | ||
| 50 | + if (lodr_fst_ != nullptr) { | ||
| 51 | + auto next_lodr_state = std::make_unique<LodrStateCost>( | ||
| 52 | + hyp->lodr_state->ForwardOneStep(hyp->ys.back())); | ||
| 53 | + // calculate the score of the latest token | ||
| 54 | + auto score = next_lodr_state->Score() - hyp->lodr_state->Score(); | ||
| 55 | + hyp->lodr_state = std::move(next_lodr_state); | ||
| 56 | + // apply LODR to hyp score | ||
| 57 | + hyp->lm_log_prob += score * config_.lodr_scale; | ||
| 58 | + } | ||
| 59 | + | ||
| 44 | // get lm scores for next tokens given the hyp->ys[:] and save to | 60 | // get lm scores for next tokens given the hyp->ys[:] and save to |
| 45 | // nn_lm_scores | 61 | // nn_lm_scores |
| 46 | std::array<int64_t, 2> x_shape{1, 1}; | 62 | std::array<int64_t, 2> x_shape{1, 1}; |
| @@ -89,6 +105,12 @@ class OnlineRnnLM::Impl { | @@ -89,6 +105,12 @@ class OnlineRnnLM::Impl { | ||
| 89 | const float *p_nll = out.first.GetTensorData<float>(); | 105 | const float *p_nll = out.first.GetTensorData<float>(); |
| 90 | h.lm_log_prob = -scale * (*p_nll); | 106 | h.lm_log_prob = -scale * (*p_nll); |
| 91 | 107 | ||
| 108 | + // apply LODR to hyp score | ||
| 109 | + if (lodr_fst_ != nullptr) { | ||
| 110 | + // We scale LODR scale with LM scale to replicate Icefall code | ||
| 111 | + lodr_fst_->ComputeScore(config_.lodr_scale*scale, &h, context_size); | ||
| 112 | + } | ||
| 113 | + | ||
| 92 | // update NN LM states in hyp | 114 | // update NN LM states in hyp |
| 93 | h.nn_lm_states = Convert(std::move(out.second)); | 115 | h.nn_lm_states = Convert(std::move(out.second)); |
| 94 | 116 | ||
| @@ -154,6 +176,11 @@ class OnlineRnnLM::Impl { | @@ -154,6 +176,11 @@ class OnlineRnnLM::Impl { | ||
| 154 | SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); | 176 | SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); |
| 155 | 177 | ||
| 156 | ComputeInitStates(); | 178 | ComputeInitStates(); |
| 179 | + | ||
| 180 | + if (!config_.lodr_fst.empty()) { | ||
| 181 | + lodr_fst_ = std::make_unique<LodrFst>(LodrFst(config_.lodr_fst, | ||
| 182 | + config_.lodr_backoff_id)); | ||
| 183 | + } | ||
| 157 | } | 184 | } |
| 158 | 185 | ||
| 159 | void ComputeInitStates() { | 186 | void ComputeInitStates() { |
| @@ -203,6 +230,8 @@ class OnlineRnnLM::Impl { | @@ -203,6 +230,8 @@ class OnlineRnnLM::Impl { | ||
| 203 | int32_t rnn_num_layers_ = 2; | 230 | int32_t rnn_num_layers_ = 2; |
| 204 | int32_t rnn_hidden_size_ = 512; | 231 | int32_t rnn_hidden_size_ = 512; |
| 205 | int32_t sos_id_ = 1; | 232 | int32_t sos_id_ = 1; |
| 233 | + | ||
| 234 | + std::unique_ptr<LodrFst> lodr_fst_; | ||
| 206 | }; | 235 | }; |
| 207 | 236 | ||
| 208 | OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | 237 | OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) |
| @@ -13,13 +13,19 @@ namespace sherpa_onnx { | @@ -13,13 +13,19 @@ namespace sherpa_onnx { | ||
| 13 | void PybindOfflineLMConfig(py::module *m) { | 13 | void PybindOfflineLMConfig(py::module *m) { |
| 14 | using PyClass = OfflineLMConfig; | 14 | using PyClass = OfflineLMConfig; |
| 15 | py::class_<PyClass>(*m, "OfflineLMConfig") | 15 | py::class_<PyClass>(*m, "OfflineLMConfig") |
| 16 | - .def(py::init<const std::string &, float, int32_t, const std::string &>(), | 16 | + .def(py::init<const std::string &, float, int32_t, const std::string &, |
| 17 | + const std::string &, float, int32_t>(), | ||
| 17 | py::arg("model"), py::arg("scale") = 0.5f, | 18 | py::arg("model"), py::arg("scale") = 0.5f, |
| 18 | - py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") | 19 | + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", |
| 20 | + py::arg("lodr_fst") = "", py::arg("lodr_scale") = 0.0f, | ||
| 21 | + py::arg("lodr_backoff_id") = -1) | ||
| 19 | .def_readwrite("model", &PyClass::model) | 22 | .def_readwrite("model", &PyClass::model) |
| 20 | .def_readwrite("scale", &PyClass::scale) | 23 | .def_readwrite("scale", &PyClass::scale) |
| 21 | .def_readwrite("lm_provider", &PyClass::lm_provider) | 24 | .def_readwrite("lm_provider", &PyClass::lm_provider) |
| 22 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) | 25 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) |
| 26 | + .def_readwrite("lodr_fst", &PyClass::lodr_fst) | ||
| 27 | + .def_readwrite("lodr_scale", &PyClass::lodr_scale) | ||
| 28 | + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id) | ||
| 23 | .def("__str__", &PyClass::ToString); | 29 | .def("__str__", &PyClass::ToString); |
| 24 | } | 30 | } |
| 25 | 31 |
| @@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) { | @@ -14,15 +14,21 @@ void PybindOnlineLMConfig(py::module *m) { | ||
| 14 | using PyClass = OnlineLMConfig; | 14 | using PyClass = OnlineLMConfig; |
| 15 | py::class_<PyClass>(*m, "OnlineLMConfig") | 15 | py::class_<PyClass>(*m, "OnlineLMConfig") |
| 16 | .def(py::init<const std::string &, float, int32_t, | 16 | .def(py::init<const std::string &, float, int32_t, |
| 17 | - const std::string &, bool>(), | 17 | + const std::string &, bool, const std::string &, |
| 18 | + float, int>(), | ||
| 18 | py::arg("model") = "", py::arg("scale") = 0.5f, | 19 | py::arg("model") = "", py::arg("scale") = 0.5f, |
| 19 | py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", | 20 | py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", |
| 20 | - py::arg("shallow_fusion") = true) | 21 | + py::arg("shallow_fusion") = true, py::arg("lodr_fst") = "", |
| 22 | + py::arg("lodr_scale") = 0.0f, py::arg("lodr_backoff_id") = -1) | ||
| 21 | .def_readwrite("model", &PyClass::model) | 23 | .def_readwrite("model", &PyClass::model) |
| 22 | .def_readwrite("scale", &PyClass::scale) | 24 | .def_readwrite("scale", &PyClass::scale) |
| 23 | .def_readwrite("lm_provider", &PyClass::lm_provider) | 25 | .def_readwrite("lm_provider", &PyClass::lm_provider) |
| 24 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) | 26 | .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) |
| 25 | .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) | 27 | .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) |
| 28 | + .def_readwrite("lodr_fst", &PyClass::lodr_fst) | ||
| 29 | + .def_readwrite("lodr_scale", &PyClass::lodr_scale) | ||
| 30 | + .def_readwrite("lodr_backoff_id", &PyClass::lodr_backoff_id) | ||
| 31 | + | ||
| 26 | .def("__str__", &PyClass::ToString); | 32 | .def("__str__", &PyClass::ToString); |
| 27 | } | 33 | } |
| 28 | 34 |
| @@ -69,6 +69,8 @@ class OfflineRecognizer(object): | @@ -69,6 +69,8 @@ class OfflineRecognizer(object): | ||
| 69 | hr_dict_dir: str = "", | 69 | hr_dict_dir: str = "", |
| 70 | hr_rule_fsts: str = "", | 70 | hr_rule_fsts: str = "", |
| 71 | hr_lexicon: str = "", | 71 | hr_lexicon: str = "", |
| 72 | + lodr_fst: str = "", | ||
| 73 | + lodr_scale: float = 0.0, | ||
| 72 | ): | 74 | ): |
| 73 | """ | 75 | """ |
| 74 | Please refer to | 76 | Please refer to |
| @@ -133,6 +135,10 @@ class OfflineRecognizer(object): | @@ -133,6 +135,10 @@ class OfflineRecognizer(object): | ||
| 133 | rule_fars: | 135 | rule_fars: |
| 134 | If not empty, it specifies fst archives for inverse text normalization. | 136 | If not empty, it specifies fst archives for inverse text normalization. |
| 135 | If there are multiple archives, they are separated by a comma. | 137 | If there are multiple archives, they are separated by a comma. |
| 138 | + lodr_fst: | ||
| 139 | + Path to the LODR FST file in binary format. If empty, LODR is disabled. | ||
| 140 | + lodr_scale: | ||
| 141 | + Scale factor for LODR rescoring. Only used when lodr_fst is provided. | ||
| 136 | """ | 142 | """ |
| 137 | self = cls.__new__(cls) | 143 | self = cls.__new__(cls) |
| 138 | model_config = OfflineModelConfig( | 144 | model_config = OfflineModelConfig( |
| @@ -173,6 +179,8 @@ class OfflineRecognizer(object): | @@ -173,6 +179,8 @@ class OfflineRecognizer(object): | ||
| 173 | scale=lm_scale, | 179 | scale=lm_scale, |
| 174 | lm_num_threads=num_threads, | 180 | lm_num_threads=num_threads, |
| 175 | lm_provider=provider, | 181 | lm_provider=provider, |
| 182 | + lodr_fst=lodr_fst, | ||
| 183 | + lodr_scale=lodr_scale, | ||
| 176 | ) | 184 | ) |
| 177 | 185 | ||
| 178 | recognizer_config = OfflineRecognizerConfig( | 186 | recognizer_config = OfflineRecognizerConfig( |
| @@ -89,6 +89,8 @@ class OnlineRecognizer(object): | @@ -89,6 +89,8 @@ class OnlineRecognizer(object): | ||
| 89 | hr_dict_dir: str = "", | 89 | hr_dict_dir: str = "", |
| 90 | hr_rule_fsts: str = "", | 90 | hr_rule_fsts: str = "", |
| 91 | hr_lexicon: str = "", | 91 | hr_lexicon: str = "", |
| 92 | + lodr_fst: str = "", | ||
| 93 | + lodr_scale: float = 0.0, | ||
| 92 | ): | 94 | ): |
| 93 | """ | 95 | """ |
| 94 | Please refer to | 96 | Please refer to |
| @@ -216,6 +218,10 @@ class OnlineRecognizer(object): | @@ -216,6 +218,10 @@ class OnlineRecognizer(object): | ||
| 216 | "Set path for storing timing cache." TensorRT EP | 218 | "Set path for storing timing cache." TensorRT EP |
| 217 | trt_dump_subgraphs: bool = False, | 219 | trt_dump_subgraphs: bool = False, |
| 218 | "Dump optimized subgraphs for debugging." TensorRT EP | 220 | "Dump optimized subgraphs for debugging." TensorRT EP |
| 221 | + lodr_fst: | ||
| 222 | + Path to the LODR FST file in binary format. If empty, LODR is disabled. | ||
| 223 | + lodr_scale: | ||
| 224 | + Scale factor for LODR rescoring. Only used when lodr_fst is provided. | ||
| 219 | """ | 225 | """ |
| 220 | self = cls.__new__(cls) | 226 | self = cls.__new__(cls) |
| 221 | _assert_file_exists(tokens) | 227 | _assert_file_exists(tokens) |
| @@ -298,6 +304,8 @@ class OnlineRecognizer(object): | @@ -298,6 +304,8 @@ class OnlineRecognizer(object): | ||
| 298 | model=lm, | 304 | model=lm, |
| 299 | scale=lm_scale, | 305 | scale=lm_scale, |
| 300 | shallow_fusion=lm_shallow_fusion, | 306 | shallow_fusion=lm_shallow_fusion, |
| 307 | + lodr_fst=lodr_fst, | ||
| 308 | + lodr_scale=lodr_scale, | ||
| 301 | ) | 309 | ) |
| 302 | 310 | ||
| 303 | recognizer_config = OnlineRecognizerConfig( | 311 | recognizer_config = OnlineRecognizerConfig( |
-
请 注册 或 登录 后发表评论