Fangjun Kuang
Committed by GitHub

Add inverse text normalization for online ASR (#1020)

@@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then @@ -256,7 +256,18 @@ if [[ x$OS != x'windows-latest' ]]; then
256 $repo/test_wavs/3.wav \ 256 $repo/test_wavs/3.wav \
257 $repo/test_wavs/8k.wav 257 $repo/test_wavs/8k.wav
258 258
  259 + ln -s $repo $PWD/
  260 +
  261 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
  262 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
  263 +
  264 + python3 ./python-api-examples/inverse-text-normalization-online-asr.py
  265 +
259 python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose 266 python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose
  267 +
  268 + rm -rfv sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
  269 +
  270 + rm -rf $repo
260 fi 271 fi
261 272
262 log "Test non-streaming transducer models" 273 log "Test non-streaming transducer models"
  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +"""
  6 +This script shows how to use inverse text normalization with streaming ASR.
  7 +
  8 +Usage:
  9 +
  10 +(1) Download the test model
  11 +
  12 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2
  13 +
  14 +(2) Download rule fst
  15 +
  16 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst
  17 +
  18 +Please refer to
  19 +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb
  20 +for how itn_zh_number.fst is generated.
  21 +
  22 +(3) Download test wave
  23 +
  24 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav
  25 +
  26 +(4) Run this script
  27 +
  28 +python3 ./python-api-examples/inverse-text-normalization-online-asr.py
  29 +"""
  30 +from pathlib import Path
  31 +
  32 +import sherpa_onnx
  33 +import soundfile as sf
  34 +
  35 +
  36 +def create_recognizer():
  37 + encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx"
  38 + decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx"
  39 + joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx"
  40 + tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"
  41 + rule_fsts = "./itn_zh_number.fst"
  42 +
  43 + if (
  44 + not Path(encoder).is_file()
  45 + or not Path(decoder).is_file()
  46 + or not Path(joiner).is_file()
  47 + or not Path(tokens).is_file()
  48 + or not Path(rule_fsts).is_file()
  49 + ):
  50 + raise ValueError(
  51 + """Please download model files from
  52 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  53 + """
  54 + )
  55 + return sherpa_onnx.OnlineRecognizer.from_transducer(
  56 + encoder=encoder,
  57 + decoder=decoder,
  58 + joiner=joiner,
  59 + tokens=tokens,
  60 + debug=True,
  61 + rule_fsts=rule_fsts,
  62 + )
  63 +
  64 +
  65 +def main():
  66 + recognizer = create_recognizer()
  67 + wave_filename = "./itn-zh-number.wav"
  68 + if not Path(wave_filename).is_file():
  69 + raise ValueError(
  70 + """Please download model files from
  71 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  72 + """
  73 + )
  74 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  75 + audio = audio[:, 0] # only use the first channel
  76 +
  77 + stream = recognizer.create_stream()
  78 + stream.accept_waveform(sample_rate, audio)
  79 +
  80 + tail_padding = [0] * int(0.3 * sample_rate)
  81 + stream.accept_waveform(sample_rate, tail_padding)
  82 +
  83 + while recognizer.is_ready(stream):
  84 + recognizer.decode_stream(stream)
  85 +
  86 + print(wave_filename)
  87 + print(recognizer.get_result_all(stream))
  88 +
  89 +
  90 +if __name__ == "__main__":
  91 + main()
@@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, @@ -68,7 +68,8 @@ static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
68 class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { 68 class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
69 public: 69 public:
70 explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) 70 explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config)
71 - : config_(config), 71 + : OnlineRecognizerImpl(config),
  72 + config_(config),
72 model_(OnlineCtcModel::Create(config.model_config)), 73 model_(OnlineCtcModel::Create(config.model_config)),
73 sym_(config.model_config.tokens), 74 sym_(config.model_config.tokens),
74 endpoint_(config_.endpoint_config) { 75 endpoint_(config_.endpoint_config) {
@@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -84,7 +85,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
84 #if __ANDROID_API__ >= 9 85 #if __ANDROID_API__ >= 9
85 explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, 86 explicit OnlineRecognizerCtcImpl(AAssetManager *mgr,
86 const OnlineRecognizerConfig &config) 87 const OnlineRecognizerConfig &config)
87 - : config_(config), 88 + : OnlineRecognizerImpl(mgr, config),
  89 + config_(config),
88 model_(OnlineCtcModel::Create(mgr, config.model_config)), 90 model_(OnlineCtcModel::Create(mgr, config.model_config)),
89 sym_(mgr, config.model_config.tokens), 91 sym_(mgr, config.model_config.tokens),
90 endpoint_(config_.endpoint_config) { 92 endpoint_(config_.endpoint_config) {
@@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -182,8 +184,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
182 // TODO(fangjun): Remember to change these constants if needed 184 // TODO(fangjun): Remember to change these constants if needed
183 int32_t frame_shift_ms = 10; 185 int32_t frame_shift_ms = 10;
184 int32_t subsampling_factor = 4; 186 int32_t subsampling_factor = 4;
185 - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, 187 + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
186 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 188 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
  189 + r.text = ApplyInverseTextNormalization(r.text);
  190 + return r;
187 } 191 }
188 192
189 bool IsEndpoint(OnlineStream *s) const override { 193 bool IsEndpoint(OnlineStream *s) const override {
@@ -4,11 +4,22 @@ @@ -4,11 +4,22 @@
4 4
5 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 5 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
6 6
  7 +#if __ANDROID_API__ >= 9
  8 +#include <strstream>
  9 +
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
  14 +#include "fst/extensions/far/far.h"
  15 +#include "kaldifst/csrc/kaldi-fst-io.h"
  16 +#include "sherpa-onnx/csrc/macros.h"
7 #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" 17 #include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
8 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" 18 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
9 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" 19 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
10 #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h" 20 #include "sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h"
11 #include "sherpa-onnx/csrc/onnx-utils.h" 21 #include "sherpa-onnx/csrc/onnx-utils.h"
  22 +#include "sherpa-onnx/csrc/text-utils.h"
12 23
13 namespace sherpa_onnx { 24 namespace sherpa_onnx {
14 25
@@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -78,4 +89,110 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
78 } 89 }
79 #endif 90 #endif
80 91
  92 +OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
  93 + : config_(config) {
  94 + if (!config.rule_fsts.empty()) {
  95 + std::vector<std::string> files;
  96 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  97 + itn_list_.reserve(files.size());
  98 + for (const auto &f : files) {
  99 + if (config.model_config.debug) {
  100 + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
  101 + }
  102 + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
  103 + }
  104 + }
  105 +
  106 + if (!config.rule_fars.empty()) {
  107 + if (config.model_config.debug) {
  108 + SHERPA_ONNX_LOGE("Loading FST archives");
  109 + }
  110 + std::vector<std::string> files;
  111 + SplitStringToVector(config.rule_fars, ",", false, &files);
  112 +
  113 + itn_list_.reserve(files.size() + itn_list_.size());
  114 +
  115 + for (const auto &f : files) {
  116 + if (config.model_config.debug) {
  117 + SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
  118 + }
  119 + std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
  120 + fst::FarReader<fst::StdArc>::Open(f));
  121 + for (; !reader->Done(); reader->Next()) {
  122 + std::unique_ptr<fst::StdConstFst> r(
  123 + fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
  124 +
  125 + itn_list_.push_back(
  126 + std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
  127 + }
  128 + }
  129 +
  130 + if (config.model_config.debug) {
  131 + SHERPA_ONNX_LOGE("FST archives loaded!");
  132 + }
  133 + }
  134 +}
  135 +
  136 +#if __ANDROID_API__ >= 9
  137 +OnlineRecognizerImpl::OnlineRecognizerImpl(AAssetManager *mgr,
  138 + const OnlineRecognizerConfig &config)
  139 + : config_(config) {
  140 + if (!config.rule_fsts.empty()) {
  141 + std::vector<std::string> files;
  142 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  143 + itn_list_.reserve(files.size());
  144 + for (const auto &f : files) {
  145 + if (config.model_config.debug) {
  146 + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
  147 + }
  148 + auto buf = ReadFile(mgr, f);
  149 + std::istrstream is(buf.data(), buf.size());
  150 + itn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
  151 + }
  152 + }
  153 +
  154 + if (!config.rule_fars.empty()) {
  155 + std::vector<std::string> files;
  156 + SplitStringToVector(config.rule_fars, ",", false, &files);
  157 + itn_list_.reserve(files.size() + itn_list_.size());
  158 +
  159 + for (const auto &f : files) {
  160 + if (config.model_config.debug) {
  161 + SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
  162 + }
  163 +
  164 + auto buf = ReadFile(mgr, f);
  165 +
  166 + std::unique_ptr<std::istream> s(
  167 + new std::istrstream(buf.data(), buf.size()));
  168 +
  169 + std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
  170 + fst::FarReader<fst::StdArc>::Open(std::move(s)));
  171 +
  172 + for (; !reader->Done(); reader->Next()) {
  173 + std::unique_ptr<fst::StdConstFst> r(
  174 + fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
  175 +
  176 + itn_list_.push_back(
  177 + std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
  178 + } // for (; !reader->Done(); reader->Next())
  179 + } // for (const auto &f : files)
  180 + } // if (!config.rule_fars.empty())
  181 +}
  182 +#endif
  183 +
  184 +std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
  185 + std::string text) const {
  186 + if (!itn_list_.empty()) {
  187 + for (const auto &tn : itn_list_) {
  188 + text = tn->Normalize(text);
  189 + if (config_.model_config.debug) {
  190 + SHERPA_ONNX_LOGE("After inverse text normalization: %s", text.c_str());
  191 + }
  192 + }
  193 + }
  194 +
  195 + return text;
  196 +}
  197 +
81 } // namespace sherpa_onnx 198 } // namespace sherpa_onnx
@@ -9,6 +9,12 @@ @@ -9,6 +9,12 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "kaldifst/csrc/text-normalizer.h"
12 #include "sherpa-onnx/csrc/macros.h" 18 #include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/online-recognizer.h" 19 #include "sherpa-onnx/csrc/online-recognizer.h"
14 #include "sherpa-onnx/csrc/online-stream.h" 20 #include "sherpa-onnx/csrc/online-stream.h"
@@ -17,10 +23,15 @@ namespace sherpa_onnx { @@ -17,10 +23,15 @@ namespace sherpa_onnx {
17 23
18 class OnlineRecognizerImpl { 24 class OnlineRecognizerImpl {
19 public: 25 public:
  26 + explicit OnlineRecognizerImpl(const OnlineRecognizerConfig &config);
  27 +
20 static std::unique_ptr<OnlineRecognizerImpl> Create( 28 static std::unique_ptr<OnlineRecognizerImpl> Create(
21 const OnlineRecognizerConfig &config); 29 const OnlineRecognizerConfig &config);
22 30
23 #if __ANDROID_API__ >= 9 31 #if __ANDROID_API__ >= 9
  32 + OnlineRecognizerImpl(AAssetManager *mgr,
  33 + const OnlineRecognizerConfig &config);
  34 +
24 static std::unique_ptr<OnlineRecognizerImpl> Create( 35 static std::unique_ptr<OnlineRecognizerImpl> Create(
25 AAssetManager *mgr, const OnlineRecognizerConfig &config); 36 AAssetManager *mgr, const OnlineRecognizerConfig &config);
26 #endif 37 #endif
@@ -50,6 +61,15 @@ class OnlineRecognizerImpl { @@ -50,6 +61,15 @@ class OnlineRecognizerImpl {
50 virtual bool IsEndpoint(OnlineStream *s) const = 0; 61 virtual bool IsEndpoint(OnlineStream *s) const = 0;
51 62
52 virtual void Reset(OnlineStream *s) const = 0; 63 virtual void Reset(OnlineStream *s) const = 0;
  64 +
  65 + std::string ApplyInverseTextNormalization(std::string text) const;
  66 +
  67 + private:
  68 + OnlineRecognizerConfig config_;
  69 + // for inverse text normalization. Used only if
  70 + // config.rule_fsts is not empty or
  71 + // config.rule_fars is not empty
  72 + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
53 }; 73 };
54 74
55 } // namespace sherpa_onnx 75 } // namespace sherpa_onnx
@@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) { @@ -96,7 +96,8 @@ static void Scale(const float *x, int32_t n, float scale, float *y) {
96 class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { 96 class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
97 public: 97 public:
98 explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) 98 explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
99 - : config_(config), 99 + : OnlineRecognizerImpl(config),
  100 + config_(config),
100 model_(config.model_config), 101 model_(config.model_config),
101 sym_(config.model_config.tokens), 102 sym_(config.model_config.tokens),
102 endpoint_(config_.endpoint_config) { 103 endpoint_(config_.endpoint_config) {
@@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { @@ -116,7 +117,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
116 #if __ANDROID_API__ >= 9 117 #if __ANDROID_API__ >= 9
117 explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr, 118 explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
118 const OnlineRecognizerConfig &config) 119 const OnlineRecognizerConfig &config)
119 - : config_(config), 120 + : OnlineRecognizerImpl(mgr, config),
  121 + config_(config),
120 model_(mgr, config.model_config), 122 model_(mgr, config.model_config),
121 sym_(mgr, config.model_config.tokens), 123 sym_(mgr, config.model_config.tokens),
122 endpoint_(config_.endpoint_config) { 124 endpoint_(config_.endpoint_config) {
@@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { @@ -160,7 +162,9 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
160 OnlineRecognizerResult GetResult(OnlineStream *s) const override { 162 OnlineRecognizerResult GetResult(OnlineStream *s) const override {
161 auto decoder_result = s->GetParaformerResult(); 163 auto decoder_result = s->GetParaformerResult();
162 164
163 - return Convert(decoder_result, sym_); 165 + auto r = Convert(decoder_result, sym_);
  166 + r.text = ApplyInverseTextNormalization(r.text);
  167 + return r;
164 } 168 }
165 169
166 bool IsEndpoint(OnlineStream *s) const override { 170 bool IsEndpoint(OnlineStream *s) const override {
@@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -80,7 +80,8 @@ OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
80 class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { 80 class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
81 public: 81 public:
82 explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) 82 explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config)
83 - : config_(config), 83 + : OnlineRecognizerImpl(config),
  84 + config_(config),
84 model_(OnlineTransducerModel::Create(config.model_config)), 85 model_(OnlineTransducerModel::Create(config.model_config)),
85 sym_(config.model_config.tokens), 86 sym_(config.model_config.tokens),
86 endpoint_(config_.endpoint_config) { 87 endpoint_(config_.endpoint_config) {
@@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -124,7 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
124 #if __ANDROID_API__ >= 9 125 #if __ANDROID_API__ >= 9
125 explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr, 126 explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr,
126 const OnlineRecognizerConfig &config) 127 const OnlineRecognizerConfig &config)
127 - : config_(config), 128 + : OnlineRecognizerImpl(mgr, config),
  129 + config_(config),
128 model_(OnlineTransducerModel::Create(mgr, config.model_config)), 130 model_(OnlineTransducerModel::Create(mgr, config.model_config)),
129 sym_(mgr, config.model_config.tokens), 131 sym_(mgr, config.model_config.tokens),
130 endpoint_(config_.endpoint_config) { 132 endpoint_(config_.endpoint_config) {
@@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -332,8 +334,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
332 // TODO(fangjun): Remember to change these constants if needed 334 // TODO(fangjun): Remember to change these constants if needed
333 int32_t frame_shift_ms = 10; 335 int32_t frame_shift_ms = 10;
334 int32_t subsampling_factor = 4; 336 int32_t subsampling_factor = 4;
335 - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, 337 + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
336 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 338 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
  339 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  340 + return r;
337 } 341 }
338 342
339 bool IsEndpoint(OnlineStream *s) const override { 343 bool IsEndpoint(OnlineStream *s) const override {
@@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -42,7 +42,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
42 public: 42 public:
43 explicit OnlineRecognizerTransducerNeMoImpl( 43 explicit OnlineRecognizerTransducerNeMoImpl(
44 const OnlineRecognizerConfig &config) 44 const OnlineRecognizerConfig &config)
45 - : config_(config), 45 + : OnlineRecognizerImpl(config),
  46 + config_(config),
46 symbol_table_(config.model_config.tokens), 47 symbol_table_(config.model_config.tokens),
47 endpoint_(config_.endpoint_config), 48 endpoint_(config_.endpoint_config),
48 model_( 49 model_(
@@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -61,7 +62,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
61 #if __ANDROID_API__ >= 9 62 #if __ANDROID_API__ >= 9
62 explicit OnlineRecognizerTransducerNeMoImpl( 63 explicit OnlineRecognizerTransducerNeMoImpl(
63 AAssetManager *mgr, const OnlineRecognizerConfig &config) 64 AAssetManager *mgr, const OnlineRecognizerConfig &config)
64 - : config_(config), 65 + : OnlineRecognizerImpl(mgr, config),
  66 + config_(config),
65 symbol_table_(mgr, config.model_config.tokens), 67 symbol_table_(mgr, config.model_config.tokens),
66 endpoint_(config_.endpoint_config), 68 endpoint_(config_.endpoint_config),
67 model_(std::make_unique<OnlineTransducerNeMoModel>( 69 model_(std::make_unique<OnlineTransducerNeMoModel>(
@@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -94,9 +96,11 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
94 // TODO(fangjun): Remember to change these constants if needed 96 // TODO(fangjun): Remember to change these constants if needed
95 int32_t frame_shift_ms = 10; 97 int32_t frame_shift_ms = 10;
96 int32_t subsampling_factor = model_->SubsamplingFactor(); 98 int32_t subsampling_factor = model_->SubsamplingFactor();
97 - return Convert(s->GetResult(), symbol_table_, frame_shift_ms, 99 + auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms,
98 subsampling_factor, s->GetCurrentSegment(), 100 subsampling_factor, s->GetCurrentSegment(),
99 s->GetNumFramesSinceStart()); 101 s->GetNumFramesSinceStart());
  102 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  103 + return r;
100 } 104 }
101 105
102 bool IsEndpoint(OnlineStream *s) const override { 106 bool IsEndpoint(OnlineStream *s) const override {
@@ -14,7 +14,9 @@ @@ -14,7 +14,9 @@
14 #include <utility> 14 #include <utility>
15 #include <vector> 15 #include <vector>
16 16
  17 +#include "sherpa-onnx/csrc/file-utils.h"
17 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 18 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
  19 +#include "sherpa-onnx/csrc/text-utils.h"
18 20
19 namespace sherpa_onnx { 21 namespace sherpa_onnx {
20 22
@@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -100,6 +102,15 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
100 "now support greedy_search and modified_beam_search."); 102 "now support greedy_search and modified_beam_search.");
101 po->Register("temperature-scale", &temperature_scale, 103 po->Register("temperature-scale", &temperature_scale,
102 "Temperature scale for confidence computation in decoding."); 104 "Temperature scale for confidence computation in decoding.");
  105 + po->Register(
  106 + "rule-fsts", &rule_fsts,
  107 + "If not empty, it specifies fsts for inverse text normalization. "
  108 + "If there are multiple fsts, they are separated by a comma.");
  109 +
  110 + po->Register(
  111 + "rule-fars", &rule_fars,
  112 + "If not empty, it specifies fst archives for inverse text normalization. "
  113 + "If there are multiple archives, they are separated by a comma.");
103 } 114 }
104 115
105 bool OnlineRecognizerConfig::Validate() const { 116 bool OnlineRecognizerConfig::Validate() const {
@@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const { @@ -129,6 +140,34 @@ bool OnlineRecognizerConfig::Validate() const {
129 return false; 140 return false;
130 } 141 }
131 142
  143 + if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
  144 + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
  145 + hotwords_file.c_str());
  146 + return false;
  147 + }
  148 +
  149 + if (!rule_fsts.empty()) {
  150 + std::vector<std::string> files;
  151 + SplitStringToVector(rule_fsts, ",", false, &files);
  152 + for (const auto &f : files) {
  153 + if (!FileExists(f)) {
  154 + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
  155 + return false;
  156 + }
  157 + }
  158 + }
  159 +
  160 + if (!rule_fars.empty()) {
  161 + std::vector<std::string> files;
  162 + SplitStringToVector(rule_fars, ",", false, &files);
  163 + for (const auto &f : files) {
  164 + if (!FileExists(f)) {
  165 + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
  166 + return false;
  167 + }
  168 + }
  169 + }
  170 +
132 return model_config.Validate(); 171 return model_config.Validate();
133 } 172 }
134 173
@@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -147,7 +186,9 @@ std::string OnlineRecognizerConfig::ToString() const {
147 os << "hotwords_file=\"" << hotwords_file << "\", "; 186 os << "hotwords_file=\"" << hotwords_file << "\", ";
148 os << "decoding_method=\"" << decoding_method << "\", "; 187 os << "decoding_method=\"" << decoding_method << "\", ";
149 os << "blank_penalty=" << blank_penalty << ", "; 188 os << "blank_penalty=" << blank_penalty << ", ";
150 - os << "temperature_scale=" << temperature_scale << ")"; 189 + os << "temperature_scale=" << temperature_scale << ", ";
  190 + os << "rule_fsts=\"" << rule_fsts << "\", ";
  191 + os << "rule_fars=\"" << rule_fars << "\")";
151 192
152 return os.str(); 193 return os.str();
153 } 194 }
@@ -100,6 +100,12 @@ struct OnlineRecognizerConfig { @@ -100,6 +100,12 @@ struct OnlineRecognizerConfig {
100 100
101 float temperature_scale = 2.0; 101 float temperature_scale = 2.0;
102 102
  103 + // If there are multiple rules, they are applied from left to right.
  104 + std::string rule_fsts;
  105 +
  106 + // If there are multiple FST archives, they are applied from left to right.
  107 + std::string rule_fars;
  108 +
103 OnlineRecognizerConfig() = default; 109 OnlineRecognizerConfig() = default;
104 110
105 OnlineRecognizerConfig( 111 OnlineRecognizerConfig(
@@ -109,7 +115,8 @@ struct OnlineRecognizerConfig { @@ -109,7 +115,8 @@ struct OnlineRecognizerConfig {
109 const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, 115 const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
110 bool enable_endpoint, const std::string &decoding_method, 116 bool enable_endpoint, const std::string &decoding_method,
111 int32_t max_active_paths, const std::string &hotwords_file, 117 int32_t max_active_paths, const std::string &hotwords_file,
112 - float hotwords_score, float blank_penalty, float temperature_scale) 118 + float hotwords_score, float blank_penalty, float temperature_scale,
  119 + const std::string &rule_fsts, const std::string &rule_fars)
113 : feat_config(feat_config), 120 : feat_config(feat_config),
114 model_config(model_config), 121 model_config(model_config),
115 lm_config(lm_config), 122 lm_config(lm_config),
@@ -121,7 +128,9 @@ struct OnlineRecognizerConfig { @@ -121,7 +128,9 @@ struct OnlineRecognizerConfig {
121 hotwords_file(hotwords_file), 128 hotwords_file(hotwords_file),
122 hotwords_score(hotwords_score), 129 hotwords_score(hotwords_score),
123 blank_penalty(blank_penalty), 130 blank_penalty(blank_penalty),
124 - temperature_scale(temperature_scale) {} 131 + temperature_scale(temperature_scale),
  132 + rule_fsts(rule_fsts),
  133 + rule_fars(rule_fars) {}
125 134
126 void Register(ParseOptions *po); 135 void Register(ParseOptions *po);
127 bool Validate() const; 136 bool Validate() const;
@@ -54,11 +54,11 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -54,11 +54,11 @@ static void PybindOnlineRecognizerResult(py::module *m) {
54 static void PybindOnlineRecognizerConfig(py::module *m) { 54 static void PybindOnlineRecognizerConfig(py::module *m) {
55 using PyClass = OnlineRecognizerConfig; 55 using PyClass = OnlineRecognizerConfig;
56 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 56 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
57 - .def(  
58 - py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, 57 + .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
59 const OnlineLMConfig &, const EndpointConfig &, 58 const OnlineLMConfig &, const EndpointConfig &,
60 - const OnlineCtcFstDecoderConfig &, bool, const std::string &,  
61 - int32_t, const std::string &, float, float, float>(), 59 + const OnlineCtcFstDecoderConfig &, bool,
  60 + const std::string &, int32_t, const std::string &, float,
  61 + float, float, const std::string &, const std::string &>(),
62 py::arg("feat_config"), py::arg("model_config"), 62 py::arg("feat_config"), py::arg("model_config"),
63 py::arg("lm_config") = OnlineLMConfig(), 63 py::arg("lm_config") = OnlineLMConfig(),
64 py::arg("endpoint_config") = EndpointConfig(), 64 py::arg("endpoint_config") = EndpointConfig(),
@@ -66,7 +66,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -66,7 +66,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
66 py::arg("enable_endpoint"), py::arg("decoding_method"), 66 py::arg("enable_endpoint"), py::arg("decoding_method"),
67 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 67 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
68 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, 68 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
69 - py::arg("temperature_scale") = 2.0) 69 + py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
  70 + py::arg("rule_fars") = "")
70 .def_readwrite("feat_config", &PyClass::feat_config) 71 .def_readwrite("feat_config", &PyClass::feat_config)
71 .def_readwrite("model_config", &PyClass::model_config) 72 .def_readwrite("model_config", &PyClass::model_config)
72 .def_readwrite("lm_config", &PyClass::lm_config) 73 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -79,6 +80,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
79 .def_readwrite("hotwords_score", &PyClass::hotwords_score) 80 .def_readwrite("hotwords_score", &PyClass::hotwords_score)
80 .def_readwrite("blank_penalty", &PyClass::blank_penalty) 81 .def_readwrite("blank_penalty", &PyClass::blank_penalty)
81 .def_readwrite("temperature_scale", &PyClass::temperature_scale) 82 .def_readwrite("temperature_scale", &PyClass::temperature_scale)
  83 + .def_readwrite("rule_fsts", &PyClass::rule_fsts)
  84 + .def_readwrite("rule_fars", &PyClass::rule_fars)
82 .def("__str__", &PyClass::ToString); 85 .def("__str__", &PyClass::ToString);
83 } 86 }
84 87
@@ -64,6 +64,8 @@ class OnlineRecognizer(object): @@ -64,6 +64,8 @@ class OnlineRecognizer(object):
64 lm_scale: float = 0.1, 64 lm_scale: float = 0.1,
65 temperature_scale: float = 2.0, 65 temperature_scale: float = 2.0,
66 debug: bool = False, 66 debug: bool = False,
  67 + rule_fsts: str = "",
  68 + rule_fars: str = "",
67 ): 69 ):
68 """ 70 """
69 Please refer to 71 Please refer to
@@ -148,6 +150,12 @@ class OnlineRecognizer(object): @@ -148,6 +150,12 @@ class OnlineRecognizer(object):
148 the log probability, you can get it from the directory where 150 the log probability, you can get it from the directory where
149 your bpe model is generated. Only used when hotwords provided 151 your bpe model is generated. Only used when hotwords provided
150 and the modeling unit is bpe or cjkchar+bpe. 152 and the modeling unit is bpe or cjkchar+bpe.
  153 + rule_fsts:
  154 + If not empty, it specifies fsts for inverse text normalization.
  155 + If there are multiple fsts, they are separated by a comma.
  156 + rule_fars:
  157 + If not empty, it specifies fst archives for inverse text normalization.
  158 + If there are multiple archives, they are separated by a comma.
151 """ 159 """
152 self = cls.__new__(cls) 160 self = cls.__new__(cls)
153 _assert_file_exists(tokens) 161 _assert_file_exists(tokens)
@@ -217,6 +225,8 @@ class OnlineRecognizer(object): @@ -217,6 +225,8 @@ class OnlineRecognizer(object):
217 hotwords_file=hotwords_file, 225 hotwords_file=hotwords_file,
218 blank_penalty=blank_penalty, 226 blank_penalty=blank_penalty,
219 temperature_scale=temperature_scale, 227 temperature_scale=temperature_scale,
  228 + rule_fsts=rule_fsts,
  229 + rule_fars=rule_fars,
220 ) 230 )
221 231
222 self.recognizer = _Recognizer(recognizer_config) 232 self.recognizer = _Recognizer(recognizer_config)
@@ -239,6 +249,8 @@ class OnlineRecognizer(object): @@ -239,6 +249,8 @@ class OnlineRecognizer(object):
239 decoding_method: str = "greedy_search", 249 decoding_method: str = "greedy_search",
240 provider: str = "cpu", 250 provider: str = "cpu",
241 debug: bool = False, 251 debug: bool = False,
  252 + rule_fsts: str = "",
  253 + rule_fars: str = "",
242 ): 254 ):
243 """ 255 """
244 Please refer to 256 Please refer to
@@ -283,6 +295,12 @@ class OnlineRecognizer(object): @@ -283,6 +295,12 @@ class OnlineRecognizer(object):
283 The only valid value is greedy_search. 295 The only valid value is greedy_search.
284 provider: 296 provider:
285 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 297 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  298 + rule_fsts:
  299 + If not empty, it specifies fsts for inverse text normalization.
  300 + If there are multiple fsts, they are separated by a comma.
  301 + rule_fars:
  302 + If not empty, it specifies fst archives for inverse text normalization.
  303 + If there are multiple archives, they are separated by a comma.
286 """ 304 """
287 self = cls.__new__(cls) 305 self = cls.__new__(cls)
288 _assert_file_exists(tokens) 306 _assert_file_exists(tokens)
@@ -322,6 +340,8 @@ class OnlineRecognizer(object): @@ -322,6 +340,8 @@ class OnlineRecognizer(object):
322 endpoint_config=endpoint_config, 340 endpoint_config=endpoint_config,
323 enable_endpoint=enable_endpoint_detection, 341 enable_endpoint=enable_endpoint_detection,
324 decoding_method=decoding_method, 342 decoding_method=decoding_method,
  343 + rule_fsts=rule_fsts,
  344 + rule_fars=rule_fars,
325 ) 345 )
326 346
327 self.recognizer = _Recognizer(recognizer_config) 347 self.recognizer = _Recognizer(recognizer_config)
@@ -345,6 +365,8 @@ class OnlineRecognizer(object): @@ -345,6 +365,8 @@ class OnlineRecognizer(object):
345 ctc_max_active: int = 3000, 365 ctc_max_active: int = 3000,
346 provider: str = "cpu", 366 provider: str = "cpu",
347 debug: bool = False, 367 debug: bool = False,
  368 + rule_fsts: str = "",
  369 + rule_fars: str = "",
348 ): 370 ):
349 """ 371 """
350 Please refer to 372 Please refer to
@@ -393,6 +415,12 @@ class OnlineRecognizer(object): @@ -393,6 +415,12 @@ class OnlineRecognizer(object):
393 active paths at a time. 415 active paths at a time.
394 provider: 416 provider:
395 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 417 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  418 + rule_fsts:
  419 + If not empty, it specifies fsts for inverse text normalization.
  420 + If there are multiple fsts, they are separated by a comma.
  421 + rule_fars:
  422 + If not empty, it specifies fst archives for inverse text normalization.
  423 + If there are multiple archives, they are separated by a comma.
396 """ 424 """
397 self = cls.__new__(cls) 425 self = cls.__new__(cls)
398 _assert_file_exists(tokens) 426 _assert_file_exists(tokens)
@@ -433,6 +461,8 @@ class OnlineRecognizer(object): @@ -433,6 +461,8 @@ class OnlineRecognizer(object):
433 ctc_fst_decoder_config=ctc_fst_decoder_config, 461 ctc_fst_decoder_config=ctc_fst_decoder_config,
434 enable_endpoint=enable_endpoint_detection, 462 enable_endpoint=enable_endpoint_detection,
435 decoding_method=decoding_method, 463 decoding_method=decoding_method,
  464 + rule_fsts=rule_fsts,
  465 + rule_fars=rule_fars,
436 ) 466 )
437 467
438 self.recognizer = _Recognizer(recognizer_config) 468 self.recognizer = _Recognizer(recognizer_config)
@@ -454,6 +484,8 @@ class OnlineRecognizer(object): @@ -454,6 +484,8 @@ class OnlineRecognizer(object):
454 decoding_method: str = "greedy_search", 484 decoding_method: str = "greedy_search",
455 provider: str = "cpu", 485 provider: str = "cpu",
456 debug: bool = False, 486 debug: bool = False,
  487 + rule_fsts: str = "",
  488 + rule_fars: str = "",
457 ): 489 ):
458 """ 490 """
459 Please refer to 491 Please refer to
@@ -497,6 +529,12 @@ class OnlineRecognizer(object): @@ -497,6 +529,12 @@ class OnlineRecognizer(object):
497 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 529 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
498 debug: 530 debug:
499 True to show meta data in the model. 531 True to show meta data in the model.
  532 + rule_fsts:
  533 + If not empty, it specifies fsts for inverse text normalization.
  534 + If there are multiple fsts, they are separated by a comma.
  535 + rule_fars:
  536 + If not empty, it specifies fst archives for inverse text normalization.
  537 + If there are multiple archives, they are separated by a comma.
500 """ 538 """
501 self = cls.__new__(cls) 539 self = cls.__new__(cls)
502 _assert_file_exists(tokens) 540 _assert_file_exists(tokens)
@@ -533,6 +571,8 @@ class OnlineRecognizer(object): @@ -533,6 +571,8 @@ class OnlineRecognizer(object):
533 endpoint_config=endpoint_config, 571 endpoint_config=endpoint_config,
534 enable_endpoint=enable_endpoint_detection, 572 enable_endpoint=enable_endpoint_detection,
535 decoding_method=decoding_method, 573 decoding_method=decoding_method,
  574 + rule_fsts=rule_fsts,
  575 + rule_fars=rule_fars,
536 ) 576 )
537 577
538 self.recognizer = _Recognizer(recognizer_config) 578 self.recognizer = _Recognizer(recognizer_config)
@@ -556,6 +596,8 @@ class OnlineRecognizer(object): @@ -556,6 +596,8 @@ class OnlineRecognizer(object):
556 decoding_method: str = "greedy_search", 596 decoding_method: str = "greedy_search",
557 provider: str = "cpu", 597 provider: str = "cpu",
558 debug: bool = False, 598 debug: bool = False,
  599 + rule_fsts: str = "",
  600 + rule_fars: str = "",
559 ): 601 ):
560 """ 602 """
561 Please refer to 603 Please refer to
@@ -602,6 +644,12 @@ class OnlineRecognizer(object): @@ -602,6 +644,12 @@ class OnlineRecognizer(object):
602 The only valid value is greedy_search. 644 The only valid value is greedy_search.
603 provider: 645 provider:
604 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 646 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  647 + rule_fsts:
  648 + If not empty, it specifies fsts for inverse text normalization.
  649 + If there are multiple fsts, they are separated by a comma.
  650 + rule_fars:
  651 + If not empty, it specifies fst archives for inverse text normalization.
  652 + If there are multiple archives, they are separated by a comma.
605 """ 653 """
606 self = cls.__new__(cls) 654 self = cls.__new__(cls)
607 _assert_file_exists(tokens) 655 _assert_file_exists(tokens)
@@ -640,6 +688,8 @@ class OnlineRecognizer(object): @@ -640,6 +688,8 @@ class OnlineRecognizer(object):
640 endpoint_config=endpoint_config, 688 endpoint_config=endpoint_config,
641 enable_endpoint=enable_endpoint_detection, 689 enable_endpoint=enable_endpoint_detection,
642 decoding_method=decoding_method, 690 decoding_method=decoding_method,
  691 + rule_fsts=rule_fsts,
  692 + rule_fars=rule_fars,
643 ) 693 )
644 694
645 self.recognizer = _Recognizer(recognizer_config) 695 self.recognizer = _Recognizer(recognizer_config)