Fangjun Kuang
Committed by GitHub

Add inverse text normalization for non-streaming ASR (#1017)

@@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then @@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then
248 python3 ./python-api-examples/online-decode-files.py \ 248 python3 ./python-api-examples/online-decode-files.py \
249 --tokens=$repo/tokens.txt \ 249 --tokens=$repo/tokens.txt \
250 --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ 250 --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
251 - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ 251 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
252 --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ 252 --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
253 $repo/test_wavs/0.wav \ 253 $repo/test_wavs/0.wav \
254 $repo/test_wavs/1.wav \ 254 $repo/test_wavs/1.wav \
@@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \ @@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \
286 python3 ./python-api-examples/offline-decode-files.py \ 286 python3 ./python-api-examples/offline-decode-files.py \
287 --tokens=$repo/tokens.txt \ 287 --tokens=$repo/tokens.txt \
288 --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ 288 --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \
289 - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ 289 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
290 --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ 290 --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \
291 $repo/test_wavs/0.wav \ 291 $repo/test_wavs/0.wav \
292 $repo/test_wavs/1.wav \ 292 $repo/test_wavs/1.wav \
@@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then @@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then
330 330
331 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose 331 python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
332 332
  333 + ln -s $repo $PWD/
  334 +
  335 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
  336 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
  337 +
  338 + python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
  339 +
  340 + rm -rfv sherpa-onnx-paraformer-zh-2023-03-28
  341 +
333 rm -rf $repo 342 rm -rf $repo
334 fi 343 fi
335 344
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +"""
  6 +This script shows how to use inverse text normalization with non-streaming ASR.
  7 +
  8 +Usage:
  9 +
  10 +(1) Download the test model
  11 +
  12 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  13 +tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  14 +rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2
  15 +
  16 +(2) Download rule fst
  17 +
  18 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
  19 +
  20 +Please refer to
  21 +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
  22 +for how itn_zh_number.fst is generated.
  23 +
  24 +(3) Download test wave
  25 +
  26 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
  27 +
  28 +(4) Run this script
  29 +
  30 +python3 ./python-api-examples/inverse-text-normalization-offline-asr.py
  31 +"""
  32 +from pathlib import Path
  33 +
  34 +import sherpa_onnx
  35 +import soundfile as sf
  36 +
  37 +
  38 +def create_recognizer():
  39 + model = "./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx"
  40 + tokens = "./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt"
  41 + rule_fsts = "./itn_zh_number.fst"
  42 +
  43 + if (
  44 + not Path(model).is_file()
  45 + or not Path(tokens).is_file()
  46 + or not Path(rule_fsts).is_file()
  47 + ):
  48 + raise ValueError(
  49 + """Please download model files from
  50 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  51 + """
  52 + )
  53 + return sherpa_onnx.OfflineRecognizer.from_paraformer(
  54 + paraformer=model,
  55 + tokens=tokens,
  56 + debug=True,
  57 + rule_fsts=rule_fsts,
  58 + )
  59 +
  60 +
  61 +def main():
  62 + recognizer = create_recognizer()
  63 + wave_filename = "./itn-zh-number.wav"
  64 + if not Path(wave_filename).is_file():
  65 + raise ValueError(
  66 + """Please download model files from
  67 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  68 + """
  69 + )
  70 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  71 + audio = audio[:, 0] # only use the first channel
  72 +
  73 + stream = recognizer.create_stream()
  74 + stream.accept_waveform(sample_rate, audio)
  75 + recognizer.decode_stream(stream)
  76 + print(wave_filename)
  77 + print(stream.result)
  78 +
  79 +
  80 +if __name__ == "__main__":
  81 + main()
@@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
73 class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { 73 class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
74 public: 74 public:
75 explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) 75 explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
76 - : config_(config), 76 + : OfflineRecognizerImpl(config),
  77 + config_(config),
77 symbol_table_(config_.model_config.tokens), 78 symbol_table_(config_.model_config.tokens),
78 model_(OfflineCtcModel::Create(config_.model_config)) { 79 model_(OfflineCtcModel::Create(config_.model_config)) {
79 Init(); 80 Init();
@@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
82 #if __ANDROID_API__ >= 9 83 #if __ANDROID_API__ >= 9
83 OfflineRecognizerCtcImpl(AAssetManager *mgr, 84 OfflineRecognizerCtcImpl(AAssetManager *mgr,
84 const OfflineRecognizerConfig &config) 85 const OfflineRecognizerConfig &config)
85 - : config_(config), 86 + : OfflineRecognizerImpl(mgr, config),
  87 + config_(config),
86 symbol_table_(mgr, config_.model_config.tokens), 88 symbol_table_(mgr, config_.model_config.tokens),
87 model_(OfflineCtcModel::Create(mgr, config_.model_config)) { 89 model_(OfflineCtcModel::Create(mgr, config_.model_config)) {
88 Init(); 90 Init();
@@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
205 for (int32_t i = 0; i != n; ++i) { 207 for (int32_t i = 0; i != n; ++i) {
206 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 208 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
207 model_->SubsamplingFactor()); 209 model_->SubsamplingFactor());
  210 + r.text = ApplyInverseTextNormalization(std::move(r.text));
208 ss[i]->SetResult(r); 211 ss[i]->SetResult(r);
209 } 212 }
210 } 213 }
@@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
238 241
239 auto r = Convert(results[0], symbol_table_, frame_shift_ms, 242 auto r = Convert(results[0], symbol_table_, frame_shift_ms,
240 model_->SubsamplingFactor()); 243 model_->SubsamplingFactor());
  244 + r.text = ApplyInverseTextNormalization(std::move(r.text));
241 s->SetResult(r); 245 s->SetResult(r);
242 } 246 }
243 247
@@ -5,7 +5,18 @@ @@ -5,7 +5,18 @@
5 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 5 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
6 6
7 #include <string> 7 #include <string>
  8 +#include <utility>
  9 +#include <vector>
8 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include <strstream>
  13 +
  14 +#include "android/asset_manager.h"
  15 +#include "android/asset_manager_jni.h"
  16 +#endif
  17 +
  18 +#include "fst/extensions/far/far.h"
  19 +#include "kaldifst/csrc/kaldi-fst-io.h"
9 #include "onnxruntime_cxx_api.h" // NOLINT 20 #include "onnxruntime_cxx_api.h" // NOLINT
10 #include "sherpa-onnx/csrc/macros.h" 21 #include "sherpa-onnx/csrc/macros.h"
11 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" 22 #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
@@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
316 } 327 }
317 #endif 328 #endif
318 329
  330 +OfflineRecognizerImpl::OfflineRecognizerImpl(
  331 + const OfflineRecognizerConfig &config)
  332 + : config_(config) {
  333 + if (!config.rule_fsts.empty()) {
  334 + std::vector<std::string> files;
  335 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  336 + itn_list_.reserve(files.size());
  337 + for (const auto &f : files) {
  338 + if (config.model_config.debug) {
  339 + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
  340 + }
  341 + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
  342 + }
  343 + }
  344 +
  345 + if (!config.rule_fars.empty()) {
  346 + if (config.model_config.debug) {
  347 + SHERPA_ONNX_LOGE("Loading FST archives");
  348 + }
  349 + std::vector<std::string> files;
  350 + SplitStringToVector(config.rule_fars, ",", false, &files);
  351 +
  352 + itn_list_.reserve(files.size() + itn_list_.size());
  353 +
  354 + for (const auto &f : files) {
  355 + if (config.model_config.debug) {
  356 + SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
  357 + }
  358 + std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
  359 + fst::FarReader<fst::StdArc>::Open(f));
  360 + for (; !reader->Done(); reader->Next()) {
  361 + std::unique_ptr<fst::StdConstFst> r(
  362 + fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
  363 +
  364 + itn_list_.push_back(
  365 + std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
  366 + }
  367 + }
  368 +
  369 + if (config.model_config.debug) {
  370 + SHERPA_ONNX_LOGE("FST archives loaded!");
  371 + }
  372 + }
  373 +}
  374 +
  375 +#if __ANDROID_API__ >= 9
  376 +OfflineRecognizerImpl::OfflineRecognizerImpl(
  377 + AAssetManager *mgr, const OfflineRecognizerConfig &config)
  378 + : config_(config) {
  379 + if (!config.rule_fsts.empty()) {
  380 + std::vector<std::string> files;
  381 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  382 + itn_list_.reserve(files.size());
  383 + for (const auto &f : files) {
  384 + if (config.model_config.debug) {
  385 + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
  386 + }
  387 + auto buf = ReadFile(mgr, f);
  388 + std::istrstream is(buf.data(), buf.size());
  389 + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
  390 + }
  391 + }
  392 +
  393 + if (!config.rule_fars.empty()) {
  394 + std::vector<std::string> files;
  395 + SplitStringToVector(config.rule_fars, ",", false, &files);
  396 + itn_list_.reserve(files.size() + itn_list_.size());
  397 +
  398 + for (const auto &f : files) {
  399 + if (config.model_config.debug) {
  400 + SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
  401 + }
  402 +
  403 + auto buf = ReadFile(mgr, f);
  404 +
  405 + std::unique_ptr<std::istream> s(
  406 + new std::istrstream(buf.data(), buf.size()));
  407 +
  408 + std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
  409 + fst::FarReader<fst::StdArc>::Open(std::move(s)));
  410 +
  411 + for (; !reader->Done(); reader->Next()) {
  412 + std::unique_ptr<fst::StdConstFst> r(
  413 + fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
  414 +
  415 + itn_list_.push_back(
  416 + std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
  417 + } // for (; !reader->Done(); reader->Next())
  418 + } // for (const auto &f : files)
  419 + } // if (!config.rule_fars.empty())
  420 +}
  421 +#endif
  422 +
  423 +std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
  424 + std::string text) const {
  425 + if (!itn_list_.empty()) {
  426 + for (const auto &tn : itn_list_) {
  427 + text = tn->Normalize(text);
  428 + if (config_.model_config.debug) {
  429 + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
  430 + }
  431 + }
  432 + }
  433 +
  434 + return text;
  435 +}
  436 +
319 } // namespace sherpa_onnx 437 } // namespace sherpa_onnx
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include "android/asset_manager_jni.h" 14 #include "android/asset_manager_jni.h"
15 #endif 15 #endif
16 16
  17 +#include "kaldifst/csrc/text-normalizer.h"
17 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
18 #include "sherpa-onnx/csrc/offline-recognizer.h" 19 #include "sherpa-onnx/csrc/offline-recognizer.h"
19 #include "sherpa-onnx/csrc/offline-stream.h" 20 #include "sherpa-onnx/csrc/offline-stream.h"
@@ -22,10 +23,15 @@ namespace sherpa_onnx { @@ -22,10 +23,15 @@ namespace sherpa_onnx {
22 23
23 class OfflineRecognizerImpl { 24 class OfflineRecognizerImpl {
24 public: 25 public:
  26 + explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config);
  27 +
25 static std::unique_ptr<OfflineRecognizerImpl> Create( 28 static std::unique_ptr<OfflineRecognizerImpl> Create(
26 const OfflineRecognizerConfig &config); 29 const OfflineRecognizerConfig &config);
27 30
28 #if __ANDROID_API__ >= 9 31 #if __ANDROID_API__ >= 9
  32 + OfflineRecognizerImpl(AAssetManager *mgr,
  33 + const OfflineRecognizerConfig &config);
  34 +
29 static std::unique_ptr<OfflineRecognizerImpl> Create( 35 static std::unique_ptr<OfflineRecognizerImpl> Create(
30 AAssetManager *mgr, const OfflineRecognizerConfig &config); 36 AAssetManager *mgr, const OfflineRecognizerConfig &config);
31 #endif 37 #endif
@@ -41,6 +47,15 @@ class OfflineRecognizerImpl { @@ -41,6 +47,15 @@ class OfflineRecognizerImpl {
41 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; 47 virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
42 48
43 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; 49 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
  50 +
  51 + std::string ApplyInverseTextNormalization(std::string text) const;
  52 +
  53 + private:
  54 + OfflineRecognizerConfig config_;
  55 + // for inverse text normalization. Used only if
  56 + // config.rule_fsts is not empty or
  57 + // config.rule_fars is not empty
  58 + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
44 }; 59 };
45 60
46 } // namespace sherpa_onnx 61 } // namespace sherpa_onnx
@@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
89 public: 89 public:
90 explicit OfflineRecognizerParaformerImpl( 90 explicit OfflineRecognizerParaformerImpl(
91 const OfflineRecognizerConfig &config) 91 const OfflineRecognizerConfig &config)
92 - : config_(config), 92 + : OfflineRecognizerImpl(config),
  93 + config_(config),
93 symbol_table_(config_.model_config.tokens), 94 symbol_table_(config_.model_config.tokens),
94 model_(std::make_unique<OfflineParaformerModel>(config.model_config)) { 95 model_(std::make_unique<OfflineParaformerModel>(config.model_config)) {
95 if (config.decoding_method == "greedy_search") { 96 if (config.decoding_method == "greedy_search") {
@@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
109 #if __ANDROID_API__ >= 9 110 #if __ANDROID_API__ >= 9
110 OfflineRecognizerParaformerImpl(AAssetManager *mgr, 111 OfflineRecognizerParaformerImpl(AAssetManager *mgr,
111 const OfflineRecognizerConfig &config) 112 const OfflineRecognizerConfig &config)
112 - : config_(config), 113 + : OfflineRecognizerImpl(mgr, config),
  114 + config_(config),
113 symbol_table_(mgr, config_.model_config.tokens), 115 symbol_table_(mgr, config_.model_config.tokens),
114 model_(std::make_unique<OfflineParaformerModel>(mgr, 116 model_(std::make_unique<OfflineParaformerModel>(mgr,
115 config.model_config)) { 117 config.model_config)) {
@@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
204 206
205 for (int32_t i = 0; i != n; ++i) { 207 for (int32_t i = 0; i != n; ++i) {
206 auto r = Convert(results[i], symbol_table_); 208 auto r = Convert(results[i], symbol_table_);
  209 + r.text = ApplyInverseTextNormalization(std::move(r.text));
207 ss[i]->SetResult(r); 210 ss[i]->SetResult(r);
208 } 211 }
209 } 212 }
@@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
74 public: 74 public:
75 explicit OfflineRecognizerTransducerImpl( 75 explicit OfflineRecognizerTransducerImpl(
76 const OfflineRecognizerConfig &config) 76 const OfflineRecognizerConfig &config)
77 - : config_(config), 77 + : OfflineRecognizerImpl(config),
  78 + config_(config),
78 symbol_table_(config_.model_config.tokens), 79 symbol_table_(config_.model_config.tokens),
79 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { 80 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
80 if (config_.decoding_method == "greedy_search") { 81 if (config_.decoding_method == "greedy_search") {
@@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
107 #if __ANDROID_API__ >= 9 108 #if __ANDROID_API__ >= 9
108 explicit OfflineRecognizerTransducerImpl( 109 explicit OfflineRecognizerTransducerImpl(
109 AAssetManager *mgr, const OfflineRecognizerConfig &config) 110 AAssetManager *mgr, const OfflineRecognizerConfig &config)
110 - : config_(config), 111 + : OfflineRecognizerImpl(mgr, config),
  112 + config_(config),
111 symbol_table_(mgr, config_.model_config.tokens), 113 symbol_table_(mgr, config_.model_config.tokens),
112 model_(std::make_unique<OfflineTransducerModel>(mgr, 114 model_(std::make_unique<OfflineTransducerModel>(mgr,
113 config_.model_config)) { 115 config_.model_config)) {
@@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
230 for (int32_t i = 0; i != n; ++i) { 232 for (int32_t i = 0; i != n; ++i) {
231 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 233 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
232 model_->SubsamplingFactor()); 234 model_->SubsamplingFactor());
  235 + r.text = ApplyInverseTextNormalization(std::move(r.text));
233 236
234 ss[i]->SetResult(r); 237 ss[i]->SetResult(r);
235 } 238 }
@@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
41 public: 41 public:
42 explicit OfflineRecognizerTransducerNeMoImpl( 42 explicit OfflineRecognizerTransducerNeMoImpl(
43 const OfflineRecognizerConfig &config) 43 const OfflineRecognizerConfig &config)
44 - : config_(config), 44 + : OfflineRecognizerImpl(config),
  45 + config_(config),
45 symbol_table_(config_.model_config.tokens), 46 symbol_table_(config_.model_config.tokens),
46 model_(std::make_unique<OfflineTransducerNeMoModel>( 47 model_(std::make_unique<OfflineTransducerNeMoModel>(
47 config_.model_config)) { 48 config_.model_config)) {
@@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
59 #if __ANDROID_API__ >= 9 60 #if __ANDROID_API__ >= 9
60 explicit OfflineRecognizerTransducerNeMoImpl( 61 explicit OfflineRecognizerTransducerNeMoImpl(
61 AAssetManager *mgr, const OfflineRecognizerConfig &config) 62 AAssetManager *mgr, const OfflineRecognizerConfig &config)
62 - : config_(config), 63 + : OfflineRecognizerImpl(mgr, config),
  64 + config_(config),
63 symbol_table_(mgr, config_.model_config.tokens), 65 symbol_table_(mgr, config_.model_config.tokens),
64 model_(std::make_unique<OfflineTransducerNeMoModel>( 66 model_(std::make_unique<OfflineTransducerNeMoModel>(
65 mgr, config_.model_config)) { 67 mgr, config_.model_config)) {
@@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
131 for (int32_t i = 0; i != n; ++i) { 133 for (int32_t i = 0; i != n; ++i) {
132 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 134 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
133 model_->SubsamplingFactor()); 135 model_->SubsamplingFactor());
  136 + r.text = ApplyInverseTextNormalization(std::move(r.text));
134 137
135 ss[i]->SetResult(r); 138 ss[i]->SetResult(r);
136 } 139 }
@@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, @@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
52 class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { 52 class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
53 public: 53 public:
54 explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) 54 explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config)
55 - : config_(config), 55 + : OfflineRecognizerImpl(config),
  56 + config_(config),
56 symbol_table_(config_.model_config.tokens), 57 symbol_table_(config_.model_config.tokens),
57 model_(std::make_unique<OfflineWhisperModel>(config.model_config)) { 58 model_(std::make_unique<OfflineWhisperModel>(config.model_config)) {
58 Init(); 59 Init();
@@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
61 #if __ANDROID_API__ >= 9 62 #if __ANDROID_API__ >= 9
62 OfflineRecognizerWhisperImpl(AAssetManager *mgr, 63 OfflineRecognizerWhisperImpl(AAssetManager *mgr,
63 const OfflineRecognizerConfig &config) 64 const OfflineRecognizerConfig &config)
64 - : config_(config), 65 + : OfflineRecognizerImpl(mgr, config),
  66 + config_(config),
65 symbol_table_(mgr, config_.model_config.tokens), 67 symbol_table_(mgr, config_.model_config.tokens),
66 model_( 68 model_(
67 std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) { 69 std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) {
@@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
150 std::move(cross_kv.second)); 152 std::move(cross_kv.second));
151 153
152 auto r = Convert(results[0], symbol_table_); 154 auto r = Convert(results[0], symbol_table_);
  155 + r.text = ApplyInverseTextNormalization(std::move(r.text));
153 s->SetResult(r); 156 s->SetResult(r);
154 } catch (const Ort::Exception &ex) { 157 } catch (const Ort::Exception &ex) {
155 SHERPA_ONNX_LOGE( 158 SHERPA_ONNX_LOGE(
@@ -10,7 +10,7 @@ @@ -10,7 +10,7 @@
10 #include "sherpa-onnx/csrc/macros.h" 10 #include "sherpa-onnx/csrc/macros.h"
11 #include "sherpa-onnx/csrc/offline-lm-config.h" 11 #include "sherpa-onnx/csrc/offline-lm-config.h"
12 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
13 - 13 +#include "sherpa-onnx/csrc/text-utils.h"
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 void OfflineRecognizerConfig::Register(ParseOptions *po) { 16 void OfflineRecognizerConfig::Register(ParseOptions *po) {
@@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
44 po->Register("hotwords-score", &hotwords_score, 44 po->Register("hotwords-score", &hotwords_score,
45 "The bonus score for each token in context word/phrase. " 45 "The bonus score for each token in context word/phrase. "
46 "Used only when decoding_method is modified_beam_search"); 46 "Used only when decoding_method is modified_beam_search");
  47 +
  48 + po->Register(
  49 + "rule-fsts", &rule_fsts,
  50 + "If not empty, it specifies fsts for inverse text normalization. "
  51 + "If there are multiple fsts, they are separated by a comma.");
  52 +
  53 + po->Register(
  54 + "rule-fars", &rule_fars,
  55 + "If not empty, it specifies fst archives for inverse text normalization. "
  56 + "If there are multiple archives, they are separated by a comma.");
47 } 57 }
48 58
49 bool OfflineRecognizerConfig::Validate() const { 59 bool OfflineRecognizerConfig::Validate() const {
@@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const { @@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const {
61 if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { 71 if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
62 SHERPA_ONNX_LOGE( 72 SHERPA_ONNX_LOGE(
63 "Please use --decoding-method=modified_beam_search if you" 73 "Please use --decoding-method=modified_beam_search if you"
64 - " provide --hotwords-file. Given --decoding-method=%s", 74 + " provide --hotwords-file. Given --decoding-method='%s'",
65 decoding_method.c_str()); 75 decoding_method.c_str());
66 return false; 76 return false;
67 } 77 }
@@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const { @@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const {
72 return false; 82 return false;
73 } 83 }
74 84
  85 + if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
  86 + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
  87 + hotwords_file.c_str());
  88 + return false;
  89 + }
  90 +
  91 + if (!rule_fsts.empty()) {
  92 + std::vector<std::string> files;
  93 + SplitStringToVector(rule_fsts, ",", false, &files);
  94 + for (const auto &f : files) {
  95 + if (!FileExists(f)) {
  96 + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
  97 + return false;
  98 + }
  99 + }
  100 + }
  101 +
  102 + if (!rule_fars.empty()) {
  103 + std::vector<std::string> files;
  104 + SplitStringToVector(rule_fars, ",", false, &files);
  105 + for (const auto &f : files) {
  106 + if (!FileExists(f)) {
  107 + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
  108 + return false;
  109 + }
  110 + }
  111 + }
  112 +
75 return model_config.Validate(); 113 return model_config.Validate();
76 } 114 }
77 115
@@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const {
87 os << "max_active_paths=" << max_active_paths << ", "; 125 os << "max_active_paths=" << max_active_paths << ", ";
88 os << "hotwords_file=\"" << hotwords_file << "\", "; 126 os << "hotwords_file=\"" << hotwords_file << "\", ";
89 os << "hotwords_score=" << hotwords_score << ", "; 127 os << "hotwords_score=" << hotwords_score << ", ";
90 - os << "blank_penalty=" << blank_penalty << ")"; 128 + os << "blank_penalty=" << blank_penalty << ", ";
  129 + os << "rule_fsts=\"" << rule_fsts << "\", ";
  130 + os << "rule_fars=\"" << rule_fars << "\")";
91 131
92 return os.str(); 132 return os.str();
93 } 133 }
@@ -40,6 +40,12 @@ struct OfflineRecognizerConfig { @@ -40,6 +40,12 @@ struct OfflineRecognizerConfig {
40 40
41 float blank_penalty = 0.0; 41 float blank_penalty = 0.0;
42 42
  43 + // If there are multiple rules, they are applied from left to right.
  44 + std::string rule_fsts;
  45 +
  46 + // If there are multiple FST archives, they are applied from left to right.
  47 + std::string rule_fars;
  48 +
43 // only greedy_search is implemented 49 // only greedy_search is implemented
44 // TODO(fangjun): Implement modified_beam_search 50 // TODO(fangjun): Implement modified_beam_search
45 51
@@ -50,7 +56,8 @@ struct OfflineRecognizerConfig { @@ -50,7 +56,8 @@ struct OfflineRecognizerConfig {
50 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, 56 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
51 const std::string &decoding_method, int32_t max_active_paths, 57 const std::string &decoding_method, int32_t max_active_paths,
52 const std::string &hotwords_file, float hotwords_score, 58 const std::string &hotwords_file, float hotwords_score,
53 - float blank_penalty) 59 + float blank_penalty, const std::string &rule_fsts,
  60 + const std::string &rule_fars)
54 : feat_config(feat_config), 61 : feat_config(feat_config),
55 model_config(model_config), 62 model_config(model_config),
56 lm_config(lm_config), 63 lm_config(lm_config),
@@ -59,7 +66,9 @@ struct OfflineRecognizerConfig { @@ -59,7 +66,9 @@ struct OfflineRecognizerConfig {
59 max_active_paths(max_active_paths), 66 max_active_paths(max_active_paths),
60 hotwords_file(hotwords_file), 67 hotwords_file(hotwords_file),
61 hotwords_score(hotwords_score), 68 hotwords_score(hotwords_score),
62 - blank_penalty(blank_penalty) {} 69 + blank_penalty(blank_penalty),
  70 + rule_fsts(rule_fsts),
  71 + rule_fars(rule_fars) {}
63 72
64 void Register(ParseOptions *po); 73 void Register(ParseOptions *po);
65 bool Validate() const; 74 bool Validate() const;
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
17 .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, 17 .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
18 const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, 18 const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
19 const std::string &, int32_t, const std::string &, float, 19 const std::string &, int32_t, const std::string &, float,
20 - float>(), 20 + float, const std::string &, const std::string &>(),
21 py::arg("feat_config"), py::arg("model_config"), 21 py::arg("feat_config"), py::arg("model_config"),
22 py::arg("lm_config") = OfflineLMConfig(), 22 py::arg("lm_config") = OfflineLMConfig(),
23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
24 py::arg("decoding_method") = "greedy_search", 24 py::arg("decoding_method") = "greedy_search",
25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
26 - py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0) 26 + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
  27 + py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
27 .def_readwrite("feat_config", &PyClass::feat_config) 28 .def_readwrite("feat_config", &PyClass::feat_config)
28 .def_readwrite("model_config", &PyClass::model_config) 29 .def_readwrite("model_config", &PyClass::model_config)
29 .def_readwrite("lm_config", &PyClass::lm_config) 30 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
33 .def_readwrite("hotwords_file", &PyClass::hotwords_file) 34 .def_readwrite("hotwords_file", &PyClass::hotwords_file)
34 .def_readwrite("hotwords_score", &PyClass::hotwords_score) 35 .def_readwrite("hotwords_score", &PyClass::hotwords_score)
35 .def_readwrite("blank_penalty", &PyClass::blank_penalty) 36 .def_readwrite("blank_penalty", &PyClass::blank_penalty)
  37 + .def_readwrite("rule_fsts", &PyClass::rule_fsts)
  38 + .def_readwrite("rule_fars", &PyClass::rule_fars)
36 .def("__str__", &PyClass::ToString); 39 .def("__str__", &PyClass::ToString);
37 } 40 }
38 41
@@ -54,6 +54,8 @@ class OfflineRecognizer(object): @@ -54,6 +54,8 @@ class OfflineRecognizer(object):
54 debug: bool = False, 54 debug: bool = False,
55 provider: str = "cpu", 55 provider: str = "cpu",
56 model_type: str = "transducer", 56 model_type: str = "transducer",
  57 + rule_fsts: str = "",
  58 + rule_fars: str = "",
57 ): 59 ):
58 """ 60 """
59 Please refer to 61 Please refer to
@@ -107,6 +109,12 @@ class OfflineRecognizer(object): @@ -107,6 +109,12 @@ class OfflineRecognizer(object):
107 True to show debug messages. 109 True to show debug messages.
108 provider: 110 provider:
109 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 111 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  112 + rule_fsts:
  113 + If not empty, it specifies fsts for inverse text normalization.
  114 + If there are multiple fsts, they are separated by a comma.
  115 + rule_fars:
  116 + If not empty, it specifies fst archives for inverse text normalization.
  117 + If there are multiple archives, they are separated by a comma.
110 """ 118 """
111 self = cls.__new__(cls) 119 self = cls.__new__(cls)
112 model_config = OfflineModelConfig( 120 model_config = OfflineModelConfig(
@@ -143,6 +151,8 @@ class OfflineRecognizer(object): @@ -143,6 +151,8 @@ class OfflineRecognizer(object):
143 hotwords_file=hotwords_file, 151 hotwords_file=hotwords_file,
144 hotwords_score=hotwords_score, 152 hotwords_score=hotwords_score,
145 blank_penalty=blank_penalty, 153 blank_penalty=blank_penalty,
  154 + rule_fsts=rule_fsts,
  155 + rule_fars=rule_fars,
146 ) 156 )
147 self.recognizer = _Recognizer(recognizer_config) 157 self.recognizer = _Recognizer(recognizer_config)
148 self.config = recognizer_config 158 self.config = recognizer_config
@@ -159,6 +169,8 @@ class OfflineRecognizer(object): @@ -159,6 +169,8 @@ class OfflineRecognizer(object):
159 decoding_method: str = "greedy_search", 169 decoding_method: str = "greedy_search",
160 debug: bool = False, 170 debug: bool = False,
161 provider: str = "cpu", 171 provider: str = "cpu",
  172 + rule_fsts: str = "",
  173 + rule_fars: str = "",
162 ): 174 ):
163 """ 175 """
164 Please refer to 176 Please refer to
@@ -186,6 +198,12 @@ class OfflineRecognizer(object): @@ -186,6 +198,12 @@ class OfflineRecognizer(object):
186 True to show debug messages. 198 True to show debug messages.
187 provider: 199 provider:
188 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 200 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  201 + rule_fsts:
  202 + If not empty, it specifies fsts for inverse text normalization.
  203 + If there are multiple fsts, they are separated by a comma.
  204 + rule_fars:
  205 + If not empty, it specifies fst archives for inverse text normalization.
  206 + If there are multiple archives, they are separated by a comma.
189 """ 207 """
190 self = cls.__new__(cls) 208 self = cls.__new__(cls)
191 model_config = OfflineModelConfig( 209 model_config = OfflineModelConfig(
@@ -206,6 +224,8 @@ class OfflineRecognizer(object): @@ -206,6 +224,8 @@ class OfflineRecognizer(object):
206 feat_config=feat_config, 224 feat_config=feat_config,
207 model_config=model_config, 225 model_config=model_config,
208 decoding_method=decoding_method, 226 decoding_method=decoding_method,
  227 + rule_fsts=rule_fsts,
  228 + rule_fars=rule_fars,
209 ) 229 )
210 self.recognizer = _Recognizer(recognizer_config) 230 self.recognizer = _Recognizer(recognizer_config)
211 self.config = recognizer_config 231 self.config = recognizer_config
@@ -222,6 +242,8 @@ class OfflineRecognizer(object): @@ -222,6 +242,8 @@ class OfflineRecognizer(object):
222 decoding_method: str = "greedy_search", 242 decoding_method: str = "greedy_search",
223 debug: bool = False, 243 debug: bool = False,
224 provider: str = "cpu", 244 provider: str = "cpu",
  245 + rule_fsts: str = "",
  246 + rule_fars: str = "",
225 ): 247 ):
226 """ 248 """
227 Please refer to 249 Please refer to
@@ -251,6 +273,12 @@ class OfflineRecognizer(object): @@ -251,6 +273,12 @@ class OfflineRecognizer(object):
251 True to show debug messages. 273 True to show debug messages.
252 provider: 274 provider:
253 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 275 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  276 + rule_fsts:
  277 + If not empty, it specifies fsts for inverse text normalization.
  278 + If there are multiple fsts, they are separated by a comma.
  279 + rule_fars:
  280 + If not empty, it specifies fst archives for inverse text normalization.
  281 + If there are multiple archives, they are separated by a comma.
254 """ 282 """
255 self = cls.__new__(cls) 283 self = cls.__new__(cls)
256 model_config = OfflineModelConfig( 284 model_config = OfflineModelConfig(
@@ -271,6 +299,8 @@ class OfflineRecognizer(object): @@ -271,6 +299,8 @@ class OfflineRecognizer(object):
271 feat_config=feat_config, 299 feat_config=feat_config,
272 model_config=model_config, 300 model_config=model_config,
273 decoding_method=decoding_method, 301 decoding_method=decoding_method,
  302 + rule_fsts=rule_fsts,
  303 + rule_fars=rule_fars,
274 ) 304 )
275 self.recognizer = _Recognizer(recognizer_config) 305 self.recognizer = _Recognizer(recognizer_config)
276 self.config = recognizer_config 306 self.config = recognizer_config
@@ -287,6 +317,8 @@ class OfflineRecognizer(object): @@ -287,6 +317,8 @@ class OfflineRecognizer(object):
287 decoding_method: str = "greedy_search", 317 decoding_method: str = "greedy_search",
288 debug: bool = False, 318 debug: bool = False,
289 provider: str = "cpu", 319 provider: str = "cpu",
  320 + rule_fsts: str = "",
  321 + rule_fars: str = "",
290 ): 322 ):
291 """ 323 """
292 Please refer to 324 Please refer to
@@ -315,6 +347,12 @@ class OfflineRecognizer(object): @@ -315,6 +347,12 @@ class OfflineRecognizer(object):
315 True to show debug messages. 347 True to show debug messages.
316 provider: 348 provider:
317 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 349 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  350 + rule_fsts:
  351 + If not empty, it specifies fsts for inverse text normalization.
  352 + If there are multiple fsts, they are separated by a comma.
  353 + rule_fars:
  354 + If not empty, it specifies fst archives for inverse text normalization.
  355 + If there are multiple archives, they are separated by a comma.
318 """ 356 """
319 self = cls.__new__(cls) 357 self = cls.__new__(cls)
320 model_config = OfflineModelConfig( 358 model_config = OfflineModelConfig(
@@ -335,6 +373,8 @@ class OfflineRecognizer(object): @@ -335,6 +373,8 @@ class OfflineRecognizer(object):
335 feat_config=feat_config, 373 feat_config=feat_config,
336 model_config=model_config, 374 model_config=model_config,
337 decoding_method=decoding_method, 375 decoding_method=decoding_method,
  376 + rule_fsts=rule_fsts,
  377 + rule_fars=rule_fars,
338 ) 378 )
339 self.recognizer = _Recognizer(recognizer_config) 379 self.recognizer = _Recognizer(recognizer_config)
340 self.config = recognizer_config 380 self.config = recognizer_config
@@ -353,6 +393,8 @@ class OfflineRecognizer(object): @@ -353,6 +393,8 @@ class OfflineRecognizer(object):
353 debug: bool = False, 393 debug: bool = False,
354 provider: str = "cpu", 394 provider: str = "cpu",
355 tail_paddings: int = -1, 395 tail_paddings: int = -1,
  396 + rule_fsts: str = "",
  397 + rule_fars: str = "",
356 ): 398 ):
357 """ 399 """
358 Please refer to 400 Please refer to
@@ -389,6 +431,12 @@ class OfflineRecognizer(object): @@ -389,6 +431,12 @@ class OfflineRecognizer(object):
389 True to show debug messages. 431 True to show debug messages.
390 provider: 432 provider:
391 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 433 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  434 + rule_fsts:
  435 + If not empty, it specifies fsts for inverse text normalization.
  436 + If there are multiple fsts, they are separated by a comma.
  437 + rule_fars:
  438 + If not empty, it specifies fst archives for inverse text normalization.
  439 + If there are multiple archives, they are separated by a comma.
392 """ 440 """
393 self = cls.__new__(cls) 441 self = cls.__new__(cls)
394 model_config = OfflineModelConfig( 442 model_config = OfflineModelConfig(
@@ -415,6 +463,8 @@ class OfflineRecognizer(object): @@ -415,6 +463,8 @@ class OfflineRecognizer(object):
415 feat_config=feat_config, 463 feat_config=feat_config,
416 model_config=model_config, 464 model_config=model_config,
417 decoding_method=decoding_method, 465 decoding_method=decoding_method,
  466 + rule_fsts=rule_fsts,
  467 + rule_fars=rule_fars,
418 ) 468 )
419 self.recognizer = _Recognizer(recognizer_config) 469 self.recognizer = _Recognizer(recognizer_config)
420 self.config = recognizer_config 470 self.config = recognizer_config
@@ -431,6 +481,8 @@ class OfflineRecognizer(object): @@ -431,6 +481,8 @@ class OfflineRecognizer(object):
431 decoding_method: str = "greedy_search", 481 decoding_method: str = "greedy_search",
432 debug: bool = False, 482 debug: bool = False,
433 provider: str = "cpu", 483 provider: str = "cpu",
  484 + rule_fsts: str = "",
  485 + rule_fars: str = "",
434 ): 486 ):
435 """ 487 """
436 Please refer to 488 Please refer to
@@ -458,6 +510,12 @@ class OfflineRecognizer(object): @@ -458,6 +510,12 @@ class OfflineRecognizer(object):
458 True to show debug messages. 510 True to show debug messages.
459 provider: 511 provider:
460 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 512 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  513 + rule_fsts:
  514 + If not empty, it specifies fsts for inverse text normalization.
  515 + If there are multiple fsts, they are separated by a comma.
  516 + rule_fars:
  517 + If not empty, it specifies fst archives for inverse text normalization.
  518 + If there are multiple archives, they are separated by a comma.
461 """ 519 """
462 self = cls.__new__(cls) 520 self = cls.__new__(cls)
463 model_config = OfflineModelConfig( 521 model_config = OfflineModelConfig(
@@ -478,6 +536,8 @@ class OfflineRecognizer(object): @@ -478,6 +536,8 @@ class OfflineRecognizer(object):
478 feat_config=feat_config, 536 feat_config=feat_config,
479 model_config=model_config, 537 model_config=model_config,
480 decoding_method=decoding_method, 538 decoding_method=decoding_method,
  539 + rule_fsts=rule_fsts,
  540 + rule_fars=rule_fars,
481 ) 541 )
482 self.recognizer = _Recognizer(recognizer_config) 542 self.recognizer = _Recognizer(recognizer_config)
483 self.config = recognizer_config 543 self.config = recognizer_config
@@ -494,6 +554,8 @@ class OfflineRecognizer(object): @@ -494,6 +554,8 @@ class OfflineRecognizer(object):
494 decoding_method: str = "greedy_search", 554 decoding_method: str = "greedy_search",
495 debug: bool = False, 555 debug: bool = False,
496 provider: str = "cpu", 556 provider: str = "cpu",
  557 + rule_fsts: str = "",
  558 + rule_fars: str = "",
497 ): 559 ):
498 """ 560 """
499 Please refer to 561 Please refer to
@@ -522,6 +584,12 @@ class OfflineRecognizer(object): @@ -522,6 +584,12 @@ class OfflineRecognizer(object):
522 True to show debug messages. 584 True to show debug messages.
523 provider: 585 provider:
524 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 586 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  587 + rule_fsts:
  588 + If not empty, it specifies fsts for inverse text normalization.
  589 + If there are multiple fsts, they are separated by a comma.
  590 + rule_fars:
  591 + If not empty, it specifies fst archives for inverse text normalization.
  592 + If there are multiple archives, they are separated by a comma.
525 """ 593 """
526 self = cls.__new__(cls) 594 self = cls.__new__(cls)
527 model_config = OfflineModelConfig( 595 model_config = OfflineModelConfig(
@@ -542,6 +610,8 @@ class OfflineRecognizer(object): @@ -542,6 +610,8 @@ class OfflineRecognizer(object):
542 feat_config=feat_config, 610 feat_config=feat_config,
543 model_config=model_config, 611 model_config=model_config,
544 decoding_method=decoding_method, 612 decoding_method=decoding_method,
  613 + rule_fsts=rule_fsts,
  614 + rule_fars=rule_fars,
545 ) 615 )
546 self.recognizer = _Recognizer(recognizer_config) 616 self.recognizer = _Recognizer(recognizer_config)
547 self.config = recognizer_config 617 self.config = recognizer_config