Committed by
GitHub
Add inverse text normalization for online ASR (#1020)
正在显示
12 个修改的文件
包含
378 行增加
和
20 行删除
| @@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then | @@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then | ||
| 256 | $repo/test_wavs/3.wav \ | 256 | $repo/test_wavs/3.wav \ |
| 257 | $repo/test_wavs/8k.wav | 257 | $repo/test_wavs/8k.wav |
| 258 | 258 | ||
| 259 | + ln -s $repo $PWD/ | ||
| 260 | + | ||
| 261 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst | ||
| 262 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav | ||
| 263 | + | ||
| 264 | + python3 ./python-api-examples/inverse-text-normalization-online-asr.py | ||
| 265 | + | ||
| 259 | python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose | 266 | python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose |
| 267 | + | ||
| 268 | + rm -rfv sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 | ||
| 269 | + | ||
| 270 | + rm -rf $repo | ||
| 260 | fi | 271 | fi |
| 261 | 272 | ||
| 262 | log "Test non-streaming transducer models" | 273 | log "Test non-streaming transducer models" |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +This script shows how to use inverse text normalization with streaming ASR. | ||
| 7 | + | ||
| 8 | +Usage: | ||
| 9 | + | ||
| 10 | +(1) Download the test model | ||
| 11 | + | ||
| 12 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 | ||
| 13 | + | ||
| 14 | +(2) Download rule fst | ||
| 15 | + | ||
| 16 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst | ||
| 17 | + | ||
| 18 | +Please refer to | ||
| 19 | +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb | ||
| 20 | +for how itn_zh_number.fst is generated. | ||
| 21 | + | ||
| 22 | +(3) Download test wave | ||
| 23 | + | ||
| 24 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav | ||
| 25 | + | ||
| 26 | +(4) Run this script | ||
| 27 | + | ||
| 28 | +python3 ./python-api-examples/inverse-text-normalization-online-asr.py | ||
| 29 | +""" | ||
| 30 | +from pathlib import Path | ||
| 31 | + | ||
| 32 | +import sherpa_onnx | ||
| 33 | +import soundfile as sf | ||
| 34 | + | ||
| 35 | + | ||
| 36 | +def create_recognizer(): | ||
| 37 | + encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" | ||
| 38 | + decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" | ||
| 39 | + joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" | ||
| 40 | + tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" | ||
| 41 | + rule_fsts = "./itn_zh_number.fst" | ||
| 42 | + | ||
| 43 | + if ( | ||
| 44 | + not Path(encoder).is_file() | ||
| 45 | + or not Path(decoder).is_file() | ||
| 46 | + or not Path(joiner).is_file() | ||
| 47 | + or not Path(tokens).is_file() | ||
| 48 | + or not Path(rule_fsts).is_file() | ||
| 49 | + ): | ||
| 50 | + raise ValueError( | ||
| 51 | + """Please download model files from | ||
| 52 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 53 | + """ | ||
| 54 | + ) | ||
| 55 | + return sherpa_onnx.OnlineRecognizer.from_transducer( | ||
| 56 | + encoder=encoder, | ||
| 57 | + decoder=decoder, | ||
| 58 | + joiner=joiner, | ||
| 59 | + tokens=tokens, | ||
| 60 | + debug=True, | ||
| 61 | + rule_fsts=rule_fsts, | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +def main(): | ||
| 66 | + recognizer = create_recognizer() | ||
| 67 | + wave_filename = "./itn-zh-number.wav" | ||
| 68 | + if not Path(wave_filename).is_file(): | ||
| 69 | + raise ValueError( | ||
| 70 | + """Please download model files from | ||
| 71 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 72 | + """ | ||
| 73 | + ) | ||
| 74 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 75 | + audio = audio[:, 0] # only use the first channel | ||
| 76 | + | ||
| 77 | + stream = recognizer.create_stream() | ||
| 78 | + stream.accept_waveform(sample_rate, audio) | ||
| 79 | + | ||
| 80 | + tail_padding = [0] * int(0.3 * sample_rate) | ||
| 81 | + stream.accept_waveform(sample_rate, tail_padding) | ||
| 82 | + | ||
| 83 | + while recognizer.is_ready(stream): | ||
| 84 | + recognizer.decode_stream(stream) | ||
| 85 | + | ||
| 86 | + print(wave_filename) | ||
| 87 | + print(recognizer.get_result_all(stream)) | ||
| 88 | + | ||
| 89 | + | ||
| 90 | +if __name__ == "__main__": | ||
| 91 | + main() |
| @@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | @@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 68 | class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | 68 | class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { |
| 69 | public: | 69 | public: |
| 70 | explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) | 70 | explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) |
| 71 | - : config_(config), | 71 | + : OnlineRecognizerImpl(config), |
| 72 | + config_(config), | ||
| 72 | model_(OnlineCtcModel::Create(config.model_config)), | 73 | model_(OnlineCtcModel::Create(config.model_config)), |
| 73 | sym_(config.model_config.tokens), | 74 | sym_(config.model_config.tokens), |
| 74 | endpoint_(config_.endpoint_config) { | 75 | endpoint_(config_.endpoint_config) { |
| @@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 84 | #if __ANDROID_API__ >= 9 | 85 | #if __ANDROID_API__ >= 9 |
| 85 | explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, | 86 | explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, |
| 86 | const OnlineRecognizerConfig &config) | 87 | const OnlineRecognizerConfig &config) |
| 87 | - : config_(config), | 88 | + : OnlineRecognizerImpl(mgr, config), |
| 89 | + config_(config), | ||
| 88 | model_(OnlineCtcModel::Create(mgr, config.model_config)), | 90 | model_(OnlineCtcModel::Create(mgr, config.model_config)), |
| 89 | sym_(mgr, config.model_config.tokens), | 91 | sym_(mgr, config.model_config.tokens), |
| 90 | endpoint_(config_.endpoint_config) { | 92 | endpoint_(config_.endpoint_config) { |
| @@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 182 | // TODO(fangjun): Remember to change these constants if needed | 184 | // TODO(fangjun): Remember to change these constants if needed |
| 183 | int32_t frame_shift_ms = 10; | 185 | int32_t frame_shift_ms = 10; |
| 184 | int32_t subsampling_factor = 4; | 186 | int32_t subsampling_factor = 4; |
| 185 | - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 187 | + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 186 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 188 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 189 | + r.text = ApplyInverseTextNormalization(r.text); | ||
| 190 | + return r; | ||
| 187 | } | 191 | } |
| 188 | 192 | ||
| 189 | bool IsEndpoint(OnlineStream *s) const override { | 193 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -4,11 +4,22 @@ | @@ -4,11 +4,22 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 6 | 6 | ||
| 7 | +#if __ANDROID_API__ >= 9 | ||
| 8 | +#include <strstream> | ||
| 9 | + | ||
| 10 | +#include "android/asset_manager.h" | ||
| 11 | +#include "android/asset_manager_jni.h" | ||
| 12 | +#endif | ||
| 13 | + | ||
| 14 | +#include "fst/extensions/far/far.h" | ||
| 15 | +#include "kaldifst/csrc/kaldi-fst-io.h" | ||
| 16 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 7 | #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" | 17 | #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" |
| 8 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" | 18 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" |
| 9 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" | 19 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" |
| 10 | #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" | 20 | #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" |
| 11 | #include "sherpa-onnx/csrc/onnx-utils.h" | 21 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 22 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 12 | 23 | ||
| 13 | namespace sherpa_onnx { | 24 | namespace sherpa_onnx { |
| 14 | 25 | ||
| @@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 78 | } | 89 | } |
| 79 | #endif | 90 | #endif |
| 80 | 91 | ||
| 92 | +OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) | ||
| 93 | + : config_(config) { | ||
| 94 | + if (!config.rule_fsts.empty()) { | ||
| 95 | + std::vector<std::string> files; | ||
| 96 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 97 | + itn_list_.reserve(files.size()); | ||
| 98 | + for (const auto &f : files) { | ||
| 99 | + if (config.model_config.debug) { | ||
| 100 | + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); | ||
| 101 | + } | ||
| 102 | + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f)); | ||
| 103 | + } | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + if (!config.rule_fars.empty()) { | ||
| 107 | + if (config.model_config.debug) { | ||
| 108 | + SHERPA_ONNX_LOGE("Loading FST archives"); | ||
| 109 | + } | ||
| 110 | + std::vector<std::string> files; | ||
| 111 | + SplitStringToVector(config.rule_fars, ",", false, &files); | ||
| 112 | + | ||
| 113 | + itn_list_.reserve(files.size() + itn_list_.size()); | ||
| 114 | + | ||
| 115 | + for (const auto &f : files) { | ||
| 116 | + if (config.model_config.debug) { | ||
| 117 | + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); | ||
| 118 | + } | ||
| 119 | + std::unique_ptr<fst::FarReader<fst::StdArc>> reader( | ||
| 120 | + fst::FarReader<fst::StdArc>::Open(f)); | ||
| 121 | + for (; !reader->Done(); reader->Next()) { | ||
| 122 | + std::unique_ptr<fst::StdConstFst> r( | ||
| 123 | + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); | ||
| 124 | + | ||
| 125 | + itn_list_.push_back( | ||
| 126 | + std::make_unique<kaldifst::TextNormalizer>(std::move(r))); | ||
| 127 | + } | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + if (config.model_config.debug) { | ||
| 131 | + SHERPA_ONNX_LOGE("FST archives loaded!"); | ||
| 132 | + } | ||
| 133 | + } | ||
| 134 | +} | ||
| 135 | + | ||
| 136 | +#if __ANDROID_API__ >= 9 | ||
| 137 | +OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr, | ||
| 138 | + const OnlineRecognizerConfig &config) | ||
| 139 | + : config_(config) { | ||
| 140 | + if (!config.rule_fsts.empty()) { | ||
| 141 | + std::vector<std::string> files; | ||
| 142 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 143 | + itn_list_.reserve(files.size()); | ||
| 144 | + for (const auto &f : files) { | ||
| 145 | + if (config.model_config.debug) { | ||
| 146 | + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); | ||
| 147 | + } | ||
| 148 | + auto buf = ReadFile(mgr, f); | ||
| 149 | + std::istrstream is(buf.data(), buf.size()); | ||
| 150 | + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is)); | ||
| 151 | + } | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + if (!config.rule_fars.empty()) { | ||
| 155 | + std::vector<std::string> files; | ||
| 156 | + SplitStringToVector(config.rule_fars, ",", false, &files); | ||
| 157 | + itn_list_.reserve(files.size() + itn_list_.size()); | ||
| 158 | + | ||
| 159 | + for (const auto &f : files) { | ||
| 160 | + if (config.model_config.debug) { | ||
| 161 | + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + auto buf = ReadFile(mgr, f); | ||
| 165 | + | ||
| 166 | + std::unique_ptr<std::istream> s( | ||
| 167 | + new std::istrstream(buf.data(), buf.size())); | ||
| 168 | + | ||
| 169 | + std::unique_ptr<fst::FarReader<fst::StdArc>> reader( | ||
| 170 | + fst::FarReader<fst::StdArc>::Open(std::move(s))); | ||
| 171 | + | ||
| 172 | + for (; !reader->Done(); reader->Next()) { | ||
| 173 | + std::unique_ptr<fst::StdConstFst> r( | ||
| 174 | + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); | ||
| 175 | + | ||
| 176 | + itn_list_.push_back( | ||
| 177 | + std::make_unique<kaldifst::TextNormalizer>(std::move(r))); | ||
| 178 | + } // for (; !reader->Done(); reader->Next()) | ||
| 179 | + } // for (const auto &f : files) | ||
| 180 | + } // if (!config.rule_fars.empty()) | ||
| 181 | +} | ||
| 182 | +#endif | ||
| 183 | + | ||
| 184 | +std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( | ||
| 185 | + std::string text) const { | ||
| 186 | + if (!itn_list_.empty()) { | ||
| 187 | + for (const auto &tn : itn_list_) { | ||
| 188 | + text = tn->Normalize(text); | ||
| 189 | + if (config_.model_config.debug) { | ||
| 190 | + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str()); | ||
| 191 | + } | ||
| 192 | + } | ||
| 193 | + } | ||
| 194 | + | ||
| 195 | + return text; | ||
| 196 | +} | ||
| 197 | + | ||
| 81 | } // namespace sherpa_onnx | 198 | } // namespace sherpa_onnx |
| @@ -9,6 +9,12 @@ | @@ -9,6 +9,12 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#if __ANDROID_API__ >= 9 | ||
| 13 | +#include "android/asset_manager.h" | ||
| 14 | +#include "android/asset_manager_jni.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "kaldifst/csrc/text-normalizer.h" | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 18 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/online-recognizer.h" | 19 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 14 | #include "sherpa-onnx/csrc/online-stream.h" | 20 | #include "sherpa-onnx/csrc/online-stream.h" |
| @@ -17,10 +23,15 @@ namespace sherpa_onnx { | @@ -17,10 +23,15 @@ namespace sherpa_onnx { | ||
| 17 | 23 | ||
| 18 | class OnlineRecognizerImpl { | 24 | class OnlineRecognizerImpl { |
| 19 | public: | 25 | public: |
| 26 | + explicit OnlineRecognizerImpl(const OnlineRecognizerConfig &config); | ||
| 27 | + | ||
| 20 | static std::unique_ptr<OnlineRecognizerImpl> Create( | 28 | static std::unique_ptr<OnlineRecognizerImpl> Create( |
| 21 | const OnlineRecognizerConfig &config); | 29 | const OnlineRecognizerConfig &config); |
| 22 | 30 | ||
| 23 | #if __ANDROID_API__ >= 9 | 31 | #if __ANDROID_API__ >= 9 |
| 32 | + OnlineRecognizerImpl(AAssetManager *mgr, | ||
| 33 | + const OnlineRecognizerConfig &config); | ||
| 34 | + | ||
| 24 | static std::unique_ptr<OnlineRecognizerImpl> Create( | 35 | static std::unique_ptr<OnlineRecognizerImpl> Create( |
| 25 | AAssetManager *mgr, const OnlineRecognizerConfig &config); | 36 | AAssetManager *mgr, const OnlineRecognizerConfig &config); |
| 26 | #endif | 37 | #endif |
| @@ -50,6 +61,15 @@ class OnlineRecognizerImpl { | @@ -50,6 +61,15 @@ class OnlineRecognizerImpl { | ||
| 50 | virtual bool IsEndpoint(OnlineStream *s) const = 0; | 61 | virtual bool IsEndpoint(OnlineStream *s) const = 0; |
| 51 | 62 | ||
| 52 | virtual void Reset(OnlineStream *s) const = 0; | 63 | virtual void Reset(OnlineStream *s) const = 0; |
| 64 | + | ||
| 65 | + std::string ApplyInverseTextNormalization(std::string text) const; | ||
| 66 | + | ||
| 67 | + private: | ||
| 68 | + OnlineRecognizerConfig config_; | ||
| 69 | + // for inverse text normalization. Used only if | ||
| 70 | + // config.rule_fsts is not empty or | ||
| 71 | + // config.rule_fars is not empty | ||
| 72 | + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; | ||
| 53 | }; | 73 | }; |
| 54 | 74 | ||
| 55 | } // namespace sherpa_onnx | 75 | } // namespace sherpa_onnx |
| @@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) { | @@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) { | ||
| 96 | class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | 96 | class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { |
| 97 | public: | 97 | public: |
| 98 | explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) | 98 | explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) |
| 99 | - : config_(config), | 99 | + : OnlineRecognizerImpl(config), |
| 100 | + config_(config), | ||
| 100 | model_(config.model_config), | 101 | model_(config.model_config), |
| 101 | sym_(config.model_config.tokens), | 102 | sym_(config.model_config.tokens), |
| 102 | endpoint_(config_.endpoint_config) { | 103 | endpoint_(config_.endpoint_config) { |
| @@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | @@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 116 | #if __ANDROID_API__ >= 9 | 117 | #if __ANDROID_API__ >= 9 |
| 117 | explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, | 118 | explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, |
| 118 | const OnlineRecognizerConfig &config) | 119 | const OnlineRecognizerConfig &config) |
| 119 | - : config_(config), | 120 | + : OnlineRecognizerImpl(mgr, config), |
| 121 | + config_(config), | ||
| 120 | model_(mgr, config.model_config), | 122 | model_(mgr, config.model_config), |
| 121 | sym_(mgr, config.model_config.tokens), | 123 | sym_(mgr, config.model_config.tokens), |
| 122 | endpoint_(config_.endpoint_config) { | 124 | endpoint_(config_.endpoint_config) { |
| @@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | @@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 160 | OnlineRecognizerResult GetResult(OnlineStream *s) const override { | 162 | OnlineRecognizerResult GetResult(OnlineStream *s) const override { |
| 161 | auto decoder_result = s->GetParaformerResult(); | 163 | auto decoder_result = s->GetParaformerResult(); |
| 162 | 164 | ||
| 163 | - return Convert(decoder_result, sym_); | 165 | + auto r = Convert(decoder_result, sym_); |
| 166 | + r.text = ApplyInverseTextNormalization(r.text); | ||
| 167 | + return r; | ||
| 164 | } | 168 | } |
| 165 | 169 | ||
| 166 | bool IsEndpoint(OnlineStream *s) const override { | 170 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 80 | class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | 80 | class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { |
| 81 | public: | 81 | public: |
| 82 | explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) | 82 | explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) |
| 83 | - : config_(config), | 83 | + : OnlineRecognizerImpl(config), |
| 84 | + config_(config), | ||
| 84 | model_(OnlineTransducerModel::Create(config.model_config)), | 85 | model_(OnlineTransducerModel::Create(config.model_config)), |
| 85 | sym_(config.model_config.tokens), | 86 | sym_(config.model_config.tokens), |
| 86 | endpoint_(config_.endpoint_config) { | 87 | endpoint_(config_.endpoint_config) { |
| @@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 124 | #if __ANDROID_API__ >= 9 | 125 | #if __ANDROID_API__ >= 9 |
| 125 | explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, | 126 | explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, |
| 126 | const OnlineRecognizerConfig &config) | 127 | const OnlineRecognizerConfig &config) |
| 127 | - : config_(config), | 128 | + : OnlineRecognizerImpl(mgr, config), |
| 129 | + config_(config), | ||
| 128 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), | 130 | model_(OnlineTransducerModel::Create(mgr, config.model_config)), |
| 129 | sym_(mgr, config.model_config.tokens), | 131 | sym_(mgr, config.model_config.tokens), |
| 130 | endpoint_(config_.endpoint_config) { | 132 | endpoint_(config_.endpoint_config) { |
| @@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 332 | // TODO(fangjun): Remember to change these constants if needed | 334 | // TODO(fangjun): Remember to change these constants if needed |
| 333 | int32_t frame_shift_ms = 10; | 335 | int32_t frame_shift_ms = 10; |
| 334 | int32_t subsampling_factor = 4; | 336 | int32_t subsampling_factor = 4; |
| 335 | - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 337 | + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 336 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 338 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 339 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 340 | + return r; | ||
| 337 | } | 341 | } |
| 338 | 342 | ||
| 339 | bool IsEndpoint(OnlineStream *s) const override { | 343 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 42 | public: | 42 | public: |
| 43 | explicit OnlineRecognizerTransducerNeMoImpl( | 43 | explicit OnlineRecognizerTransducerNeMoImpl( |
| 44 | const OnlineRecognizerConfig &config) | 44 | const OnlineRecognizerConfig &config) |
| 45 | - : config_(config), | 45 | + : OnlineRecognizerImpl(config), |
| 46 | + config_(config), | ||
| 46 | symbol_table_(config.model_config.tokens), | 47 | symbol_table_(config.model_config.tokens), |
| 47 | endpoint_(config_.endpoint_config), | 48 | endpoint_(config_.endpoint_config), |
| 48 | model_( | 49 | model_( |
| @@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 61 | #if __ANDROID_API__ >= 9 | 62 | #if __ANDROID_API__ >= 9 |
| 62 | explicit OnlineRecognizerTransducerNeMoImpl( | 63 | explicit OnlineRecognizerTransducerNeMoImpl( |
| 63 | AAssetManager *mgr, const OnlineRecognizerConfig &config) | 64 | AAssetManager *mgr, const OnlineRecognizerConfig &config) |
| 64 | - : config_(config), | 65 | + : OnlineRecognizerImpl(mgr, config), |
| 66 | + config_(config), | ||
| 65 | symbol_table_(mgr, config.model_config.tokens), | 67 | symbol_table_(mgr, config.model_config.tokens), |
| 66 | endpoint_(config_.endpoint_config), | 68 | endpoint_(config_.endpoint_config), |
| 67 | model_(std::make_unique<OnlineTransducerNeMoModel>( | 69 | model_(std::make_unique<OnlineTransducerNeMoModel>( |
| @@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 94 | // TODO(fangjun): Remember to change these constants if needed | 96 | // TODO(fangjun): Remember to change these constants if needed |
| 95 | int32_t frame_shift_ms = 10; | 97 | int32_t frame_shift_ms = 10; |
| 96 | int32_t subsampling_factor = model_->SubsamplingFactor(); | 98 | int32_t subsampling_factor = model_->SubsamplingFactor(); |
| 97 | - return Convert(s->GetResult(), symbol_table_, frame_shift_ms, | 99 | + auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms, |
| 98 | subsampling_factor, s->GetCurrentSegment(), | 100 | subsampling_factor, s->GetCurrentSegment(), |
| 99 | s->GetNumFramesSinceStart()); | 101 | s->GetNumFramesSinceStart()); |
| 102 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 103 | + return r; | ||
| 100 | } | 104 | } |
| 101 | 105 | ||
| 102 | bool IsEndpoint(OnlineStream *s) const override { | 106 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -14,7 +14,9 @@ | @@ -14,7 +14,9 @@ | ||
| 14 | #include <utility> | 14 | #include <utility> |
| 15 | #include <vector> | 15 | #include <vector> |
| 16 | 16 | ||
| 17 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 17 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 18 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 19 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 18 | 20 | ||
| 19 | namespace sherpa_onnx { | 21 | namespace sherpa_onnx { |
| 20 | 22 | ||
| @@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 100 | "now support greedy_search and modified_beam_search."); | 102 | "now support greedy_search and modified_beam_search."); |
| 101 | po->Register("temperature-scale", &temperature_scale, | 103 | po->Register("temperature-scale", &temperature_scale, |
| 102 | "Temperature scale for confidence computation in decoding."); | 104 | "Temperature scale for confidence computation in decoding."); |
| 105 | + po->Register( | ||
| 106 | + "rule-fsts", &rule_fsts, | ||
| 107 | + "If not empty, it specifies fsts for inverse text normalization. " | ||
| 108 | + "If there are multiple fsts, they are separated by a comma."); | ||
| 109 | + | ||
| 110 | + po->Register( | ||
| 111 | + "rule-fars", &rule_fars, | ||
| 112 | + "If not empty, it specifies fst archives for inverse text normalization. " | ||
| 113 | + "If there are multiple archives, they are separated by a comma."); | ||
| 103 | } | 114 | } |
| 104 | 115 | ||
| 105 | bool OnlineRecognizerConfig::Validate() const { | 116 | bool OnlineRecognizerConfig::Validate() const { |
| @@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const { | @@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const { | ||
| 129 | return false; | 140 | return false; |
| 130 | } | 141 | } |
| 131 | 142 | ||
| 143 | + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { | ||
| 144 | + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", | ||
| 145 | + hotwords_file.c_str()); | ||
| 146 | + return false; | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + if (!rule_fsts.empty()) { | ||
| 150 | + std::vector<std::string> files; | ||
| 151 | + SplitStringToVector(rule_fsts, ",", false, &files); | ||
| 152 | + for (const auto &f : files) { | ||
| 153 | + if (!FileExists(f)) { | ||
| 154 | + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); | ||
| 155 | + return false; | ||
| 156 | + } | ||
| 157 | + } | ||
| 158 | + } | ||
| 159 | + | ||
| 160 | + if (!rule_fars.empty()) { | ||
| 161 | + std::vector<std::string> files; | ||
| 162 | + SplitStringToVector(rule_fars, ",", false, &files); | ||
| 163 | + for (const auto &f : files) { | ||
| 164 | + if (!FileExists(f)) { | ||
| 165 | + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); | ||
| 166 | + return false; | ||
| 167 | + } | ||
| 168 | + } | ||
| 169 | + } | ||
| 170 | + | ||
| 132 | return model_config.Validate(); | 171 | return model_config.Validate(); |
| 133 | } | 172 | } |
| 134 | 173 | ||
| @@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 147 | os << "hotwords_file=\"" << hotwords_file << "\", "; | 186 | os << "hotwords_file=\"" << hotwords_file << "\", "; |
| 148 | os << "decoding_method=\"" << decoding_method << "\", "; | 187 | os << "decoding_method=\"" << decoding_method << "\", "; |
| 149 | os << "blank_penalty=" << blank_penalty << ", "; | 188 | os << "blank_penalty=" << blank_penalty << ", "; |
| 150 | - os << "temperature_scale=" << temperature_scale << ")"; | 189 | + os << "temperature_scale=" << temperature_scale << ", "; |
| 190 | + os << "rule_fsts=\"" << rule_fsts << "\", "; | ||
| 191 | + os << "rule_fars=\"" << rule_fars << "\")"; | ||
| 151 | 192 | ||
| 152 | return os.str(); | 193 | return os.str(); |
| 153 | } | 194 | } |
| @@ -100,6 +100,12 @@ struct OnlineRecognizerConfig { | @@ -100,6 +100,12 @@ struct OnlineRecognizerConfig { | ||
| 100 | 100 | ||
| 101 | float temperature_scale = 2.0; | 101 | float temperature_scale = 2.0; |
| 102 | 102 | ||
| 103 | + // If there are multiple rules, they are applied from left to right. | ||
| 104 | + std::string rule_fsts; | ||
| 105 | + | ||
| 106 | + // If there are multiple FST archives, they are applied from left to right. | ||
| 107 | + std::string rule_fars; | ||
| 108 | + | ||
| 103 | OnlineRecognizerConfig() = default; | 109 | OnlineRecognizerConfig() = default; |
| 104 | 110 | ||
| 105 | OnlineRecognizerConfig( | 111 | OnlineRecognizerConfig( |
| @@ -109,7 +115,8 @@ struct OnlineRecognizerConfig { | @@ -109,7 +115,8 @@ struct OnlineRecognizerConfig { | ||
| 109 | const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, | 115 | const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, |
| 110 | bool enable_endpoint, const std::string &decoding_method, | 116 | bool enable_endpoint, const std::string &decoding_method, |
| 111 | int32_t max_active_paths, const std::string &hotwords_file, | 117 | int32_t max_active_paths, const std::string &hotwords_file, |
| 112 | - float hotwords_score, float blank_penalty, float temperature_scale) | 118 | + float hotwords_score, float blank_penalty, float temperature_scale, |
| 119 | + const std::string &rule_fsts, const std::string &rule_fars) | ||
| 113 | : feat_config(feat_config), | 120 | : feat_config(feat_config), |
| 114 | model_config(model_config), | 121 | model_config(model_config), |
| 115 | lm_config(lm_config), | 122 | lm_config(lm_config), |
| @@ -121,7 +128,9 @@ struct OnlineRecognizerConfig { | @@ -121,7 +128,9 @@ struct OnlineRecognizerConfig { | ||
| 121 | hotwords_file(hotwords_file), | 128 | hotwords_file(hotwords_file), |
| 122 | hotwords_score(hotwords_score), | 129 | hotwords_score(hotwords_score), |
| 123 | blank_penalty(blank_penalty), | 130 | blank_penalty(blank_penalty), |
| 124 | - temperature_scale(temperature_scale) {} | 131 | + temperature_scale(temperature_scale), |
| 132 | + rule_fsts(rule_fsts), | ||
| 133 | + rule_fars(rule_fars) {} | ||
| 125 | 134 | ||
| 126 | void Register(ParseOptions *po); | 135 | void Register(ParseOptions *po); |
| 127 | bool Validate() const; | 136 | bool Validate() const; |
| @@ -54,11 +54,11 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -54,11 +54,11 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 54 | static void PybindOnlineRecognizerConfig(py::module *m) { | 54 | static void PybindOnlineRecognizerConfig(py::module *m) { |
| 55 | using PyClass = OnlineRecognizerConfig; | 55 | using PyClass = OnlineRecognizerConfig; |
| 56 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 56 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 57 | - .def( | ||
| 58 | - py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | 57 | + .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, |
| 59 | const OnlineLMConfig &, const EndpointConfig &, | 58 | const OnlineLMConfig &, const EndpointConfig &, |
| 60 | - const OnlineCtcFstDecoderConfig &, bool, const std::string &, | ||
| 61 | - int32_t, const std::string &, float, float, float>(), | 59 | + const OnlineCtcFstDecoderConfig &, bool, |
| 60 | + const std::string &, int32_t, const std::string &, float, | ||
| 61 | + float, float, const std::string &, const std::string &>(), | ||
| 62 | py::arg("feat_config"), py::arg("model_config"), | 62 | py::arg("feat_config"), py::arg("model_config"), |
| 63 | py::arg("lm_config") = OnlineLMConfig(), | 63 | py::arg("lm_config") = OnlineLMConfig(), |
| 64 | py::arg("endpoint_config") = EndpointConfig(), | 64 | py::arg("endpoint_config") = EndpointConfig(), |
| @@ -66,7 +66,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -66,7 +66,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 66 | py::arg("enable_endpoint"), py::arg("decoding_method"), | 66 | py::arg("enable_endpoint"), py::arg("decoding_method"), |
| 67 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 67 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 68 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, | 68 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, |
| 69 | - py::arg("temperature_scale") = 2.0) | 69 | + py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", |
| 70 | + py::arg("rule_fars") = "") | ||
| 70 | .def_readwrite("feat_config", &PyClass::feat_config) | 71 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 71 | .def_readwrite("model_config", &PyClass::model_config) | 72 | .def_readwrite("model_config", &PyClass::model_config) |
| 72 | .def_readwrite("lm_config", &PyClass::lm_config) | 73 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 79 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) | 80 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) |
| 80 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) | 81 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) |
| 81 | .def_readwrite("temperature_scale", &PyClass::temperature_scale) | 82 | .def_readwrite("temperature_scale", &PyClass::temperature_scale) |
| 83 | + .def_readwrite("rule_fsts", &PyClass::rule_fsts) | ||
| 84 | + .def_readwrite("rule_fars", &PyClass::rule_fars) | ||
| 82 | .def("__str__", &PyClass::ToString); | 85 | .def("__str__", &PyClass::ToString); |
| 83 | } | 86 | } |
| 84 | 87 |
| @@ -64,6 +64,8 @@ class OnlineRecognizer(object): | @@ -64,6 +64,8 @@ class OnlineRecognizer(object): | ||
| 64 | lm_scale: float = 0.1, | 64 | lm_scale: float = 0.1, |
| 65 | temperature_scale: float = 2.0, | 65 | temperature_scale: float = 2.0, |
| 66 | debug: bool = False, | 66 | debug: bool = False, |
| 67 | + rule_fsts: str = "", | ||
| 68 | + rule_fars: str = "", | ||
| 67 | ): | 69 | ): |
| 68 | """ | 70 | """ |
| 69 | Please refer to | 71 | Please refer to |
| @@ -148,6 +150,12 @@ class OnlineRecognizer(object): | @@ -148,6 +150,12 @@ class OnlineRecognizer(object): | ||
| 148 | the log probability, you can get it from the directory where | 150 | the log probability, you can get it from the directory where |
| 149 | your bpe model is generated. Only used when hotwords provided | 151 | your bpe model is generated. Only used when hotwords provided |
| 150 | and the modeling unit is bpe or cjkchar+bpe. | 152 | and the modeling unit is bpe or cjkchar+bpe. |
| 153 | + rule_fsts: | ||
| 154 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 155 | + If there are multiple fsts, they are separated by a comma. | ||
| 156 | + rule_fars: | ||
| 157 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 158 | + If there are multiple archives, they are separated by a comma. | ||
| 151 | """ | 159 | """ |
| 152 | self = cls.__new__(cls) | 160 | self = cls.__new__(cls) |
| 153 | _assert_file_exists(tokens) | 161 | _assert_file_exists(tokens) |
| @@ -217,6 +225,8 @@ class OnlineRecognizer(object): | @@ -217,6 +225,8 @@ class OnlineRecognizer(object): | ||
| 217 | hotwords_file=hotwords_file, | 225 | hotwords_file=hotwords_file, |
| 218 | blank_penalty=blank_penalty, | 226 | blank_penalty=blank_penalty, |
| 219 | temperature_scale=temperature_scale, | 227 | temperature_scale=temperature_scale, |
| 228 | + rule_fsts=rule_fsts, | ||
| 229 | + rule_fars=rule_fars, | ||
| 220 | ) | 230 | ) |
| 221 | 231 | ||
| 222 | self.recognizer = _Recognizer(recognizer_config) | 232 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -239,6 +249,8 @@ class OnlineRecognizer(object): | @@ -239,6 +249,8 @@ class OnlineRecognizer(object): | ||
| 239 | decoding_method: str = "greedy_search", | 249 | decoding_method: str = "greedy_search", |
| 240 | provider: str = "cpu", | 250 | provider: str = "cpu", |
| 241 | debug: bool = False, | 251 | debug: bool = False, |
| 252 | + rule_fsts: str = "", | ||
| 253 | + rule_fars: str = "", | ||
| 242 | ): | 254 | ): |
| 243 | """ | 255 | """ |
| 244 | Please refer to | 256 | Please refer to |
| @@ -283,6 +295,12 @@ class OnlineRecognizer(object): | @@ -283,6 +295,12 @@ class OnlineRecognizer(object): | ||
| 283 | The only valid value is greedy_search. | 295 | The only valid value is greedy_search. |
| 284 | provider: | 296 | provider: |
| 285 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 297 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 298 | + rule_fsts: | ||
| 299 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 300 | + If there are multiple fsts, they are separated by a comma. | ||
| 301 | + rule_fars: | ||
| 302 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 303 | + If there are multiple archives, they are separated by a comma. | ||
| 286 | """ | 304 | """ |
| 287 | self = cls.__new__(cls) | 305 | self = cls.__new__(cls) |
| 288 | _assert_file_exists(tokens) | 306 | _assert_file_exists(tokens) |
| @@ -322,6 +340,8 @@ class OnlineRecognizer(object): | @@ -322,6 +340,8 @@ class OnlineRecognizer(object): | ||
| 322 | endpoint_config=endpoint_config, | 340 | endpoint_config=endpoint_config, |
| 323 | enable_endpoint=enable_endpoint_detection, | 341 | enable_endpoint=enable_endpoint_detection, |
| 324 | decoding_method=decoding_method, | 342 | decoding_method=decoding_method, |
| 343 | + rule_fsts=rule_fsts, | ||
| 344 | + rule_fars=rule_fars, | ||
| 325 | ) | 345 | ) |
| 326 | 346 | ||
| 327 | self.recognizer = _Recognizer(recognizer_config) | 347 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -345,6 +365,8 @@ class OnlineRecognizer(object): | @@ -345,6 +365,8 @@ class OnlineRecognizer(object): | ||
| 345 | ctc_max_active: int = 3000, | 365 | ctc_max_active: int = 3000, |
| 346 | provider: str = "cpu", | 366 | provider: str = "cpu", |
| 347 | debug: bool = False, | 367 | debug: bool = False, |
| 368 | + rule_fsts: str = "", | ||
| 369 | + rule_fars: str = "", | ||
| 348 | ): | 370 | ): |
| 349 | """ | 371 | """ |
| 350 | Please refer to | 372 | Please refer to |
| @@ -393,6 +415,12 @@ class OnlineRecognizer(object): | @@ -393,6 +415,12 @@ class OnlineRecognizer(object): | ||
| 393 | active paths at a time. | 415 | active paths at a time. |
| 394 | provider: | 416 | provider: |
| 395 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 417 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 418 | + rule_fsts: | ||
| 419 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 420 | + If there are multiple fsts, they are separated by a comma. | ||
| 421 | + rule_fars: | ||
| 422 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 423 | + If there are multiple archives, they are separated by a comma. | ||
| 396 | """ | 424 | """ |
| 397 | self = cls.__new__(cls) | 425 | self = cls.__new__(cls) |
| 398 | _assert_file_exists(tokens) | 426 | _assert_file_exists(tokens) |
| @@ -433,6 +461,8 @@ class OnlineRecognizer(object): | @@ -433,6 +461,8 @@ class OnlineRecognizer(object): | ||
| 433 | ctc_fst_decoder_config=ctc_fst_decoder_config, | 461 | ctc_fst_decoder_config=ctc_fst_decoder_config, |
| 434 | enable_endpoint=enable_endpoint_detection, | 462 | enable_endpoint=enable_endpoint_detection, |
| 435 | decoding_method=decoding_method, | 463 | decoding_method=decoding_method, |
| 464 | + rule_fsts=rule_fsts, | ||
| 465 | + rule_fars=rule_fars, | ||
| 436 | ) | 466 | ) |
| 437 | 467 | ||
| 438 | self.recognizer = _Recognizer(recognizer_config) | 468 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -454,6 +484,8 @@ class OnlineRecognizer(object): | @@ -454,6 +484,8 @@ class OnlineRecognizer(object): | ||
| 454 | decoding_method: str = "greedy_search", | 484 | decoding_method: str = "greedy_search", |
| 455 | provider: str = "cpu", | 485 | provider: str = "cpu", |
| 456 | debug: bool = False, | 486 | debug: bool = False, |
| 487 | + rule_fsts: str = "", | ||
| 488 | + rule_fars: str = "", | ||
| 457 | ): | 489 | ): |
| 458 | """ | 490 | """ |
| 459 | Please refer to | 491 | Please refer to |
| @@ -497,6 +529,12 @@ class OnlineRecognizer(object): | @@ -497,6 +529,12 @@ class OnlineRecognizer(object): | ||
| 497 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 529 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 498 | debug: | 530 | debug: |
| 499 | True to show meta data in the model. | 531 | True to show meta data in the model. |
| 532 | + rule_fsts: | ||
| 533 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 534 | + If there are multiple fsts, they are separated by a comma. | ||
| 535 | + rule_fars: | ||
| 536 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 537 | + If there are multiple archives, they are separated by a comma. | ||
| 500 | """ | 538 | """ |
| 501 | self = cls.__new__(cls) | 539 | self = cls.__new__(cls) |
| 502 | _assert_file_exists(tokens) | 540 | _assert_file_exists(tokens) |
| @@ -533,6 +571,8 @@ class OnlineRecognizer(object): | @@ -533,6 +571,8 @@ class OnlineRecognizer(object): | ||
| 533 | endpoint_config=endpoint_config, | 571 | endpoint_config=endpoint_config, |
| 534 | enable_endpoint=enable_endpoint_detection, | 572 | enable_endpoint=enable_endpoint_detection, |
| 535 | decoding_method=decoding_method, | 573 | decoding_method=decoding_method, |
| 574 | + rule_fsts=rule_fsts, | ||
| 575 | + rule_fars=rule_fars, | ||
| 536 | ) | 576 | ) |
| 537 | 577 | ||
| 538 | self.recognizer = _Recognizer(recognizer_config) | 578 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -556,6 +596,8 @@ class OnlineRecognizer(object): | @@ -556,6 +596,8 @@ class OnlineRecognizer(object): | ||
| 556 | decoding_method: str = "greedy_search", | 596 | decoding_method: str = "greedy_search", |
| 557 | provider: str = "cpu", | 597 | provider: str = "cpu", |
| 558 | debug: bool = False, | 598 | debug: bool = False, |
| 599 | + rule_fsts: str = "", | ||
| 600 | + rule_fars: str = "", | ||
| 559 | ): | 601 | ): |
| 560 | """ | 602 | """ |
| 561 | Please refer to | 603 | Please refer to |
| @@ -602,6 +644,12 @@ class OnlineRecognizer(object): | @@ -602,6 +644,12 @@ class OnlineRecognizer(object): | ||
| 602 | The only valid value is greedy_search. | 644 | The only valid value is greedy_search. |
| 603 | provider: | 645 | provider: |
| 604 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 646 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 647 | + rule_fsts: | ||
| 648 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 649 | + If there are multiple fsts, they are separated by a comma. | ||
| 650 | + rule_fars: | ||
| 651 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 652 | + If there are multiple archives, they are separated by a comma. | ||
| 605 | """ | 653 | """ |
| 606 | self = cls.__new__(cls) | 654 | self = cls.__new__(cls) |
| 607 | _assert_file_exists(tokens) | 655 | _assert_file_exists(tokens) |
| @@ -640,6 +688,8 @@ class OnlineRecognizer(object): | @@ -640,6 +688,8 @@ class OnlineRecognizer(object): | ||
| 640 | endpoint_config=endpoint_config, | 688 | endpoint_config=endpoint_config, |
| 641 | enable_endpoint=enable_endpoint_detection, | 689 | enable_endpoint=enable_endpoint_detection, |
| 642 | decoding_method=decoding_method, | 690 | decoding_method=decoding_method, |
| 691 | + rule_fsts=rule_fsts, | ||
| 692 | + rule_fars=rule_fars, | ||
| 643 | ) | 693 | ) |
| 644 | 694 | ||
| 645 | self.recognizer = _Recognizer(recognizer_config) | 695 | self.recognizer = _Recognizer(recognizer_config) |
-
请 注册 或 登录 后发表评论