Committed by
GitHub
Add inverse text normalization for non-streaming ASR (#1017)
正在显示
13 个修改的文件
包含
380 行增加
和
19 行删除
| @@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then | @@ -248,7 +248,7 @@ if [[ x$OS != x'windows-latest' ]]; then | ||
| 248 | python3 ./python-api-examples/online-decode-files.py \ | 248 | python3 ./python-api-examples/online-decode-files.py \ |
| 249 | --tokens=$repo/tokens.txt \ | 249 | --tokens=$repo/tokens.txt \ |
| 250 | --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | 250 | --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ |
| 251 | - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | 251 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ |
| 252 | --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | 252 | --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ |
| 253 | $repo/test_wavs/0.wav \ | 253 | $repo/test_wavs/0.wav \ |
| 254 | $repo/test_wavs/1.wav \ | 254 | $repo/test_wavs/1.wav \ |
| @@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \ | @@ -286,7 +286,7 @@ python3 ./python-api-examples/offline-decode-files.py \ | ||
| 286 | python3 ./python-api-examples/offline-decode-files.py \ | 286 | python3 ./python-api-examples/offline-decode-files.py \ |
| 287 | --tokens=$repo/tokens.txt \ | 287 | --tokens=$repo/tokens.txt \ |
| 288 | --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ | 288 | --encoder=$repo/encoder-epoch-99-avg-1.int8.onnx \ |
| 289 | - --decoder=$repo/decoder-epoch-99-avg-1.int8.onnx \ | 289 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ |
| 290 | --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ | 290 | --joiner=$repo/joiner-epoch-99-avg-1.int8.onnx \ |
| 291 | $repo/test_wavs/0.wav \ | 291 | $repo/test_wavs/0.wav \ |
| 292 | $repo/test_wavs/1.wav \ | 292 | $repo/test_wavs/1.wav \ |
| @@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then | @@ -330,6 +330,15 @@ if [[ x$OS != x'windows-latest' ]]; then | ||
| 330 | 330 | ||
| 331 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose | 331 | python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose |
| 332 | 332 | ||
| 333 | + ln -s $repo $PWD/ | ||
| 334 | + | ||
| 335 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst | ||
| 336 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav | ||
| 337 | + | ||
| 338 | + python3 ./python-api-examples/inverse-text-normalization-offline-asr.py | ||
| 339 | + | ||
| 340 | + rm -rfv sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 341 | + | ||
| 333 | rm -rf $repo | 342 | rm -rf $repo |
| 334 | fi | 343 | fi |
| 335 | 344 |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# | ||
| 3 | +# Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +""" | ||
| 6 | +This script shows how to use inverse text normalization with non-streaming ASR. | ||
| 7 | + | ||
| 8 | +Usage: | ||
| 9 | + | ||
| 10 | +(1) Download the test model | ||
| 11 | + | ||
| 12 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 | ||
| 13 | +tar xvf sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 | ||
| 14 | +rm sherpa-onnx-paraformer-zh-2023-03-28.tar.bz2 | ||
| 15 | + | ||
| 16 | +(2) Download rule fst | ||
| 17 | + | ||
| 18 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst | ||
| 19 | + | ||
| 20 | +Please refer to | ||
| 21 | +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb | ||
| 22 | +for how itn_zh_number.fst is generated. | ||
| 23 | + | ||
| 24 | +(3) Download test wave | ||
| 25 | + | ||
| 26 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav | ||
| 27 | + | ||
| 28 | +(4) Run this script | ||
| 29 | + | ||
| 30 | +python3 ./python-api-examples/inverse-text-normalization-offline-asr.py | ||
| 31 | +""" | ||
| 32 | +from pathlib import Path | ||
| 33 | + | ||
| 34 | +import sherpa_onnx | ||
| 35 | +import soundfile as sf | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +def create_recognizer(): | ||
| 39 | + model = "./sherpa-onnx-paraformer-zh-2023-03-28/model.int8.onnx" | ||
| 40 | + tokens = "./sherpa-onnx-paraformer-zh-2023-03-28/tokens.txt" | ||
| 41 | + rule_fsts = "./itn_zh_number.fst" | ||
| 42 | + | ||
| 43 | + if ( | ||
| 44 | + not Path(model).is_file() | ||
| 45 | + or not Path(tokens).is_file() | ||
| 46 | + or not Path(rule_fsts).is_file() | ||
| 47 | + ): | ||
| 48 | + raise ValueError( | ||
| 49 | + """Please download model files from | ||
| 50 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 51 | + """ | ||
| 52 | + ) | ||
| 53 | + return sherpa_onnx.OfflineRecognizer.from_paraformer( | ||
| 54 | + paraformer=model, | ||
| 55 | + tokens=tokens, | ||
| 56 | + debug=True, | ||
| 57 | + rule_fsts=rule_fsts, | ||
| 58 | + ) | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +def main(): | ||
| 62 | + recognizer = create_recognizer() | ||
| 63 | + wave_filename = "./itn-zh-number.wav" | ||
| 64 | + if not Path(wave_filename).is_file(): | ||
| 65 | + raise ValueError( | ||
| 66 | + """Please download model files from | ||
| 67 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 68 | + """ | ||
| 69 | + ) | ||
| 70 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 71 | + audio = audio[:, 0] # only use the first channel | ||
| 72 | + | ||
| 73 | + stream = recognizer.create_stream() | ||
| 74 | + stream.accept_waveform(sample_rate, audio) | ||
| 75 | + recognizer.decode_stream(stream) | ||
| 76 | + print(wave_filename) | ||
| 77 | + print(stream.result) | ||
| 78 | + | ||
| 79 | + | ||
| 80 | +if __name__ == "__main__": | ||
| 81 | + main() |
| @@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -73,7 +73,8 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 73 | class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | 73 | class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { |
| 74 | public: | 74 | public: |
| 75 | explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) | 75 | explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) |
| 76 | - : config_(config), | 76 | + : OfflineRecognizerImpl(config), |
| 77 | + config_(config), | ||
| 77 | symbol_table_(config_.model_config.tokens), | 78 | symbol_table_(config_.model_config.tokens), |
| 78 | model_(OfflineCtcModel::Create(config_.model_config)) { | 79 | model_(OfflineCtcModel::Create(config_.model_config)) { |
| 79 | Init(); | 80 | Init(); |
| @@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -82,7 +83,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 82 | #if __ANDROID_API__ >= 9 | 83 | #if __ANDROID_API__ >= 9 |
| 83 | OfflineRecognizerCtcImpl(AAssetManager *mgr, | 84 | OfflineRecognizerCtcImpl(AAssetManager *mgr, |
| 84 | const OfflineRecognizerConfig &config) | 85 | const OfflineRecognizerConfig &config) |
| 85 | - : config_(config), | 86 | + : OfflineRecognizerImpl(mgr, config), |
| 87 | + config_(config), | ||
| 86 | symbol_table_(mgr, config_.model_config.tokens), | 88 | symbol_table_(mgr, config_.model_config.tokens), |
| 87 | model_(OfflineCtcModel::Create(mgr, config_.model_config)) { | 89 | model_(OfflineCtcModel::Create(mgr, config_.model_config)) { |
| 88 | Init(); | 90 | Init(); |
| @@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -205,6 +207,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 205 | for (int32_t i = 0; i != n; ++i) { | 207 | for (int32_t i = 0; i != n; ++i) { |
| 206 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 208 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 207 | model_->SubsamplingFactor()); | 209 | model_->SubsamplingFactor()); |
| 210 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 208 | ss[i]->SetResult(r); | 211 | ss[i]->SetResult(r); |
| 209 | } | 212 | } |
| 210 | } | 213 | } |
| @@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -238,6 +241,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 238 | 241 | ||
| 239 | auto r = Convert(results[0], symbol_table_, frame_shift_ms, | 242 | auto r = Convert(results[0], symbol_table_, frame_shift_ms, |
| 240 | model_->SubsamplingFactor()); | 243 | model_->SubsamplingFactor()); |
| 244 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 241 | s->SetResult(r); | 245 | s->SetResult(r); |
| 242 | } | 246 | } |
| 243 | 247 |
| @@ -5,7 +5,18 @@ | @@ -5,7 +5,18 @@ | ||
| 5 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 5 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 6 | 6 | ||
| 7 | #include <string> | 7 | #include <string> |
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 8 | 10 | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include <strstream> | ||
| 13 | + | ||
| 14 | +#include "android/asset_manager.h" | ||
| 15 | +#include "android/asset_manager_jni.h" | ||
| 16 | +#endif | ||
| 17 | + | ||
| 18 | +#include "fst/extensions/far/far.h" | ||
| 19 | +#include "kaldifst/csrc/kaldi-fst-io.h" | ||
| 9 | #include "onnxruntime_cxx_api.h" // NOLINT | 20 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 10 | #include "sherpa-onnx/csrc/macros.h" | 21 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" | 22 | #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" |
| @@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -316,4 +327,111 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 316 | } | 327 | } |
| 317 | #endif | 328 | #endif |
| 318 | 329 | ||
| 330 | +OfflineRecognizerImpl::OfflineRecognizerImpl( | ||
| 331 | + const OfflineRecognizerConfig &config) | ||
| 332 | + : config_(config) { | ||
| 333 | + if (!config.rule_fsts.empty()) { | ||
| 334 | + std::vector<std::string> files; | ||
| 335 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 336 | + itn_list_.reserve(files.size()); | ||
| 337 | + for (const auto &f : files) { | ||
| 338 | + if (config.model_config.debug) { | ||
| 339 | + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); | ||
| 340 | + } | ||
| 341 | + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f)); | ||
| 342 | + } | ||
| 343 | + } | ||
| 344 | + | ||
| 345 | + if (!config.rule_fars.empty()) { | ||
| 346 | + if (config.model_config.debug) { | ||
| 347 | + SHERPA_ONNX_LOGE("Loading FST archives"); | ||
| 348 | + } | ||
| 349 | + std::vector<std::string> files; | ||
| 350 | + SplitStringToVector(config.rule_fars, ",", false, &files); | ||
| 351 | + | ||
| 352 | + itn_list_.reserve(files.size() + itn_list_.size()); | ||
| 353 | + | ||
| 354 | + for (const auto &f : files) { | ||
| 355 | + if (config.model_config.debug) { | ||
| 356 | + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); | ||
| 357 | + } | ||
| 358 | + std::unique_ptr<fst::FarReader<fst::StdArc>> reader( | ||
| 359 | + fst::FarReader<fst::StdArc>::Open(f)); | ||
| 360 | + for (; !reader->Done(); reader->Next()) { | ||
| 361 | + std::unique_ptr<fst::StdConstFst> r( | ||
| 362 | + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); | ||
| 363 | + | ||
| 364 | + itn_list_.push_back( | ||
| 365 | + std::make_unique<kaldifst::TextNormalizer>(std::move(r))); | ||
| 366 | + } | ||
| 367 | + } | ||
| 368 | + | ||
| 369 | + if (config.model_config.debug) { | ||
| 370 | + SHERPA_ONNX_LOGE("FST archives loaded!"); | ||
| 371 | + } | ||
| 372 | + } | ||
| 373 | +} | ||
| 374 | + | ||
| 375 | +#if __ANDROID_API__ >= 9 | ||
| 376 | +OfflineRecognizerImpl::OfflineRecognizerImpl( | ||
| 377 | + AAssetManager *mgr, const OfflineRecognizerConfig &config) | ||
| 378 | + : config_(config) { | ||
| 379 | + if (!config.rule_fsts.empty()) { | ||
| 380 | + std::vector<std::string> files; | ||
| 381 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 382 | + itn_list_.reserve(files.size()); | ||
| 383 | + for (const auto &f : files) { | ||
| 384 | + if (config.model_config.debug) { | ||
| 385 | + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); | ||
| 386 | + } | ||
| 387 | + auto buf = ReadFile(mgr, f); | ||
| 388 | + std::istrstream is(buf.data(), buf.size()); | ||
| 389 | + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is)); | ||
| 390 | + } | ||
| 391 | + } | ||
| 392 | + | ||
| 393 | + if (!config.rule_fars.empty()) { | ||
| 394 | + std::vector<std::string> files; | ||
| 395 | + SplitStringToVector(config.rule_fars, ",", false, &files); | ||
| 396 | + itn_list_.reserve(files.size() + itn_list_.size()); | ||
| 397 | + | ||
| 398 | + for (const auto &f : files) { | ||
| 399 | + if (config.model_config.debug) { | ||
| 400 | + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); | ||
| 401 | + } | ||
| 402 | + | ||
| 403 | + auto buf = ReadFile(mgr, f); | ||
| 404 | + | ||
| 405 | + std::unique_ptr<std::istream> s( | ||
| 406 | + new std::istrstream(buf.data(), buf.size())); | ||
| 407 | + | ||
| 408 | + std::unique_ptr<fst::FarReader<fst::StdArc>> reader( | ||
| 409 | + fst::FarReader<fst::StdArc>::Open(std::move(s))); | ||
| 410 | + | ||
| 411 | + for (; !reader->Done(); reader->Next()) { | ||
| 412 | + std::unique_ptr<fst::StdConstFst> r( | ||
| 413 | + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); | ||
| 414 | + | ||
| 415 | + itn_list_.push_back( | ||
| 416 | + std::make_unique<kaldifst::TextNormalizer>(std::move(r))); | ||
| 417 | + } // for (; !reader->Done(); reader->Next()) | ||
| 418 | + } // for (const auto &f : files) | ||
| 419 | + } // if (!config.rule_fars.empty()) | ||
| 420 | +} | ||
| 421 | +#endif | ||
| 422 | + | ||
| 423 | +std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | ||
| 424 | + std::string text) const { | ||
| 425 | + if (!itn_list_.empty()) { | ||
| 426 | + for (const auto &tn : itn_list_) { | ||
| 427 | + text = tn->Normalize(text); | ||
| 428 | + if (config_.model_config.debug) { | ||
| 429 | + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str()); | ||
| 430 | + } | ||
| 431 | + } | ||
| 432 | + } | ||
| 433 | + | ||
| 434 | + return text; | ||
| 435 | +} | ||
| 436 | + | ||
| 319 | } // namespace sherpa_onnx | 437 | } // namespace sherpa_onnx |
| @@ -14,6 +14,7 @@ | @@ -14,6 +14,7 @@ | ||
| 14 | #include "android/asset_manager_jni.h" | 14 | #include "android/asset_manager_jni.h" |
| 15 | #endif | 15 | #endif |
| 16 | 16 | ||
| 17 | +#include "kaldifst/csrc/text-normalizer.h" | ||
| 17 | #include "sherpa-onnx/csrc/macros.h" | 18 | #include "sherpa-onnx/csrc/macros.h" |
| 18 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 19 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 19 | #include "sherpa-onnx/csrc/offline-stream.h" | 20 | #include "sherpa-onnx/csrc/offline-stream.h" |
| @@ -22,10 +23,15 @@ namespace sherpa_onnx { | @@ -22,10 +23,15 @@ namespace sherpa_onnx { | ||
| 22 | 23 | ||
| 23 | class OfflineRecognizerImpl { | 24 | class OfflineRecognizerImpl { |
| 24 | public: | 25 | public: |
| 26 | + explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config); | ||
| 27 | + | ||
| 25 | static std::unique_ptr<OfflineRecognizerImpl> Create( | 28 | static std::unique_ptr<OfflineRecognizerImpl> Create( |
| 26 | const OfflineRecognizerConfig &config); | 29 | const OfflineRecognizerConfig &config); |
| 27 | 30 | ||
| 28 | #if __ANDROID_API__ >= 9 | 31 | #if __ANDROID_API__ >= 9 |
| 32 | + OfflineRecognizerImpl(AAssetManager *mgr, | ||
| 33 | + const OfflineRecognizerConfig &config); | ||
| 34 | + | ||
| 29 | static std::unique_ptr<OfflineRecognizerImpl> Create( | 35 | static std::unique_ptr<OfflineRecognizerImpl> Create( |
| 30 | AAssetManager *mgr, const OfflineRecognizerConfig &config); | 36 | AAssetManager *mgr, const OfflineRecognizerConfig &config); |
| 31 | #endif | 37 | #endif |
| @@ -41,6 +47,15 @@ class OfflineRecognizerImpl { | @@ -41,6 +47,15 @@ class OfflineRecognizerImpl { | ||
| 41 | virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; | 47 | virtual std::unique_ptr<OfflineStream> CreateStream() const = 0; |
| 42 | 48 | ||
| 43 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; | 49 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; |
| 50 | + | ||
| 51 | + std::string ApplyInverseTextNormalization(std::string text) const; | ||
| 52 | + | ||
| 53 | + private: | ||
| 54 | + OfflineRecognizerConfig config_; | ||
| 55 | + // for inverse text normalization. Used only if | ||
| 56 | + // config.rule_fsts is not empty or | ||
| 57 | + // config.rule_fars is not empty | ||
| 58 | + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; | ||
| 44 | }; | 59 | }; |
| 45 | 60 | ||
| 46 | } // namespace sherpa_onnx | 61 | } // namespace sherpa_onnx |
| @@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -89,7 +89,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 89 | public: | 89 | public: |
| 90 | explicit OfflineRecognizerParaformerImpl( | 90 | explicit OfflineRecognizerParaformerImpl( |
| 91 | const OfflineRecognizerConfig &config) | 91 | const OfflineRecognizerConfig &config) |
| 92 | - : config_(config), | 92 | + : OfflineRecognizerImpl(config), |
| 93 | + config_(config), | ||
| 93 | symbol_table_(config_.model_config.tokens), | 94 | symbol_table_(config_.model_config.tokens), |
| 94 | model_(std::make_unique<OfflineParaformerModel>(config.model_config)) { | 95 | model_(std::make_unique<OfflineParaformerModel>(config.model_config)) { |
| 95 | if (config.decoding_method == "greedy_search") { | 96 | if (config.decoding_method == "greedy_search") { |
| @@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -109,7 +110,8 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 109 | #if __ANDROID_API__ >= 9 | 110 | #if __ANDROID_API__ >= 9 |
| 110 | OfflineRecognizerParaformerImpl(AAssetManager *mgr, | 111 | OfflineRecognizerParaformerImpl(AAssetManager *mgr, |
| 111 | const OfflineRecognizerConfig &config) | 112 | const OfflineRecognizerConfig &config) |
| 112 | - : config_(config), | 113 | + : OfflineRecognizerImpl(mgr, config), |
| 114 | + config_(config), | ||
| 113 | symbol_table_(mgr, config_.model_config.tokens), | 115 | symbol_table_(mgr, config_.model_config.tokens), |
| 114 | model_(std::make_unique<OfflineParaformerModel>(mgr, | 116 | model_(std::make_unique<OfflineParaformerModel>(mgr, |
| 115 | config.model_config)) { | 117 | config.model_config)) { |
| @@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -204,6 +206,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 204 | 206 | ||
| 205 | for (int32_t i = 0; i != n; ++i) { | 207 | for (int32_t i = 0; i != n; ++i) { |
| 206 | auto r = Convert(results[i], symbol_table_); | 208 | auto r = Convert(results[i], symbol_table_); |
| 209 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 207 | ss[i]->SetResult(r); | 210 | ss[i]->SetResult(r); |
| 208 | } | 211 | } |
| 209 | } | 212 | } |
| @@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -74,7 +74,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 74 | public: | 74 | public: |
| 75 | explicit OfflineRecognizerTransducerImpl( | 75 | explicit OfflineRecognizerTransducerImpl( |
| 76 | const OfflineRecognizerConfig &config) | 76 | const OfflineRecognizerConfig &config) |
| 77 | - : config_(config), | 77 | + : OfflineRecognizerImpl(config), |
| 78 | + config_(config), | ||
| 78 | symbol_table_(config_.model_config.tokens), | 79 | symbol_table_(config_.model_config.tokens), |
| 79 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { | 80 | model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { |
| 80 | if (config_.decoding_method == "greedy_search") { | 81 | if (config_.decoding_method == "greedy_search") { |
| @@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -107,7 +108,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 107 | #if __ANDROID_API__ >= 9 | 108 | #if __ANDROID_API__ >= 9 |
| 108 | explicit OfflineRecognizerTransducerImpl( | 109 | explicit OfflineRecognizerTransducerImpl( |
| 109 | AAssetManager *mgr, const OfflineRecognizerConfig &config) | 110 | AAssetManager *mgr, const OfflineRecognizerConfig &config) |
| 110 | - : config_(config), | 111 | + : OfflineRecognizerImpl(mgr, config), |
| 112 | + config_(config), | ||
| 111 | symbol_table_(mgr, config_.model_config.tokens), | 113 | symbol_table_(mgr, config_.model_config.tokens), |
| 112 | model_(std::make_unique<OfflineTransducerModel>(mgr, | 114 | model_(std::make_unique<OfflineTransducerModel>(mgr, |
| 113 | config_.model_config)) { | 115 | config_.model_config)) { |
| @@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -230,6 +232,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 230 | for (int32_t i = 0; i != n; ++i) { | 232 | for (int32_t i = 0; i != n; ++i) { |
| 231 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 233 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 232 | model_->SubsamplingFactor()); | 234 | model_->SubsamplingFactor()); |
| 235 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 233 | 236 | ||
| 234 | ss[i]->SetResult(r); | 237 | ss[i]->SetResult(r); |
| 235 | } | 238 | } |
| @@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -41,7 +41,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 41 | public: | 41 | public: |
| 42 | explicit OfflineRecognizerTransducerNeMoImpl( | 42 | explicit OfflineRecognizerTransducerNeMoImpl( |
| 43 | const OfflineRecognizerConfig &config) | 43 | const OfflineRecognizerConfig &config) |
| 44 | - : config_(config), | 44 | + : OfflineRecognizerImpl(config), |
| 45 | + config_(config), | ||
| 45 | symbol_table_(config_.model_config.tokens), | 46 | symbol_table_(config_.model_config.tokens), |
| 46 | model_(std::make_unique<OfflineTransducerNeMoModel>( | 47 | model_(std::make_unique<OfflineTransducerNeMoModel>( |
| 47 | config_.model_config)) { | 48 | config_.model_config)) { |
| @@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -59,7 +60,8 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 59 | #if __ANDROID_API__ >= 9 | 60 | #if __ANDROID_API__ >= 9 |
| 60 | explicit OfflineRecognizerTransducerNeMoImpl( | 61 | explicit OfflineRecognizerTransducerNeMoImpl( |
| 61 | AAssetManager *mgr, const OfflineRecognizerConfig &config) | 62 | AAssetManager *mgr, const OfflineRecognizerConfig &config) |
| 62 | - : config_(config), | 63 | + : OfflineRecognizerImpl(mgr, config), |
| 64 | + config_(config), | ||
| 63 | symbol_table_(mgr, config_.model_config.tokens), | 65 | symbol_table_(mgr, config_.model_config.tokens), |
| 64 | model_(std::make_unique<OfflineTransducerNeMoModel>( | 66 | model_(std::make_unique<OfflineTransducerNeMoModel>( |
| 65 | mgr, config_.model_config)) { | 67 | mgr, config_.model_config)) { |
| @@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -131,6 +133,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 131 | for (int32_t i = 0; i != n; ++i) { | 133 | for (int32_t i = 0; i != n; ++i) { |
| 132 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 134 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 133 | model_->SubsamplingFactor()); | 135 | model_->SubsamplingFactor()); |
| 136 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 134 | 137 | ||
| 135 | ss[i]->SetResult(r); | 138 | ss[i]->SetResult(r); |
| 136 | } | 139 | } |
| @@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | @@ -52,7 +52,8 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 52 | class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | 52 | class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { |
| 53 | public: | 53 | public: |
| 54 | explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) | 54 | explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) |
| 55 | - : config_(config), | 55 | + : OfflineRecognizerImpl(config), |
| 56 | + config_(config), | ||
| 56 | symbol_table_(config_.model_config.tokens), | 57 | symbol_table_(config_.model_config.tokens), |
| 57 | model_(std::make_unique<OfflineWhisperModel>(config.model_config)) { | 58 | model_(std::make_unique<OfflineWhisperModel>(config.model_config)) { |
| 58 | Init(); | 59 | Init(); |
| @@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -61,7 +62,8 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 61 | #if __ANDROID_API__ >= 9 | 62 | #if __ANDROID_API__ >= 9 |
| 62 | OfflineRecognizerWhisperImpl(AAssetManager *mgr, | 63 | OfflineRecognizerWhisperImpl(AAssetManager *mgr, |
| 63 | const OfflineRecognizerConfig &config) | 64 | const OfflineRecognizerConfig &config) |
| 64 | - : config_(config), | 65 | + : OfflineRecognizerImpl(mgr, config), |
| 66 | + config_(config), | ||
| 65 | symbol_table_(mgr, config_.model_config.tokens), | 67 | symbol_table_(mgr, config_.model_config.tokens), |
| 66 | model_( | 68 | model_( |
| 67 | std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) { | 69 | std::make_unique<OfflineWhisperModel>(mgr, config.model_config)) { |
| @@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -150,6 +152,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 150 | std::move(cross_kv.second)); | 152 | std::move(cross_kv.second)); |
| 151 | 153 | ||
| 152 | auto r = Convert(results[0], symbol_table_); | 154 | auto r = Convert(results[0], symbol_table_); |
| 155 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); | ||
| 153 | s->SetResult(r); | 156 | s->SetResult(r); |
| 154 | } catch (const Ort::Exception &ex) { | 157 | } catch (const Ort::Exception &ex) { |
| 155 | SHERPA_ONNX_LOGE( | 158 | SHERPA_ONNX_LOGE( |
| @@ -10,7 +10,7 @@ | @@ -10,7 +10,7 @@ | ||
| 10 | #include "sherpa-onnx/csrc/macros.h" | 10 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 11 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 13 | - | 13 | +#include "sherpa-onnx/csrc/text-utils.h" |
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | void OfflineRecognizerConfig::Register(ParseOptions *po) { | 16 | void OfflineRecognizerConfig::Register(ParseOptions *po) { |
| @@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -44,6 +44,16 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 44 | po->Register("hotwords-score", &hotwords_score, | 44 | po->Register("hotwords-score", &hotwords_score, |
| 45 | "The bonus score for each token in context word/phrase. " | 45 | "The bonus score for each token in context word/phrase. " |
| 46 | "Used only when decoding_method is modified_beam_search"); | 46 | "Used only when decoding_method is modified_beam_search"); |
| 47 | + | ||
| 48 | + po->Register( | ||
| 49 | + "rule-fsts", &rule_fsts, | ||
| 50 | + "If not empty, it specifies fsts for inverse text normalization. " | ||
| 51 | + "If there are multiple fsts, they are separated by a comma."); | ||
| 52 | + | ||
| 53 | + po->Register( | ||
| 54 | + "rule-fars", &rule_fars, | ||
| 55 | + "If not empty, it specifies fst archives for inverse text normalization. " | ||
| 56 | + "If there are multiple archives, they are separated by a comma."); | ||
| 47 | } | 57 | } |
| 48 | 58 | ||
| 49 | bool OfflineRecognizerConfig::Validate() const { | 59 | bool OfflineRecognizerConfig::Validate() const { |
| @@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const { | @@ -61,7 +71,7 @@ bool OfflineRecognizerConfig::Validate() const { | ||
| 61 | if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { | 71 | if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { |
| 62 | SHERPA_ONNX_LOGE( | 72 | SHERPA_ONNX_LOGE( |
| 63 | "Please use --decoding-method=modified_beam_search if you" | 73 | "Please use --decoding-method=modified_beam_search if you" |
| 64 | - " provide --hotwords-file. Given --decoding-method=%s", | 74 | + " provide --hotwords-file. Given --decoding-method='%s'", |
| 65 | decoding_method.c_str()); | 75 | decoding_method.c_str()); |
| 66 | return false; | 76 | return false; |
| 67 | } | 77 | } |
| @@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const { | @@ -72,6 +82,34 @@ bool OfflineRecognizerConfig::Validate() const { | ||
| 72 | return false; | 82 | return false; |
| 73 | } | 83 | } |
| 74 | 84 | ||
| 85 | + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { | ||
| 86 | + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", | ||
| 87 | + hotwords_file.c_str()); | ||
| 88 | + return false; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + if (!rule_fsts.empty()) { | ||
| 92 | + std::vector<std::string> files; | ||
| 93 | + SplitStringToVector(rule_fsts, ",", false, &files); | ||
| 94 | + for (const auto &f : files) { | ||
| 95 | + if (!FileExists(f)) { | ||
| 96 | + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); | ||
| 97 | + return false; | ||
| 98 | + } | ||
| 99 | + } | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + if (!rule_fars.empty()) { | ||
| 103 | + std::vector<std::string> files; | ||
| 104 | + SplitStringToVector(rule_fars, ",", false, &files); | ||
| 105 | + for (const auto &f : files) { | ||
| 106 | + if (!FileExists(f)) { | ||
| 107 | + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); | ||
| 108 | + return false; | ||
| 109 | + } | ||
| 110 | + } | ||
| 111 | + } | ||
| 112 | + | ||
| 75 | return model_config.Validate(); | 113 | return model_config.Validate(); |
| 76 | } | 114 | } |
| 77 | 115 | ||
| @@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -87,7 +125,9 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 87 | os << "max_active_paths=" << max_active_paths << ", "; | 125 | os << "max_active_paths=" << max_active_paths << ", "; |
| 88 | os << "hotwords_file=\"" << hotwords_file << "\", "; | 126 | os << "hotwords_file=\"" << hotwords_file << "\", "; |
| 89 | os << "hotwords_score=" << hotwords_score << ", "; | 127 | os << "hotwords_score=" << hotwords_score << ", "; |
| 90 | - os << "blank_penalty=" << blank_penalty << ")"; | 128 | + os << "blank_penalty=" << blank_penalty << ", "; |
| 129 | + os << "rule_fsts=\"" << rule_fsts << "\", "; | ||
| 130 | + os << "rule_fars=\"" << rule_fars << "\")"; | ||
| 91 | 131 | ||
| 92 | return os.str(); | 132 | return os.str(); |
| 93 | } | 133 | } |
| @@ -40,6 +40,12 @@ struct OfflineRecognizerConfig { | @@ -40,6 +40,12 @@ struct OfflineRecognizerConfig { | ||
| 40 | 40 | ||
| 41 | float blank_penalty = 0.0; | 41 | float blank_penalty = 0.0; |
| 42 | 42 | ||
| 43 | + // If there are multiple rules, they are applied from left to right. | ||
| 44 | + std::string rule_fsts; | ||
| 45 | + | ||
| 46 | + // If there are multiple FST archives, they are applied from left to right. | ||
| 47 | + std::string rule_fars; | ||
| 48 | + | ||
| 43 | // only greedy_search is implemented | 49 | // only greedy_search is implemented |
| 44 | // TODO(fangjun): Implement modified_beam_search | 50 | // TODO(fangjun): Implement modified_beam_search |
| 45 | 51 | ||
| @@ -50,7 +56,8 @@ struct OfflineRecognizerConfig { | @@ -50,7 +56,8 @@ struct OfflineRecognizerConfig { | ||
| 50 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, | 56 | const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, |
| 51 | const std::string &decoding_method, int32_t max_active_paths, | 57 | const std::string &decoding_method, int32_t max_active_paths, |
| 52 | const std::string &hotwords_file, float hotwords_score, | 58 | const std::string &hotwords_file, float hotwords_score, |
| 53 | - float blank_penalty) | 59 | + float blank_penalty, const std::string &rule_fsts, |
| 60 | + const std::string &rule_fars) | ||
| 54 | : feat_config(feat_config), | 61 | : feat_config(feat_config), |
| 55 | model_config(model_config), | 62 | model_config(model_config), |
| 56 | lm_config(lm_config), | 63 | lm_config(lm_config), |
| @@ -59,7 +66,9 @@ struct OfflineRecognizerConfig { | @@ -59,7 +66,9 @@ struct OfflineRecognizerConfig { | ||
| 59 | max_active_paths(max_active_paths), | 66 | max_active_paths(max_active_paths), |
| 60 | hotwords_file(hotwords_file), | 67 | hotwords_file(hotwords_file), |
| 61 | hotwords_score(hotwords_score), | 68 | hotwords_score(hotwords_score), |
| 62 | - blank_penalty(blank_penalty) {} | 69 | + blank_penalty(blank_penalty), |
| 70 | + rule_fsts(rule_fsts), | ||
| 71 | + rule_fars(rule_fars) {} | ||
| 63 | 72 | ||
| 64 | void Register(ParseOptions *po); | 73 | void Register(ParseOptions *po); |
| 65 | bool Validate() const; | 74 | bool Validate() const; |
| @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 17 | .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, | 17 | .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, |
| 18 | const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, | 18 | const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, |
| 19 | const std::string &, int32_t, const std::string &, float, | 19 | const std::string &, int32_t, const std::string &, float, |
| 20 | - float>(), | 20 | + float, const std::string &, const std::string &>(), |
| 21 | py::arg("feat_config"), py::arg("model_config"), | 21 | py::arg("feat_config"), py::arg("model_config"), |
| 22 | py::arg("lm_config") = OfflineLMConfig(), | 22 | py::arg("lm_config") = OfflineLMConfig(), |
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| 24 | py::arg("decoding_method") = "greedy_search", | 24 | py::arg("decoding_method") = "greedy_search", |
| 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 26 | - py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0) | 26 | + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, |
| 27 | + py::arg("rule_fsts") = "", py::arg("rule_fars") = "") | ||
| 27 | .def_readwrite("feat_config", &PyClass::feat_config) | 28 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 28 | .def_readwrite("model_config", &PyClass::model_config) | 29 | .def_readwrite("model_config", &PyClass::model_config) |
| 29 | .def_readwrite("lm_config", &PyClass::lm_config) | 30 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -33,6 +34,8 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 33 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) | 34 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 34 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) | 35 | .def_readwrite("hotwords_score", &PyClass::hotwords_score) |
| 35 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) | 36 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) |
| 37 | + .def_readwrite("rule_fsts", &PyClass::rule_fsts) | ||
| 38 | + .def_readwrite("rule_fars", &PyClass::rule_fars) | ||
| 36 | .def("__str__", &PyClass::ToString); | 39 | .def("__str__", &PyClass::ToString); |
| 37 | } | 40 | } |
| 38 | 41 |
| @@ -54,6 +54,8 @@ class OfflineRecognizer(object): | @@ -54,6 +54,8 @@ class OfflineRecognizer(object): | ||
| 54 | debug: bool = False, | 54 | debug: bool = False, |
| 55 | provider: str = "cpu", | 55 | provider: str = "cpu", |
| 56 | model_type: str = "transducer", | 56 | model_type: str = "transducer", |
| 57 | + rule_fsts: str = "", | ||
| 58 | + rule_fars: str = "", | ||
| 57 | ): | 59 | ): |
| 58 | """ | 60 | """ |
| 59 | Please refer to | 61 | Please refer to |
| @@ -107,6 +109,12 @@ class OfflineRecognizer(object): | @@ -107,6 +109,12 @@ class OfflineRecognizer(object): | ||
| 107 | True to show debug messages. | 109 | True to show debug messages. |
| 108 | provider: | 110 | provider: |
| 109 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 111 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 112 | + rule_fsts: | ||
| 113 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 114 | + If there are multiple fsts, they are separated by a comma. | ||
| 115 | + rule_fars: | ||
| 116 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 117 | + If there are multiple archives, they are separated by a comma. | ||
| 110 | """ | 118 | """ |
| 111 | self = cls.__new__(cls) | 119 | self = cls.__new__(cls) |
| 112 | model_config = OfflineModelConfig( | 120 | model_config = OfflineModelConfig( |
| @@ -143,6 +151,8 @@ class OfflineRecognizer(object): | @@ -143,6 +151,8 @@ class OfflineRecognizer(object): | ||
| 143 | hotwords_file=hotwords_file, | 151 | hotwords_file=hotwords_file, |
| 144 | hotwords_score=hotwords_score, | 152 | hotwords_score=hotwords_score, |
| 145 | blank_penalty=blank_penalty, | 153 | blank_penalty=blank_penalty, |
| 154 | + rule_fsts=rule_fsts, | ||
| 155 | + rule_fars=rule_fars, | ||
| 146 | ) | 156 | ) |
| 147 | self.recognizer = _Recognizer(recognizer_config) | 157 | self.recognizer = _Recognizer(recognizer_config) |
| 148 | self.config = recognizer_config | 158 | self.config = recognizer_config |
| @@ -159,6 +169,8 @@ class OfflineRecognizer(object): | @@ -159,6 +169,8 @@ class OfflineRecognizer(object): | ||
| 159 | decoding_method: str = "greedy_search", | 169 | decoding_method: str = "greedy_search", |
| 160 | debug: bool = False, | 170 | debug: bool = False, |
| 161 | provider: str = "cpu", | 171 | provider: str = "cpu", |
| 172 | + rule_fsts: str = "", | ||
| 173 | + rule_fars: str = "", | ||
| 162 | ): | 174 | ): |
| 163 | """ | 175 | """ |
| 164 | Please refer to | 176 | Please refer to |
| @@ -186,6 +198,12 @@ class OfflineRecognizer(object): | @@ -186,6 +198,12 @@ class OfflineRecognizer(object): | ||
| 186 | True to show debug messages. | 198 | True to show debug messages. |
| 187 | provider: | 199 | provider: |
| 188 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 200 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 201 | + rule_fsts: | ||
| 202 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 203 | + If there are multiple fsts, they are separated by a comma. | ||
| 204 | + rule_fars: | ||
| 205 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 206 | + If there are multiple archives, they are separated by a comma. | ||
| 189 | """ | 207 | """ |
| 190 | self = cls.__new__(cls) | 208 | self = cls.__new__(cls) |
| 191 | model_config = OfflineModelConfig( | 209 | model_config = OfflineModelConfig( |
| @@ -206,6 +224,8 @@ class OfflineRecognizer(object): | @@ -206,6 +224,8 @@ class OfflineRecognizer(object): | ||
| 206 | feat_config=feat_config, | 224 | feat_config=feat_config, |
| 207 | model_config=model_config, | 225 | model_config=model_config, |
| 208 | decoding_method=decoding_method, | 226 | decoding_method=decoding_method, |
| 227 | + rule_fsts=rule_fsts, | ||
| 228 | + rule_fars=rule_fars, | ||
| 209 | ) | 229 | ) |
| 210 | self.recognizer = _Recognizer(recognizer_config) | 230 | self.recognizer = _Recognizer(recognizer_config) |
| 211 | self.config = recognizer_config | 231 | self.config = recognizer_config |
| @@ -222,6 +242,8 @@ class OfflineRecognizer(object): | @@ -222,6 +242,8 @@ class OfflineRecognizer(object): | ||
| 222 | decoding_method: str = "greedy_search", | 242 | decoding_method: str = "greedy_search", |
| 223 | debug: bool = False, | 243 | debug: bool = False, |
| 224 | provider: str = "cpu", | 244 | provider: str = "cpu", |
| 245 | + rule_fsts: str = "", | ||
| 246 | + rule_fars: str = "", | ||
| 225 | ): | 247 | ): |
| 226 | """ | 248 | """ |
| 227 | Please refer to | 249 | Please refer to |
| @@ -251,6 +273,12 @@ class OfflineRecognizer(object): | @@ -251,6 +273,12 @@ class OfflineRecognizer(object): | ||
| 251 | True to show debug messages. | 273 | True to show debug messages. |
| 252 | provider: | 274 | provider: |
| 253 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 275 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 276 | + rule_fsts: | ||
| 277 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 278 | + If there are multiple fsts, they are separated by a comma. | ||
| 279 | + rule_fars: | ||
| 280 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 281 | + If there are multiple archives, they are separated by a comma. | ||
| 254 | """ | 282 | """ |
| 255 | self = cls.__new__(cls) | 283 | self = cls.__new__(cls) |
| 256 | model_config = OfflineModelConfig( | 284 | model_config = OfflineModelConfig( |
| @@ -271,6 +299,8 @@ class OfflineRecognizer(object): | @@ -271,6 +299,8 @@ class OfflineRecognizer(object): | ||
| 271 | feat_config=feat_config, | 299 | feat_config=feat_config, |
| 272 | model_config=model_config, | 300 | model_config=model_config, |
| 273 | decoding_method=decoding_method, | 301 | decoding_method=decoding_method, |
| 302 | + rule_fsts=rule_fsts, | ||
| 303 | + rule_fars=rule_fars, | ||
| 274 | ) | 304 | ) |
| 275 | self.recognizer = _Recognizer(recognizer_config) | 305 | self.recognizer = _Recognizer(recognizer_config) |
| 276 | self.config = recognizer_config | 306 | self.config = recognizer_config |
| @@ -287,6 +317,8 @@ class OfflineRecognizer(object): | @@ -287,6 +317,8 @@ class OfflineRecognizer(object): | ||
| 287 | decoding_method: str = "greedy_search", | 317 | decoding_method: str = "greedy_search", |
| 288 | debug: bool = False, | 318 | debug: bool = False, |
| 289 | provider: str = "cpu", | 319 | provider: str = "cpu", |
| 320 | + rule_fsts: str = "", | ||
| 321 | + rule_fars: str = "", | ||
| 290 | ): | 322 | ): |
| 291 | """ | 323 | """ |
| 292 | Please refer to | 324 | Please refer to |
| @@ -315,6 +347,12 @@ class OfflineRecognizer(object): | @@ -315,6 +347,12 @@ class OfflineRecognizer(object): | ||
| 315 | True to show debug messages. | 347 | True to show debug messages. |
| 316 | provider: | 348 | provider: |
| 317 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 349 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 350 | + rule_fsts: | ||
| 351 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 352 | + If there are multiple fsts, they are separated by a comma. | ||
| 353 | + rule_fars: | ||
| 354 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 355 | + If there are multiple archives, they are separated by a comma. | ||
| 318 | """ | 356 | """ |
| 319 | self = cls.__new__(cls) | 357 | self = cls.__new__(cls) |
| 320 | model_config = OfflineModelConfig( | 358 | model_config = OfflineModelConfig( |
| @@ -335,6 +373,8 @@ class OfflineRecognizer(object): | @@ -335,6 +373,8 @@ class OfflineRecognizer(object): | ||
| 335 | feat_config=feat_config, | 373 | feat_config=feat_config, |
| 336 | model_config=model_config, | 374 | model_config=model_config, |
| 337 | decoding_method=decoding_method, | 375 | decoding_method=decoding_method, |
| 376 | + rule_fsts=rule_fsts, | ||
| 377 | + rule_fars=rule_fars, | ||
| 338 | ) | 378 | ) |
| 339 | self.recognizer = _Recognizer(recognizer_config) | 379 | self.recognizer = _Recognizer(recognizer_config) |
| 340 | self.config = recognizer_config | 380 | self.config = recognizer_config |
| @@ -353,6 +393,8 @@ class OfflineRecognizer(object): | @@ -353,6 +393,8 @@ class OfflineRecognizer(object): | ||
| 353 | debug: bool = False, | 393 | debug: bool = False, |
| 354 | provider: str = "cpu", | 394 | provider: str = "cpu", |
| 355 | tail_paddings: int = -1, | 395 | tail_paddings: int = -1, |
| 396 | + rule_fsts: str = "", | ||
| 397 | + rule_fars: str = "", | ||
| 356 | ): | 398 | ): |
| 357 | """ | 399 | """ |
| 358 | Please refer to | 400 | Please refer to |
| @@ -389,6 +431,12 @@ class OfflineRecognizer(object): | @@ -389,6 +431,12 @@ class OfflineRecognizer(object): | ||
| 389 | True to show debug messages. | 431 | True to show debug messages. |
| 390 | provider: | 432 | provider: |
| 391 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 433 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 434 | + rule_fsts: | ||
| 435 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 436 | + If there are multiple fsts, they are separated by a comma. | ||
| 437 | + rule_fars: | ||
| 438 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 439 | + If there are multiple archives, they are separated by a comma. | ||
| 392 | """ | 440 | """ |
| 393 | self = cls.__new__(cls) | 441 | self = cls.__new__(cls) |
| 394 | model_config = OfflineModelConfig( | 442 | model_config = OfflineModelConfig( |
| @@ -415,6 +463,8 @@ class OfflineRecognizer(object): | @@ -415,6 +463,8 @@ class OfflineRecognizer(object): | ||
| 415 | feat_config=feat_config, | 463 | feat_config=feat_config, |
| 416 | model_config=model_config, | 464 | model_config=model_config, |
| 417 | decoding_method=decoding_method, | 465 | decoding_method=decoding_method, |
| 466 | + rule_fsts=rule_fsts, | ||
| 467 | + rule_fars=rule_fars, | ||
| 418 | ) | 468 | ) |
| 419 | self.recognizer = _Recognizer(recognizer_config) | 469 | self.recognizer = _Recognizer(recognizer_config) |
| 420 | self.config = recognizer_config | 470 | self.config = recognizer_config |
| @@ -431,6 +481,8 @@ class OfflineRecognizer(object): | @@ -431,6 +481,8 @@ class OfflineRecognizer(object): | ||
| 431 | decoding_method: str = "greedy_search", | 481 | decoding_method: str = "greedy_search", |
| 432 | debug: bool = False, | 482 | debug: bool = False, |
| 433 | provider: str = "cpu", | 483 | provider: str = "cpu", |
| 484 | + rule_fsts: str = "", | ||
| 485 | + rule_fars: str = "", | ||
| 434 | ): | 486 | ): |
| 435 | """ | 487 | """ |
| 436 | Please refer to | 488 | Please refer to |
| @@ -458,6 +510,12 @@ class OfflineRecognizer(object): | @@ -458,6 +510,12 @@ class OfflineRecognizer(object): | ||
| 458 | True to show debug messages. | 510 | True to show debug messages. |
| 459 | provider: | 511 | provider: |
| 460 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 512 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 513 | + rule_fsts: | ||
| 514 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 515 | + If there are multiple fsts, they are separated by a comma. | ||
| 516 | + rule_fars: | ||
| 517 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 518 | + If there are multiple archives, they are separated by a comma. | ||
| 461 | """ | 519 | """ |
| 462 | self = cls.__new__(cls) | 520 | self = cls.__new__(cls) |
| 463 | model_config = OfflineModelConfig( | 521 | model_config = OfflineModelConfig( |
| @@ -478,6 +536,8 @@ class OfflineRecognizer(object): | @@ -478,6 +536,8 @@ class OfflineRecognizer(object): | ||
| 478 | feat_config=feat_config, | 536 | feat_config=feat_config, |
| 479 | model_config=model_config, | 537 | model_config=model_config, |
| 480 | decoding_method=decoding_method, | 538 | decoding_method=decoding_method, |
| 539 | + rule_fsts=rule_fsts, | ||
| 540 | + rule_fars=rule_fars, | ||
| 481 | ) | 541 | ) |
| 482 | self.recognizer = _Recognizer(recognizer_config) | 542 | self.recognizer = _Recognizer(recognizer_config) |
| 483 | self.config = recognizer_config | 543 | self.config = recognizer_config |
| @@ -494,6 +554,8 @@ class OfflineRecognizer(object): | @@ -494,6 +554,8 @@ class OfflineRecognizer(object): | ||
| 494 | decoding_method: str = "greedy_search", | 554 | decoding_method: str = "greedy_search", |
| 495 | debug: bool = False, | 555 | debug: bool = False, |
| 496 | provider: str = "cpu", | 556 | provider: str = "cpu", |
| 557 | + rule_fsts: str = "", | ||
| 558 | + rule_fars: str = "", | ||
| 497 | ): | 559 | ): |
| 498 | """ | 560 | """ |
| 499 | Please refer to | 561 | Please refer to |
| @@ -522,6 +584,12 @@ class OfflineRecognizer(object): | @@ -522,6 +584,12 @@ class OfflineRecognizer(object): | ||
| 522 | True to show debug messages. | 584 | True to show debug messages. |
| 523 | provider: | 585 | provider: |
| 524 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 586 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 587 | + rule_fsts: | ||
| 588 | + If not empty, it specifies fsts for inverse text normalization. | ||
| 589 | + If there are multiple fsts, they are separated by a comma. | ||
| 590 | + rule_fars: | ||
| 591 | + If not empty, it specifies fst archives for inverse text normalization. | ||
| 592 | + If there are multiple archives, they are separated by a comma. | ||
| 525 | """ | 593 | """ |
| 526 | self = cls.__new__(cls) | 594 | self = cls.__new__(cls) |
| 527 | model_config = OfflineModelConfig( | 595 | model_config = OfflineModelConfig( |
| @@ -542,6 +610,8 @@ class OfflineRecognizer(object): | @@ -542,6 +610,8 @@ class OfflineRecognizer(object): | ||
| 542 | feat_config=feat_config, | 610 | feat_config=feat_config, |
| 543 | model_config=model_config, | 611 | model_config=model_config, |
| 544 | decoding_method=decoding_method, | 612 | decoding_method=decoding_method, |
| 613 | + rule_fsts=rule_fsts, | ||
| 614 | + rule_fars=rule_fars, | ||
| 545 | ) | 615 | ) |
| 546 | self.recognizer = _Recognizer(recognizer_config) | 616 | self.recognizer = _Recognizer(recognizer_config) |
| 547 | self.config = recognizer_config | 617 | self.config = recognizer_config |
-
请 注册 或 登录 后发表评论