Fangjun Kuang
Committed by GitHub

Support replacing homonphonic phrases (#2153)

正在显示 42 个修改的文件 包含 834 行增加134 行删除
@@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do @@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do
98 done 98 done
99 done 99 done
100 100
  101 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
  102 +tar xf dict.tar.bz2
  103 +rm dict.tar.bz2
  104 +
  105 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
  106 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
  107 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
  108 +
  109 +for m in model.onnx model.int8.onnx; do
  110 + for use_itn in 0 1; do
  111 + echo "$m $w $use_itn"
  112 + time $EXE \
  113 + --tokens=$repo/tokens.txt \
  114 + --sense-voice-model=$repo/$m \
  115 + --sense-voice-use-itn=$use_itn \
  116 + --hr-lexicon=./lexicon.txt \
  117 + --hr-dict-dir=./dict \
  118 + --hr-rule-fsts=./replace.fst \
  119 + ./test-hr.wav
  120 + done
  121 +done
  122 +
  123 +rm -rf dict replace.fst test-hr.wav lexicon.txt
101 124
102 # test wav reader for non-standard wav files 125 # test wav reader for non-standard wav files
103 waves=( 126 waves=(
@@ -95,6 +95,18 @@ rm $name @@ -95,6 +95,18 @@ rm $name
95 ls -lh $repo 95 ls -lh $repo
96 python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py 96 python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
97 97
  98 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
  99 +tar xf dict.tar.bz2
  100 +rm dict.tar.bz2
  101 +
  102 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
  103 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
  104 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
  105 +
  106 +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
  107 +
  108 +rm -rf dict replace.fst test-hr.wav lexicon.txt
  109 +
98 if [[ $(uname) == Linux ]]; then 110 if [[ $(uname) == Linux ]]; then
99 # It needs ffmpeg 111 # It needs ffmpeg
100 log "generate subtitles (Chinese)" 112 log "generate subtitles (Chinese)"
1 function(download_kaldifst) 1 function(download_kaldifst)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz")  
5 - set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz")  
6 - set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f") 4 + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz")
  5 + set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz")
  6 + set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2")
7 7
8 # If you don't have access to the Internet, 8 # If you don't have access to the Internet,
9 # please pre-download kaldifst 9 # please pre-download kaldifst
10 set(possible_file_locations 10 set(possible_file_locations
11 - $ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz  
12 - ${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz  
13 - ${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz  
14 - /tmp/kaldifst-1.7.11.tar.gz  
15 - /star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz 11 + $ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz
  12 + ${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz
  13 + ${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz
  14 + /tmp/kaldifst-1.7.13.tar.gz
  15 + /star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz
16 ) 16 )
17 17
18 foreach(f IN LISTS possible_file_locations) 18 foreach(f IN LISTS possible_file_locations)
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use a non-streaming SenseVoice CTC model from
  5 +https://github.com/FunAudioLLM/SenseVoice
  6 +to decode files.
  7 +
  8 +Please download model files from
  9 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  10 +
  11 +For instance,
  12 +
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  14 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  15 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
  16 +
  17 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
  18 +tar xf dict.tar.bz2
  19 +
  20 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
  21 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
  22 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
  23 +"""
  24 +
  25 +from pathlib import Path
  26 +
  27 +import sherpa_onnx
  28 +import soundfile as sf
  29 +
  30 +
  31 +def create_recognizer():
  32 + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx"
  33 + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
  34 + test_wav = "./test-hr.wav"
  35 +
  36 + if not Path(model).is_file() or not Path(test_wav).is_file():
  37 + raise ValueError(
  38 + """Please download model files from
  39 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
  40 + and
  41 + https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files
  42 + """
  43 + )
  44 + return (
  45 + sherpa_onnx.OfflineRecognizer.from_sense_voice(
  46 + model=model,
  47 + tokens=tokens,
  48 + use_itn=True,
  49 + debug=True,
  50 + hr_lexicon="./lexicon.txt",
  51 + hr_dict_dir="./dict",
  52 + hr_rule_fsts="./replace.fst",
  53 + ),
  54 + test_wav,
  55 + )
  56 +
  57 +
  58 +def main():
  59 + recognizer, wave_filename = create_recognizer()
  60 +
  61 + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
  62 + audio = audio[:, 0] # only use the first channel
  63 +
  64 + # audio is a 1-D float32 numpy array normalized to the range [-1, 1]
  65 + # sample_rate does not need to be 16000 Hz
  66 +
  67 + stream = recognizer.create_stream()
  68 + stream.accept_waveform(sample_rate, audio)
  69 + recognizer.decode_stream(stream)
  70 + print(wave_filename)
  71 + print(stream.result)
  72 +
  73 +
  74 +if __name__ == "__main__":
  75 + main()
@@ -20,7 +20,9 @@ set(sources @@ -20,7 +20,9 @@ set(sources
20 features.cc 20 features.cc
21 file-utils.cc 21 file-utils.cc
22 fst-utils.cc 22 fst-utils.cc
  23 + homophone-replacer.cc
23 hypothesis.cc 24 hypothesis.cc
  25 + jieba.cc
24 keyword-spotter-impl.cc 26 keyword-spotter-impl.cc
25 keyword-spotter.cc 27 keyword-spotter.cc
26 offline-ctc-fst-decoder-config.cc 28 offline-ctc-fst-decoder-config.cc
  1 +// sherpa-onnx/csrc/homophone-replacer.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/homophone-replacer.h"
  6 +
  7 +#include <fstream>
  8 +#include <sstream>
  9 +#include <string>
  10 +#include <strstream>
  11 +#include <unordered_map>
  12 +#include <utility>
  13 +#include <vector>
  14 +
  15 +#if __ANDROID_API__ >= 9
  16 +#include "android/asset_manager.h"
  17 +#include "android/asset_manager_jni.h"
  18 +#endif
  19 +
  20 +#if __OHOS__
  21 +#include "rawfile/raw_file_manager.h"
  22 +#endif
  23 +
  24 +#include "kaldifst/csrc/text-normalizer.h"
  25 +#include "sherpa-onnx/csrc/file-utils.h"
  26 +#include "sherpa-onnx/csrc/jieba.h"
  27 +#include "sherpa-onnx/csrc/macros.h"
  28 +#include "sherpa-onnx/csrc/text-utils.h"
  29 +
  30 +namespace sherpa_onnx {
  31 +
  32 +void HomophoneReplacerConfig::Register(ParseOptions *po) {
  33 + po->Register("hr-dict-dir", &dict_dir,
  34 + "The dict directory for jieba used by HomophoneReplacer");
  35 +
  36 + po->Register("hr-lexicon", &lexicon,
  37 + "Path to lexicon.txt used by HomophoneReplacer.");
  38 +
  39 + po->Register("hr-rule-fsts", &rule_fsts,
  40 + "Fst files for HomophoneReplacer. If there are multiple, they "
  41 + "are separated by a comma. E.g., a.fst,b.fst,c.fst");
  42 +}
  43 +
  44 +bool HomophoneReplacerConfig::Validate() const {
  45 + if (!dict_dir.empty()) {
  46 + std::vector<std::string> required_files = {
  47 + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8",
  48 + "idf.utf8", "stop_words.utf8",
  49 + };
  50 +
  51 + for (const auto &f : required_files) {
  52 + if (!FileExists(dict_dir + "/" + f)) {
  53 + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir",
  54 + dict_dir.c_str(), f.c_str());
  55 + return false;
  56 + }
  57 + }
  58 + }
  59 +
  60 + if (!lexicon.empty() && !FileExists(lexicon)) {
  61 + SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str());
  62 + return false;
  63 + }
  64 +
  65 + if (!rule_fsts.empty()) {
  66 + std::vector<std::string> files;
  67 + SplitStringToVector(rule_fsts, ",", false, &files);
  68 +
  69 + if (files.size() > 1) {
  70 + SHERPA_ONNX_LOGE("Only 1 file is supported now.");
  71 + SHERPA_ONNX_EXIT(-1);
  72 + }
  73 +
  74 + for (const auto &f : files) {
  75 + if (!FileExists(f)) {
  76 + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
  77 + return false;
  78 + }
  79 + }
  80 + }
  81 +
  82 + return true;
  83 +}
  84 +
  85 +std::string HomophoneReplacerConfig::ToString() const {
  86 + std::ostringstream os;
  87 +
  88 + os << "HomophoneReplacerConfig(";
  89 + os << "dict_dir=\"" << dict_dir << "\", ";
  90 + os << "lexicon=\"" << lexicon << "\", ";
  91 + os << "rule_fsts=\"" << rule_fsts << "\")";
  92 +
  93 + return os.str();
  94 +}
  95 +
  96 +class HomophoneReplacer::Impl {
  97 + public:
  98 + explicit Impl(const HomophoneReplacerConfig &config) : config_(config) {
  99 + jieba_ = InitJieba(config.dict_dir);
  100 +
  101 + {
  102 + std::ifstream is(config.lexicon);
  103 + InitLexicon(is);
  104 + }
  105 +
  106 + if (!config.rule_fsts.empty()) {
  107 + std::vector<std::string> files;
  108 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  109 + replacer_list_.reserve(files.size());
  110 + for (const auto &f : files) {
  111 + if (config.debug) {
  112 + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
  113 + }
  114 + replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
  115 + }
  116 + }
  117 + }
  118 +
  119 + template <typename Manager>
  120 + Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) {
  121 + jieba_ = InitJieba(config.dict_dir);
  122 + {
  123 + auto buf = ReadFile(mgr, config.lexicon);
  124 +
  125 + std::istrstream is(buf.data(), buf.size());
  126 + InitLexicon(is);
  127 + }
  128 +
  129 + if (!config.rule_fsts.empty()) {
  130 + std::vector<std::string> files;
  131 + SplitStringToVector(config.rule_fsts, ",", false, &files);
  132 + replacer_list_.reserve(files.size());
  133 + for (const auto &f : files) {
  134 + if (config.debug) {
  135 + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
  136 + }
  137 + auto buf = ReadFile(mgr, f);
  138 + std::istrstream is(buf.data(), buf.size());
  139 + replacer_list_.push_back(
  140 + std::make_unique<kaldifst::TextNormalizer>(is));
  141 + }
  142 + }
  143 + }
  144 +
  145 + std::string Apply(const std::string &text) const {
  146 + bool is_hmm = true;
  147 +
  148 + std::vector<std::string> words;
  149 + jieba_->Cut(text, words, is_hmm);
  150 + if (config_.debug) {
  151 + SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str());
  152 + std::ostringstream os;
  153 + os << "After jieba: ";
  154 + std::string sep;
  155 + for (const auto &w : words) {
  156 + os << sep << w;
  157 + sep = "_";
  158 + }
  159 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  160 + }
  161 +
  162 + // convert words to pronunciations
  163 + std::vector<std::string> pronunciations;
  164 +
  165 + for (const auto &w : words) {
  166 + auto p = ConvertWordToPronunciation(w);
  167 + if (config_.debug) {
  168 + SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str());
  169 + }
  170 + pronunciations.push_back(std::move(p));
  171 + }
  172 +
  173 + std::string ans;
  174 + for (const auto &r : replacer_list_) {
  175 + ans = r->Normalize(words, pronunciations);
  176 + // TODO(fangjun): We support only 1 rule fst at present.
  177 + break;
  178 + }
  179 +
  180 + return ans;
  181 + }
  182 +
  183 + private:
  184 + std::string ConvertWordToPronunciation(const std::string &word) const {
  185 + if (word2pron_.count(word)) {
  186 + return word2pron_.at(word);
  187 + }
  188 +
  189 + if (word.size() <= 3) {
  190 + // not a Chinese character
  191 + return word;
  192 + }
  193 +
  194 + std::vector<std::string> words = SplitUtf8(word);
  195 + std::string ans;
  196 + for (const auto &w : words) {
  197 + if (word2pron_.count(w)) {
  198 + ans.append(word2pron_.at(w));
  199 + } else {
  200 + ans.append(w);
  201 + }
  202 + }
  203 +
  204 + return ans;
  205 + }
  206 +
  207 + void InitLexicon(std::istream &is) {
  208 + std::string word;
  209 + std::string pron;
  210 + std::string p;
  211 +
  212 + std::string line;
  213 + int32_t line_num = 0;
  214 + int32_t num_warn = 0;
  215 + while (std::getline(is, line)) {
  216 + ++line_num;
  217 + std::istringstream iss(line);
  218 +
  219 + pron.clear();
  220 + iss >> word;
  221 + ToLowerCase(&word);
  222 +
  223 + if (word2pron_.count(word)) {
  224 + num_warn += 1;
  225 + if (num_warn < 10) {
  226 + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
  227 + word.c_str(), line_num, line.c_str());
  228 + }
  229 + continue;
  230 + }
  231 +
  232 + while (iss >> p) {
  233 + pron.append(std::move(p));
  234 + }
  235 +
  236 + if (pron.empty()) {
  237 + SHERPA_ONNX_LOGE(
  238 + "Empty pronunciation for word '%s' at line %d:%s. Ignore it.",
  239 + word.c_str(), line_num, line.c_str());
  240 + continue;
  241 + }
  242 +
  243 + word2pron_.insert({std::move(word), std::move(pron)});
  244 + }
  245 + }
  246 +
  247 + private:
  248 + HomophoneReplacerConfig config_;
  249 + std::unique_ptr<cppjieba::Jieba> jieba_;
  250 + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_;
  251 + std::unordered_map<std::string, std::string> word2pron_;
  252 +};
  253 +
  254 +HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config)
  255 + : impl_(std::make_unique<Impl>(config)) {}
  256 +
  257 +template <typename Manager>
  258 +HomophoneReplacer::HomophoneReplacer(Manager *mgr,
  259 + const HomophoneReplacerConfig &config)
  260 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  261 +
  262 +HomophoneReplacer::~HomophoneReplacer() = default;
  263 +
  264 +std::string HomophoneReplacer::Apply(const std::string &text) const {
  265 + return impl_->Apply(text);
  266 +}
  267 +
  268 +#if __ANDROID_API__ >= 9
  269 +template HomophoneReplacer::HomophoneReplacer(
  270 + AAssetManager *mgr, const HomophoneReplacerConfig &config);
  271 +#endif
  272 +
  273 +#if __OHOS__
  274 +template HomophoneReplacer::HomophoneReplacer(
  275 + NativeResourceManager *mgr, const HomophoneReplacerConfig &config);
  276 +#endif
  277 +
  278 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/homophone-replacer.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
  6 +#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +struct HomophoneReplacerConfig {
  16 + std::string dict_dir;
  17 + std::string lexicon;
  18 +
  19 + // comma separated fst files, e.g. a.fst,b.fst,c.fst
  20 + std::string rule_fsts;
  21 +
  22 + bool debug;
  23 +
  24 + HomophoneReplacerConfig() = default;
  25 +
  26 + HomophoneReplacerConfig(const std::string &dict_dir,
  27 + const std::string &lexicon,
  28 + const std::string &rule_fsts, bool debug)
  29 + : dict_dir(dict_dir),
  30 + lexicon(lexicon),
  31 + rule_fsts(rule_fsts),
  32 + debug(debug) {}
  33 +
  34 + void Register(ParseOptions *po);
  35 + bool Validate() const;
  36 +
  37 + std::string ToString() const;
  38 +};
  39 +
  40 +class HomophoneReplacer {
  41 + public:
  42 + explicit HomophoneReplacer(const HomophoneReplacerConfig &config);
  43 +
  44 + template <typename Manager>
  45 + HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config);
  46 +
  47 + ~HomophoneReplacer();
  48 +
  49 + std::string Apply(const std::string &text) const;
  50 +
  51 + private:
  52 + class Impl;
  53 + std::unique_ptr<Impl> impl_;
  54 +};
  55 +
  56 +} // namespace sherpa_onnx
  57 +
  58 +#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
@@ -19,8 +19,8 @@ @@ -19,8 +19,8 @@
19 #include "rawfile/raw_file_manager.h" 19 #include "rawfile/raw_file_manager.h"
20 #endif 20 #endif
21 21
22 -#include "cppjieba/Jieba.hpp"  
23 #include "sherpa-onnx/csrc/file-utils.h" 22 #include "sherpa-onnx/csrc/file-utils.h"
  23 +#include "sherpa-onnx/csrc/jieba.h"
24 #include "sherpa-onnx/csrc/macros.h" 24 #include "sherpa-onnx/csrc/macros.h"
25 #include "sherpa-onnx/csrc/onnx-utils.h" 25 #include "sherpa-onnx/csrc/onnx-utils.h"
26 #include "sherpa-onnx/csrc/symbol-table.h" 26 #include "sherpa-onnx/csrc/symbol-table.h"
@@ -41,20 +41,7 @@ class JiebaLexicon::Impl { @@ -41,20 +41,7 @@ class JiebaLexicon::Impl {
41 Impl(const std::string &lexicon, const std::string &tokens, 41 Impl(const std::string &lexicon, const std::string &tokens,
42 const std::string &dict_dir, bool debug) 42 const std::string &dict_dir, bool debug)
43 : debug_(debug) { 43 : debug_(debug) {
44 - std::string dict = dict_dir + "/jieba.dict.utf8";  
45 - std::string hmm = dict_dir + "/hmm_model.utf8";  
46 - std::string user_dict = dict_dir + "/user.dict.utf8";  
47 - std::string idf = dict_dir + "/idf.utf8";  
48 - std::string stop_word = dict_dir + "/stop_words.utf8";  
49 -  
50 - AssertFileExists(dict);  
51 - AssertFileExists(hmm);  
52 - AssertFileExists(user_dict);  
53 - AssertFileExists(idf);  
54 - AssertFileExists(stop_word);  
55 -  
56 - jieba_ =  
57 - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); 44 + jieba_ = InitJieba(dict_dir);
58 45
59 { 46 {
60 std::ifstream is(tokens); 47 std::ifstream is(tokens);
@@ -71,20 +58,7 @@ class JiebaLexicon::Impl { @@ -71,20 +58,7 @@ class JiebaLexicon::Impl {
71 Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, 58 Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens,
72 const std::string &dict_dir, bool debug) 59 const std::string &dict_dir, bool debug)
73 : debug_(debug) { 60 : debug_(debug) {
74 - std::string dict = dict_dir + "/jieba.dict.utf8";  
75 - std::string hmm = dict_dir + "/hmm_model.utf8";  
76 - std::string user_dict = dict_dir + "/user.dict.utf8";  
77 - std::string idf = dict_dir + "/idf.utf8";  
78 - std::string stop_word = dict_dir + "/stop_words.utf8";  
79 -  
80 - AssertFileExists(dict);  
81 - AssertFileExists(hmm);  
82 - AssertFileExists(user_dict);  
83 - AssertFileExists(idf);  
84 - AssertFileExists(stop_word);  
85 -  
86 - jieba_ =  
87 - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); 61 + jieba_ = InitJieba(dict_dir);
88 62
89 { 63 {
90 auto buf = ReadFile(mgr, tokens); 64 auto buf = ReadFile(mgr, tokens);
  1 +// sherpa-onnx/csrc/jieba.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/jieba.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir) {
  12 + if (dict_dir.empty()) {
  13 + return {};
  14 + }
  15 +
  16 + std::string dict = dict_dir + "/jieba.dict.utf8";
  17 + std::string hmm = dict_dir + "/hmm_model.utf8";
  18 + std::string user_dict = dict_dir + "/user.dict.utf8";
  19 + std::string idf = dict_dir + "/idf.utf8";
  20 + std::string stop_word = dict_dir + "/stop_words.utf8";
  21 +
  22 + AssertFileExists(dict);
  23 + AssertFileExists(hmm);
  24 + AssertFileExists(user_dict);
  25 + AssertFileExists(idf);
  26 + AssertFileExists(stop_word);
  27 +
  28 + return std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf,
  29 + stop_word);
  30 +}
  31 +
  32 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/jieba.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_JIEBA_H_
  6 +#define SHERPA_ONNX_CSRC_JIEBA_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "cppjieba/Jieba.hpp"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir);
  16 +}
  17 +
  18 +#endif // SHERPA_ONNX_CSRC_JIEBA_H_
@@ -22,11 +22,11 @@ @@ -22,11 +22,11 @@
22 22
23 #include <codecvt> 23 #include <codecvt>
24 24
25 -#include "cppjieba/Jieba.hpp"  
26 #include "espeak-ng/speak_lib.h" 25 #include "espeak-ng/speak_lib.h"
27 #include "phoneme_ids.hpp" 26 #include "phoneme_ids.hpp"
28 #include "phonemize.hpp" 27 #include "phonemize.hpp"
29 #include "sherpa-onnx/csrc/file-utils.h" 28 #include "sherpa-onnx/csrc/file-utils.h"
  29 +#include "sherpa-onnx/csrc/jieba.h"
30 #include "sherpa-onnx/csrc/onnx-utils.h" 30 #include "sherpa-onnx/csrc/onnx-utils.h"
31 #include "sherpa-onnx/csrc/symbol-table.h" 31 #include "sherpa-onnx/csrc/symbol-table.h"
32 #include "sherpa-onnx/csrc/text-utils.h" 32 #include "sherpa-onnx/csrc/text-utils.h"
@@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl { @@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl {
47 47
48 InitLexicon(lexicon); 48 InitLexicon(lexicon);
49 49
50 - InitJieba(dict_dir); 50 + jieba_ = InitJieba(dict_dir);
51 51
52 InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc 52 InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
53 } 53 }
@@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl { @@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl {
62 InitLexicon(mgr, lexicon); 62 InitLexicon(mgr, lexicon);
63 63
64 // we assume you have copied dict_dir and data_dir from assets to some path 64 // we assume you have copied dict_dir and data_dir from assets to some path
65 - InitJieba(dict_dir); 65 + jieba_ = InitJieba(dict_dir);
66 66
67 InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc 67 InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
68 } 68 }
@@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl { @@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl {
456 } 456 }
457 } 457 }
458 458
459 - void InitJieba(const std::string &dict_dir) {  
460 - std::string dict = dict_dir + "/jieba.dict.utf8";  
461 - std::string hmm = dict_dir + "/hmm_model.utf8";  
462 - std::string user_dict = dict_dir + "/user.dict.utf8";  
463 - std::string idf = dict_dir + "/idf.utf8";  
464 - std::string stop_word = dict_dir + "/stop_words.utf8";  
465 -  
466 - AssertFileExists(dict);  
467 - AssertFileExists(hmm);  
468 - AssertFileExists(user_dict);  
469 - AssertFileExists(idf);  
470 - AssertFileExists(stop_word);  
471 -  
472 - jieba_ =  
473 - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);  
474 - }  
475 -  
476 private: 459 private:
477 OfflineTtsKokoroModelMetaData meta_data_; 460 OfflineTtsKokoroModelMetaData meta_data_;
478 461
@@ -19,8 +19,8 @@ @@ -19,8 +19,8 @@
19 #include "rawfile/raw_file_manager.h" 19 #include "rawfile/raw_file_manager.h"
20 #endif 20 #endif
21 21
22 -#include "cppjieba/Jieba.hpp"  
23 #include "sherpa-onnx/csrc/file-utils.h" 22 #include "sherpa-onnx/csrc/file-utils.h"
  23 +#include "sherpa-onnx/csrc/jieba.h"
24 #include "sherpa-onnx/csrc/macros.h" 24 #include "sherpa-onnx/csrc/macros.h"
25 #include "sherpa-onnx/csrc/onnx-utils.h" 25 #include "sherpa-onnx/csrc/onnx-utils.h"
26 #include "sherpa-onnx/csrc/symbol-table.h" 26 #include "sherpa-onnx/csrc/symbol-table.h"
@@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl { @@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl {
34 const std::string &dict_dir, 34 const std::string &dict_dir,
35 const OfflineTtsVitsModelMetaData &meta_data, bool debug) 35 const OfflineTtsVitsModelMetaData &meta_data, bool debug)
36 : meta_data_(meta_data), debug_(debug) { 36 : meta_data_(meta_data), debug_(debug) {
37 - std::string dict = dict_dir + "/jieba.dict.utf8";  
38 - std::string hmm = dict_dir + "/hmm_model.utf8";  
39 - std::string user_dict = dict_dir + "/user.dict.utf8";  
40 - std::string idf = dict_dir + "/idf.utf8";  
41 - std::string stop_word = dict_dir + "/stop_words.utf8";  
42 -  
43 - AssertFileExists(dict);  
44 - AssertFileExists(hmm);  
45 - AssertFileExists(user_dict);  
46 - AssertFileExists(idf);  
47 - AssertFileExists(stop_word);  
48 -  
49 - jieba_ =  
50 - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); 37 + jieba_ = InitJieba(dict_dir);
51 38
52 { 39 {
53 std::ifstream is(tokens); 40 std::ifstream is(tokens);
@@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl { @@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl {
79 const std::string &dict_dir, 66 const std::string &dict_dir,
80 const OfflineTtsVitsModelMetaData &meta_data, bool debug) 67 const OfflineTtsVitsModelMetaData &meta_data, bool debug)
81 : meta_data_(meta_data), debug_(debug) { 68 : meta_data_(meta_data), debug_(debug) {
82 - std::string dict = dict_dir + "/jieba.dict.utf8";  
83 - std::string hmm = dict_dir + "/hmm_model.utf8";  
84 - std::string user_dict = dict_dir + "/user.dict.utf8";  
85 - std::string idf = dict_dir + "/idf.utf8";  
86 - std::string stop_word = dict_dir + "/stop_words.utf8";  
87 -  
88 - AssertFileExists(dict);  
89 - AssertFileExists(hmm);  
90 - AssertFileExists(user_dict);  
91 - AssertFileExists(idf);  
92 - AssertFileExists(stop_word);  
93 -  
94 - jieba_ =  
95 - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); 69 + jieba_ = InitJieba(dict_dir);
96 70
97 { 71 {
98 auto buf = ReadFile(mgr, tokens); 72 auto buf = ReadFile(mgr, tokens);
@@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
239 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 239 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
240 model_->SubsamplingFactor()); 240 model_->SubsamplingFactor());
241 r.text = ApplyInverseTextNormalization(std::move(r.text)); 241 r.text = ApplyInverseTextNormalization(std::move(r.text));
  242 + r.text = ApplyHomophoneReplacer(std::move(r.text));
242 ss[i]->SetResult(r); 243 ss[i]->SetResult(r);
243 } 244 }
244 } 245 }
@@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
277 auto r = Convert(results[0], symbol_table_, frame_shift_ms, 278 auto r = Convert(results[0], symbol_table_, frame_shift_ms,
278 model_->SubsamplingFactor()); 279 model_->SubsamplingFactor());
279 r.text = ApplyInverseTextNormalization(std::move(r.text)); 280 r.text = ApplyInverseTextNormalization(std::move(r.text));
  281 + r.text = ApplyHomophoneReplacer(std::move(r.text));
280 s->SetResult(r); 282 s->SetResult(r);
281 } 283 }
282 284
@@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { @@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
125 auto r = Convert(results[0], symbol_table_); 125 auto r = Convert(results[0], symbol_table_);
126 126
127 r.text = ApplyInverseTextNormalization(std::move(r.text)); 127 r.text = ApplyInverseTextNormalization(std::move(r.text));
  128 + r.text = ApplyHomophoneReplacer(std::move(r.text));
128 s->SetResult(r); 129 s->SetResult(r);
129 } 130 }
130 131
@@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
408 OfflineRecognizerImpl::OfflineRecognizerImpl( 408 OfflineRecognizerImpl::OfflineRecognizerImpl(
409 const OfflineRecognizerConfig &config) 409 const OfflineRecognizerConfig &config)
410 : config_(config) { 410 : config_(config) {
  411 + // TODO(fangjun): Refactor this function
  412 +
411 if (!config.rule_fsts.empty()) { 413 if (!config.rule_fsts.empty()) {
412 std::vector<std::string> files; 414 std::vector<std::string> files;
413 SplitStringToVector(config.rule_fsts, ",", false, &files); 415 SplitStringToVector(config.rule_fsts, ",", false, &files);
@@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( @@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
448 SHERPA_ONNX_LOGE("FST archives loaded!"); 450 SHERPA_ONNX_LOGE("FST archives loaded!");
449 } 451 }
450 } 452 }
  453 +
  454 + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
  455 + !config.hr.rule_fsts.empty()) {
  456 + auto hr_config = config.hr;
  457 + hr_config.debug = config.model_config.debug;
  458 + hr_ = std::make_unique<HomophoneReplacer>(hr_config);
  459 + }
451 } 460 }
452 461
453 template <typename Manager> 462 template <typename Manager>
@@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( @@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
495 } // for (; !reader->Done(); reader->Next()) 504 } // for (; !reader->Done(); reader->Next())
496 } // for (const auto &f : files) 505 } // for (const auto &f : files)
497 } // if (!config.rule_fars.empty()) 506 } // if (!config.rule_fars.empty())
  507 +
  508 + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
  509 + !config.hr.rule_fsts.empty()) {
  510 + auto hr_config = config.hr;
  511 + hr_config.debug = config.model_config.debug;
  512 + hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
  513 + }
498 } 514 }
499 515
500 std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( 516 std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
@@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( @@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
510 return text; 526 return text;
511 } 527 }
512 528
  529 +std::string OfflineRecognizerImpl::ApplyHomophoneReplacer(
  530 + std::string text) const {
  531 + if (hr_) {
  532 + text = hr_->Apply(text);
  533 + }
  534 +
  535 + return text;
  536 +}
  537 +
513 void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { 538 void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
514 config_ = config; 539 config_ = config;
515 } 540 }
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <vector> 10 #include <vector>
11 11
12 #include "kaldifst/csrc/text-normalizer.h" 12 #include "kaldifst/csrc/text-normalizer.h"
  13 +#include "sherpa-onnx/csrc/homophone-replacer.h"
13 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/offline-recognizer.h" 15 #include "sherpa-onnx/csrc/offline-recognizer.h"
15 #include "sherpa-onnx/csrc/offline-stream.h" 16 #include "sherpa-onnx/csrc/offline-stream.h"
@@ -48,12 +49,15 @@ class OfflineRecognizerImpl { @@ -48,12 +49,15 @@ class OfflineRecognizerImpl {
48 49
49 std::string ApplyInverseTextNormalization(std::string text) const; 50 std::string ApplyInverseTextNormalization(std::string text) const;
50 51
  52 + std::string ApplyHomophoneReplacer(std::string text) const;
  53 +
51 private: 54 private:
52 OfflineRecognizerConfig config_; 55 OfflineRecognizerConfig config_;
53 // for inverse text normalization. Used only if 56 // for inverse text normalization. Used only if
54 // config.rule_fsts is not empty or 57 // config.rule_fsts is not empty or
55 // config.rule_fars is not empty 58 // config.rule_fars is not empty
56 std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; 59 std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
  60 + std::unique_ptr<HomophoneReplacer> hr_;
57 }; 61 };
58 62
59 } // namespace sherpa_onnx 63 } // namespace sherpa_onnx
@@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { @@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
121 121
122 auto r = Convert(results[0], symbol_table_); 122 auto r = Convert(results[0], symbol_table_);
123 r.text = ApplyInverseTextNormalization(std::move(r.text)); 123 r.text = ApplyInverseTextNormalization(std::move(r.text));
  124 + r.text = ApplyHomophoneReplacer(std::move(r.text));
124 s->SetResult(r); 125 s->SetResult(r);
125 } catch (const Ort::Exception &ex) { 126 } catch (const Ort::Exception &ex) {
126 SHERPA_ONNX_LOGE( 127 SHERPA_ONNX_LOGE(
@@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
197 for (int32_t i = 0; i != n; ++i) { 197 for (int32_t i = 0; i != n; ++i) {
198 auto r = Convert(results[i], symbol_table_); 198 auto r = Convert(results[i], symbol_table_);
199 r.text = ApplyInverseTextNormalization(std::move(r.text)); 199 r.text = ApplyInverseTextNormalization(std::move(r.text));
  200 + r.text = ApplyHomophoneReplacer(std::move(r.text));
200 ss[i]->SetResult(r); 201 ss[i]->SetResult(r);
201 } 202 }
202 } 203 }
@@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { @@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
222 auto r = ConvertSenseVoiceResult(results[i], symbol_table_, 222 auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
223 frame_shift_ms, subsampling_factor); 223 frame_shift_ms, subsampling_factor);
224 r.text = ApplyInverseTextNormalization(std::move(r.text)); 224 r.text = ApplyInverseTextNormalization(std::move(r.text));
  225 + r.text = ApplyHomophoneReplacer(std::move(r.text));
225 ss[i]->SetResult(r); 226 ss[i]->SetResult(r);
226 } 227 }
227 } 228 }
@@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { @@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
295 subsampling_factor); 296 subsampling_factor);
296 297
297 r.text = ApplyInverseTextNormalization(std::move(r.text)); 298 r.text = ApplyInverseTextNormalization(std::move(r.text));
  299 + r.text = ApplyHomophoneReplacer(std::move(r.text));
298 s->SetResult(r); 300 s->SetResult(r);
299 } 301 }
300 302
@@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
239 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 239 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
240 model_->SubsamplingFactor()); 240 model_->SubsamplingFactor());
241 r.text = ApplyInverseTextNormalization(std::move(r.text)); 241 r.text = ApplyInverseTextNormalization(std::move(r.text));
  242 + r.text = ApplyHomophoneReplacer(std::move(r.text));
242 243
243 ss[i]->SetResult(r); 244 ss[i]->SetResult(r);
244 } 245 }
@@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
128 auto r = Convert(results[i], symbol_table_, frame_shift_ms, 128 auto r = Convert(results[i], symbol_table_, frame_shift_ms,
129 model_->SubsamplingFactor()); 129 model_->SubsamplingFactor());
130 r.text = ApplyInverseTextNormalization(std::move(r.text)); 130 r.text = ApplyInverseTextNormalization(std::move(r.text));
  131 + r.text = ApplyHomophoneReplacer(std::move(r.text));
131 132
132 ss[i]->SetResult(r); 133 ss[i]->SetResult(r);
133 } 134 }
@@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
160 160
161 std::string s = sym_table[i]; 161 std::string s = sym_table[i];
162 s = ApplyInverseTextNormalization(s); 162 s = ApplyInverseTextNormalization(s);
  163 + s = ApplyHomophoneReplacer(std::move(s));
163 164
164 text += s; 165 text += s;
165 r.tokens.push_back(s); 166 r.tokens.push_back(s);
@@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
28 model_config.Register(po); 28 model_config.Register(po);
29 lm_config.Register(po); 29 lm_config.Register(po);
30 ctc_fst_decoder_config.Register(po); 30 ctc_fst_decoder_config.Register(po);
  31 + hr.Register(po);
31 32
32 po->Register( 33 po->Register(
33 "decoding-method", &decoding_method, 34 "decoding-method", &decoding_method,
@@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const { @@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const {
120 } 121 }
121 } 122 }
122 123
  124 + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
  125 + !hr.Validate()) {
  126 + return false;
  127 + }
  128 +
123 return model_config.Validate(); 129 return model_config.Validate();
124 } 130 }
125 131
@@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const {
137 os << "hotwords_score=" << hotwords_score << ", "; 143 os << "hotwords_score=" << hotwords_score << ", ";
138 os << "blank_penalty=" << blank_penalty << ", "; 144 os << "blank_penalty=" << blank_penalty << ", ";
139 os << "rule_fsts=\"" << rule_fsts << "\", "; 145 os << "rule_fsts=\"" << rule_fsts << "\", ";
140 - os << "rule_fars=\"" << rule_fars << "\")"; 146 + os << "rule_fars=\"" << rule_fars << "\", ";
  147 + os << "hr=" << hr.ToString() << ")";
141 148
142 return os.str(); 149 return os.str();
143 } 150 }
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <vector> 10 #include <vector>
11 11
12 #include "sherpa-onnx/csrc/features.h" 12 #include "sherpa-onnx/csrc/features.h"
  13 +#include "sherpa-onnx/csrc/homophone-replacer.h"
13 #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" 14 #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
14 #include "sherpa-onnx/csrc/offline-lm-config.h" 15 #include "sherpa-onnx/csrc/offline-lm-config.h"
15 #include "sherpa-onnx/csrc/offline-model-config.h" 16 #include "sherpa-onnx/csrc/offline-model-config.h"
@@ -40,6 +41,7 @@ struct OfflineRecognizerConfig { @@ -40,6 +41,7 @@ struct OfflineRecognizerConfig {
40 41
41 // If there are multiple FST archives, they are applied from left to right. 42 // If there are multiple FST archives, they are applied from left to right.
42 std::string rule_fars; 43 std::string rule_fars;
  44 + HomophoneReplacerConfig hr;
43 45
44 // only greedy_search is implemented 46 // only greedy_search is implemented
45 // TODO(fangjun): Implement modified_beam_search 47 // TODO(fangjun): Implement modified_beam_search
@@ -52,7 +54,7 @@ struct OfflineRecognizerConfig { @@ -52,7 +54,7 @@ struct OfflineRecognizerConfig {
52 const std::string &decoding_method, int32_t max_active_paths, 54 const std::string &decoding_method, int32_t max_active_paths,
53 const std::string &hotwords_file, float hotwords_score, 55 const std::string &hotwords_file, float hotwords_score,
54 float blank_penalty, const std::string &rule_fsts, 56 float blank_penalty, const std::string &rule_fsts,
55 - const std::string &rule_fars) 57 + const std::string &rule_fars, const HomophoneReplacerConfig &hr)
56 : feat_config(feat_config), 58 : feat_config(feat_config),
57 model_config(model_config), 59 model_config(model_config),
58 lm_config(lm_config), 60 lm_config(lm_config),
@@ -63,7 +65,8 @@ struct OfflineRecognizerConfig { @@ -63,7 +65,8 @@ struct OfflineRecognizerConfig {
63 hotwords_score(hotwords_score), 65 hotwords_score(hotwords_score),
64 blank_penalty(blank_penalty), 66 blank_penalty(blank_penalty),
65 rule_fsts(rule_fsts), 67 rule_fsts(rule_fsts),
66 - rule_fars(rule_fars) {} 68 + rule_fars(rule_fars),
  69 + hr(hr) {}
67 70
68 void Register(ParseOptions *po); 71 void Register(ParseOptions *po);
69 bool Validate() const; 72 bool Validate() const;
@@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
201 auto r = 201 auto r =
202 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, 202 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
203 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 203 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
204 - r.text = ApplyInverseTextNormalization(r.text); 204 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  205 + r.text = ApplyHomophoneReplacer(std::move(r.text));
205 return r; 206 return r;
206 } 207 }
207 208
@@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) @@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
192 SHERPA_ONNX_LOGE("FST archives loaded!"); 192 SHERPA_ONNX_LOGE("FST archives loaded!");
193 } 193 }
194 } 194 }
  195 +
  196 + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
  197 + !config.hr.rule_fsts.empty()) {
  198 + auto hr_config = config.hr;
  199 + hr_config.debug = config.model_config.debug;
  200 + hr_ = std::make_unique<HomophoneReplacer>(hr_config);
  201 + }
195 } 202 }
196 203
197 template <typename Manager> 204 template <typename Manager>
@@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, @@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr,
239 } // for (; !reader->Done(); reader->Next()) 246 } // for (; !reader->Done(); reader->Next())
240 } // for (const auto &f : files) 247 } // for (const auto &f : files)
241 } // if (!config.rule_fars.empty()) 248 } // if (!config.rule_fars.empty())
  249 + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
  250 + !config.hr.rule_fsts.empty()) {
  251 + auto hr_config = config.hr;
  252 + hr_config.debug = config.model_config.debug;
  253 + hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
  254 + }
242 } 255 }
243 256
244 std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( 257 std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
@@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( @@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
254 return text; 267 return text;
255 } 268 }
256 269
  270 +std::string OnlineRecognizerImpl::ApplyHomophoneReplacer(
  271 + std::string text) const {
  272 + if (hr_) {
  273 + text = hr_->Apply(text);
  274 + }
  275 +
  276 + return text;
  277 +}
  278 +
257 #if __ANDROID_API__ >= 9 279 #if __ANDROID_API__ >= 9
258 template OnlineRecognizerImpl::OnlineRecognizerImpl( 280 template OnlineRecognizerImpl::OnlineRecognizerImpl(
259 AAssetManager *mgr, const OnlineRecognizerConfig &config); 281 AAssetManager *mgr, const OnlineRecognizerConfig &config);
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include <vector> 10 #include <vector>
11 11
12 #include "kaldifst/csrc/text-normalizer.h" 12 #include "kaldifst/csrc/text-normalizer.h"
  13 +#include "sherpa-onnx/csrc/homophone-replacer.h"
13 #include "sherpa-onnx/csrc/macros.h" 14 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/online-recognizer.h" 15 #include "sherpa-onnx/csrc/online-recognizer.h"
15 #include "sherpa-onnx/csrc/online-stream.h" 16 #include "sherpa-onnx/csrc/online-stream.h"
@@ -57,6 +58,7 @@ class OnlineRecognizerImpl { @@ -57,6 +58,7 @@ class OnlineRecognizerImpl {
57 virtual void Reset(OnlineStream *s) const = 0; 58 virtual void Reset(OnlineStream *s) const = 0;
58 59
59 std::string ApplyInverseTextNormalization(std::string text) const; 60 std::string ApplyInverseTextNormalization(std::string text) const;
  61 + std::string ApplyHomophoneReplacer(std::string text) const;
60 62
61 private: 63 private:
62 OnlineRecognizerConfig config_; 64 OnlineRecognizerConfig config_;
@@ -64,6 +66,7 @@ class OnlineRecognizerImpl { @@ -64,6 +66,7 @@ class OnlineRecognizerImpl {
64 // config.rule_fsts is not empty or 66 // config.rule_fsts is not empty or
65 // config.rule_fars is not empty 67 // config.rule_fars is not empty
66 std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; 68 std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
  69 + std::unique_ptr<HomophoneReplacer> hr_;
67 }; 70 };
68 71
69 } // namespace sherpa_onnx 72 } // namespace sherpa_onnx
@@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { @@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
169 auto decoder_result = s->GetParaformerResult(); 169 auto decoder_result = s->GetParaformerResult();
170 170
171 auto r = Convert(decoder_result, sym_); 171 auto r = Convert(decoder_result, sym_);
172 - r.text = ApplyInverseTextNormalization(r.text); 172 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  173 + r.text = ApplyHomophoneReplacer(std::move(r.text));
173 return r; 174 return r;
174 } 175 }
175 176
@@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
349 auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, 349 auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
350 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 350 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
351 r.text = ApplyInverseTextNormalization(std::move(r.text)); 351 r.text = ApplyInverseTextNormalization(std::move(r.text));
  352 + r.text = ApplyHomophoneReplacer(std::move(r.text));
352 return r; 353 return r;
353 } 354 }
354 355
@@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
391 // (the encoder state buffers are kept) 392 // (the encoder state buffers are kept)
392 for (const auto &it : last_result.hyps) { 393 for (const auto &it : last_result.hyps) {
393 auto h = it.second; 394 auto h = it.second;
394 - r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,  
395 - h.ys.end()), 395 + r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, h.ys.end()),
396 h.log_prob}); 396 h.log_prob});
397 } 397 }
398 398
399 - r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,  
400 - last_result.tokens.end()); 399 + r.tokens = std::vector<int64_t>(last_result.tokens.end() - context_size,
  400 + last_result.tokens.end());
401 } else { 401 } else {
402 - if(config_.reset_encoder) { 402 + if (config_.reset_encoder) {
403 // reset encoder states, use blanks as 'ys' context 403 // reset encoder states, use blanks as 'ys' context
404 s->SetStates(model_->GetEncoderInitStates()); 404 s->SetStates(model_->GetEncoderInitStates());
405 } 405 }
@@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { @@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
100 subsampling_factor, s->GetCurrentSegment(), 100 subsampling_factor, s->GetCurrentSegment(),
101 s->GetNumFramesSinceStart()); 101 s->GetNumFramesSinceStart());
102 r.text = ApplyInverseTextNormalization(std::move(r.text)); 102 r.text = ApplyInverseTextNormalization(std::move(r.text));
  103 + r.text = ApplyHomophoneReplacer(std::move(r.text));
103 return r; 104 return r;
104 } 105 }
105 106
@@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
88 endpoint_config.Register(po); 88 endpoint_config.Register(po);
89 lm_config.Register(po); 89 lm_config.Register(po);
90 ctc_fst_decoder_config.Register(po); 90 ctc_fst_decoder_config.Register(po);
  91 + hr.Register(po);
91 92
92 po->Register("enable-endpoint", &enable_endpoint, 93 po->Register("enable-endpoint", &enable_endpoint,
93 "True to enable endpoint detection. False to disable it."); 94 "True to enable endpoint detection. False to disable it.");
@@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const { @@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const {
182 } 183 }
183 } 184 }
184 185
  186 + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
  187 + !hr.Validate()) {
  188 + return false;
  189 + }
  190 +
185 return model_config.Validate(); 191 return model_config.Validate();
186 } 192 }
187 193
@@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const {
203 os << "temperature_scale=" << temperature_scale << ", "; 209 os << "temperature_scale=" << temperature_scale << ", ";
204 os << "rule_fsts=\"" << rule_fsts << "\", "; 210 os << "rule_fsts=\"" << rule_fsts << "\", ";
205 os << "rule_fars=\"" << rule_fars << "\", "; 211 os << "rule_fars=\"" << rule_fars << "\", ";
206 - os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")"; 212 + os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", ";
  213 + os << "hr=" << hr.ToString() << ")";
207 214
208 return os.str(); 215 return os.str();
209 } 216 }
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 11
12 #include "sherpa-onnx/csrc/endpoint.h" 12 #include "sherpa-onnx/csrc/endpoint.h"
13 #include "sherpa-onnx/csrc/features.h" 13 #include "sherpa-onnx/csrc/features.h"
  14 +#include "sherpa-onnx/csrc/homophone-replacer.h"
14 #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" 15 #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
15 #include "sherpa-onnx/csrc/online-lm-config.h" 16 #include "sherpa-onnx/csrc/online-lm-config.h"
16 #include "sherpa-onnx/csrc/online-model-config.h" 17 #include "sherpa-onnx/csrc/online-model-config.h"
@@ -107,6 +108,8 @@ struct OnlineRecognizerConfig { @@ -107,6 +108,8 @@ struct OnlineRecognizerConfig {
107 // currently only in `OnlineRecognizerTransducerImpl`. 108 // currently only in `OnlineRecognizerTransducerImpl`.
108 bool reset_encoder = false; 109 bool reset_encoder = false;
109 110
  111 + HomophoneReplacerConfig hr;
  112 +
110 /// used only for modified_beam_search, if hotwords_buf is non-empty, 113 /// used only for modified_beam_search, if hotwords_buf is non-empty,
111 /// the hotwords will be loaded from the buffered string instead of from the 114 /// the hotwords will be loaded from the buffered string instead of from the
112 /// "hotwords_file" 115 /// "hotwords_file"
@@ -123,7 +126,7 @@ struct OnlineRecognizerConfig { @@ -123,7 +126,7 @@ struct OnlineRecognizerConfig {
123 int32_t max_active_paths, const std::string &hotwords_file, 126 int32_t max_active_paths, const std::string &hotwords_file,
124 float hotwords_score, float blank_penalty, float temperature_scale, 127 float hotwords_score, float blank_penalty, float temperature_scale,
125 const std::string &rule_fsts, const std::string &rule_fars, 128 const std::string &rule_fsts, const std::string &rule_fars,
126 - bool reset_encoder) 129 + bool reset_encoder, const HomophoneReplacerConfig &hr)
127 : feat_config(feat_config), 130 : feat_config(feat_config),
128 model_config(model_config), 131 model_config(model_config),
129 lm_config(lm_config), 132 lm_config(lm_config),
@@ -138,7 +141,8 @@ struct OnlineRecognizerConfig { @@ -138,7 +141,8 @@ struct OnlineRecognizerConfig {
138 temperature_scale(temperature_scale), 141 temperature_scale(temperature_scale),
139 rule_fsts(rule_fsts), 142 rule_fsts(rule_fsts),
140 rule_fars(rule_fars), 143 rule_fars(rule_fars),
141 - reset_encoder(reset_encoder) {} 144 + reset_encoder(reset_encoder),
  145 + hr(hr) {}
142 146
143 void Register(ParseOptions *po); 147 void Register(ParseOptions *po);
144 bool Validate() const; 148 bool Validate() const;
@@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { @@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
89 auto r = 89 auto r =
90 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, 90 ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
91 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 91 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
92 - r.text = ApplyInverseTextNormalization(r.text); 92 + r.text = ApplyInverseTextNormalization(std::move(r.text));
  93 + r.text = ApplyHomophoneReplacer(std::move(r.text));
93 return r; 94 return r;
94 } 95 }
95 96
@@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { @@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
177 auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, 177 auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
178 s->GetCurrentSegment(), s->GetNumFramesSinceStart()); 178 s->GetCurrentSegment(), s->GetNumFramesSinceStart());
179 r.text = ApplyInverseTextNormalization(std::move(r.text)); 179 r.text = ApplyInverseTextNormalization(std::move(r.text));
  180 + r.text = ApplyHomophoneReplacer(std::move(r.text));
180 return r; 181 return r;
181 } 182 }
182 183
@@ -7,6 +7,7 @@ set(srcs @@ -7,6 +7,7 @@ set(srcs
7 display.cc 7 display.cc
8 endpoint.cc 8 endpoint.cc
9 features.cc 9 features.cc
  10 + homophone-replacer.cc
10 keyword-spotter.cc 11 keyword-spotter.cc
11 offline-ctc-fst-decoder-config.cc 12 offline-ctc-fst-decoder-config.cc
12 offline-dolphin-model-config.cc 13 offline-dolphin-model-config.cc
  1 +// sherpa-onnx/python/csrc/homophone-replacer.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/homophone-replacer.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/homophone-replacer.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindHomophoneReplacer(py::module *m) {
  14 + using PyClass = HomophoneReplacerConfig;
  15 + py::class_<PyClass>(*m, "HomophoneReplacerConfig")
  16 + .def(py::init<>())
  17 + .def(py::init<const std::string &, const std::string &,
  18 + const std::string &, bool>(),
  19 + py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"),
  20 + py::arg("debug") = false)
  21 + .def_readwrite("dict_dir", &PyClass::dict_dir)
  22 + .def_readwrite("lexicon", &PyClass::lexicon)
  23 + .def_readwrite("rule_fsts", &PyClass::rule_fsts)
  24 + .def_readwrite("debug", &PyClass::debug)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/homophone-replacer.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindHomophoneReplacer(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
@@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
17 .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, 17 .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
18 const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, 18 const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
19 const std::string &, int32_t, const std::string &, float, 19 const std::string &, int32_t, const std::string &, float,
20 - float, const std::string &, const std::string &>(), 20 + float, const std::string &, const std::string &,
  21 + const HomophoneReplacerConfig &>(),
21 py::arg("feat_config"), py::arg("model_config"), 22 py::arg("feat_config"), py::arg("model_config"),
22 py::arg("lm_config") = OfflineLMConfig(), 23 py::arg("lm_config") = OfflineLMConfig(),
23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 24 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
24 py::arg("decoding_method") = "greedy_search", 25 py::arg("decoding_method") = "greedy_search",
25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 26 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
26 py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, 27 py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
27 - py::arg("rule_fsts") = "", py::arg("rule_fars") = "") 28 + py::arg("rule_fsts") = "", py::arg("rule_fars") = "",
  29 + py::arg("hr") = HomophoneReplacerConfig{})
28 .def_readwrite("feat_config", &PyClass::feat_config) 30 .def_readwrite("feat_config", &PyClass::feat_config)
29 .def_readwrite("model_config", &PyClass::model_config) 31 .def_readwrite("model_config", &PyClass::model_config)
30 .def_readwrite("lm_config", &PyClass::lm_config) 32 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
36 .def_readwrite("blank_penalty", &PyClass::blank_penalty) 38 .def_readwrite("blank_penalty", &PyClass::blank_penalty)
37 .def_readwrite("rule_fsts", &PyClass::rule_fsts) 39 .def_readwrite("rule_fsts", &PyClass::rule_fsts)
38 .def_readwrite("rule_fars", &PyClass::rule_fars) 40 .def_readwrite("rule_fars", &PyClass::rule_fars)
  41 + .def_readwrite("hr", &PyClass::hr)
39 .def("__str__", &PyClass::ToString); 42 .def("__str__", &PyClass::ToString);
40 } 43 }
41 44
@@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
58 const OnlineLMConfig &, const EndpointConfig &, 58 const OnlineLMConfig &, const EndpointConfig &,
59 const OnlineCtcFstDecoderConfig &, bool, 59 const OnlineCtcFstDecoderConfig &, bool,
60 const std::string &, int32_t, const std::string &, float, 60 const std::string &, int32_t, const std::string &, float,
61 - float, float, const std::string &, const std::string &, bool>(), 61 + float, float, const std::string &, const std::string &,
  62 + bool, const HomophoneReplacerConfig &>(),
62 py::arg("feat_config"), py::arg("model_config"), 63 py::arg("feat_config"), py::arg("model_config"),
63 py::arg("lm_config") = OnlineLMConfig(), 64 py::arg("lm_config") = OnlineLMConfig(),
64 py::arg("endpoint_config") = EndpointConfig(), 65 py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
67 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 68 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
68 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, 69 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
69 py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", 70 py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
70 - py::arg("rule_fars") = "", py::arg("reset_encoder") = false) 71 + py::arg("rule_fars") = "", py::arg("reset_encoder") = false,
  72 + py::arg("hr") = HomophoneReplacerConfig{})
71 .def_readwrite("feat_config", &PyClass::feat_config) 73 .def_readwrite("feat_config", &PyClass::feat_config)
72 .def_readwrite("model_config", &PyClass::model_config) 74 .def_readwrite("model_config", &PyClass::model_config)
73 .def_readwrite("lm_config", &PyClass::lm_config) 75 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
83 .def_readwrite("rule_fsts", &PyClass::rule_fsts) 85 .def_readwrite("rule_fsts", &PyClass::rule_fsts)
84 .def_readwrite("rule_fars", &PyClass::rule_fars) 86 .def_readwrite("rule_fars", &PyClass::rule_fars)
85 .def_readwrite("reset_encoder", &PyClass::reset_encoder) 87 .def_readwrite("reset_encoder", &PyClass::reset_encoder)
  88 + .def_readwrite("hr", &PyClass::hr)
86 .def("__str__", &PyClass::ToString); 89 .def("__str__", &PyClass::ToString);
87 } 90 }
88 91
@@ -10,6 +10,7 @@ @@ -10,6 +10,7 @@
10 #include "sherpa-onnx/python/csrc/display.h" 10 #include "sherpa-onnx/python/csrc/display.h"
11 #include "sherpa-onnx/python/csrc/endpoint.h" 11 #include "sherpa-onnx/python/csrc/endpoint.h"
12 #include "sherpa-onnx/python/csrc/features.h" 12 #include "sherpa-onnx/python/csrc/features.h"
  13 +#include "sherpa-onnx/python/csrc/homophone-replacer.h"
13 #include "sherpa-onnx/python/csrc/keyword-spotter.h" 14 #include "sherpa-onnx/python/csrc/keyword-spotter.h"
14 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" 15 #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
15 #include "sherpa-onnx/python/csrc/offline-lm-config.h" 16 #include "sherpa-onnx/python/csrc/offline-lm-config.h"
@@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
51 PybindAudioTagging(&m); 52 PybindAudioTagging(&m);
52 PybindOfflinePunctuation(&m); 53 PybindOfflinePunctuation(&m);
53 PybindOnlinePunctuation(&m); 54 PybindOnlinePunctuation(&m);
  55 + PybindHomophoneReplacer(&m);
54 56
55 PybindFeatures(&m); 57 PybindFeatures(&m);
56 PybindOnlineCtcFstDecoderConfig(&m); 58 PybindOnlineCtcFstDecoderConfig(&m);
@@ -5,6 +5,7 @@ from typing import List, Optional @@ -5,6 +5,7 @@ from typing import List, Optional
5 5
6 from _sherpa_onnx import ( 6 from _sherpa_onnx import (
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
  8 + HomophoneReplacerConfig,
8 OfflineCtcFstDecoderConfig, 9 OfflineCtcFstDecoderConfig,
9 OfflineDolphinModelConfig, 10 OfflineDolphinModelConfig,
10 OfflineFireRedAsrModelConfig, 11 OfflineFireRedAsrModelConfig,
@@ -64,6 +65,9 @@ class OfflineRecognizer(object): @@ -64,6 +65,9 @@ class OfflineRecognizer(object):
64 rule_fars: str = "", 65 rule_fars: str = "",
65 lm: str = "", 66 lm: str = "",
66 lm_scale: float = 0.1, 67 lm_scale: float = 0.1,
  68 + hr_dict_dir: str = "",
  69 + hr_rule_fsts: str = "",
  70 + hr_lexicon: str = "",
67 ): 71 ):
68 """ 72 """
69 Please refer to 73 Please refer to
@@ -181,6 +185,11 @@ class OfflineRecognizer(object): @@ -181,6 +185,11 @@ class OfflineRecognizer(object):
181 blank_penalty=blank_penalty, 185 blank_penalty=blank_penalty,
182 rule_fsts=rule_fsts, 186 rule_fsts=rule_fsts,
183 rule_fars=rule_fars, 187 rule_fars=rule_fars,
  188 + hr=HomophoneReplacerConfig(
  189 + dict_dir=hr_dict_dir,
  190 + lexicon=hr_lexicon,
  191 + rule_fsts=hr_rule_fsts,
  192 + ),
184 ) 193 )
185 self.recognizer = _Recognizer(recognizer_config) 194 self.recognizer = _Recognizer(recognizer_config)
186 self.config = recognizer_config 195 self.config = recognizer_config
@@ -201,6 +210,9 @@ class OfflineRecognizer(object): @@ -201,6 +210,9 @@ class OfflineRecognizer(object):
201 use_itn: bool = False, 210 use_itn: bool = False,
202 rule_fsts: str = "", 211 rule_fsts: str = "",
203 rule_fars: str = "", 212 rule_fars: str = "",
  213 + hr_dict_dir: str = "",
  214 + hr_rule_fsts: str = "",
  215 + hr_lexicon: str = "",
204 ): 216 ):
205 """ 217 """
206 Please refer to 218 Please refer to
@@ -263,6 +275,11 @@ class OfflineRecognizer(object): @@ -263,6 +275,11 @@ class OfflineRecognizer(object):
263 decoding_method=decoding_method, 275 decoding_method=decoding_method,
264 rule_fsts=rule_fsts, 276 rule_fsts=rule_fsts,
265 rule_fars=rule_fars, 277 rule_fars=rule_fars,
  278 + hr=HomophoneReplacerConfig(
  279 + dict_dir=hr_dict_dir,
  280 + lexicon=hr_lexicon,
  281 + rule_fsts=hr_rule_fsts,
  282 + ),
266 ) 283 )
267 self.recognizer = _Recognizer(recognizer_config) 284 self.recognizer = _Recognizer(recognizer_config)
268 self.config = recognizer_config 285 self.config = recognizer_config
@@ -281,6 +298,9 @@ class OfflineRecognizer(object): @@ -281,6 +298,9 @@ class OfflineRecognizer(object):
281 provider: str = "cpu", 298 provider: str = "cpu",
282 rule_fsts: str = "", 299 rule_fsts: str = "",
283 rule_fars: str = "", 300 rule_fars: str = "",
  301 + hr_dict_dir: str = "",
  302 + hr_rule_fsts: str = "",
  303 + hr_lexicon: str = "",
284 ): 304 ):
285 """ 305 """
286 Please refer to 306 Please refer to
@@ -336,6 +356,11 @@ class OfflineRecognizer(object): @@ -336,6 +356,11 @@ class OfflineRecognizer(object):
336 decoding_method=decoding_method, 356 decoding_method=decoding_method,
337 rule_fsts=rule_fsts, 357 rule_fsts=rule_fsts,
338 rule_fars=rule_fars, 358 rule_fars=rule_fars,
  359 + hr=HomophoneReplacerConfig(
  360 + dict_dir=hr_dict_dir,
  361 + lexicon=hr_lexicon,
  362 + rule_fsts=hr_rule_fsts,
  363 + ),
339 ) 364 )
340 self.recognizer = _Recognizer(recognizer_config) 365 self.recognizer = _Recognizer(recognizer_config)
341 self.config = recognizer_config 366 self.config = recognizer_config
@@ -354,6 +379,9 @@ class OfflineRecognizer(object): @@ -354,6 +379,9 @@ class OfflineRecognizer(object):
354 provider: str = "cpu", 379 provider: str = "cpu",
355 rule_fsts: str = "", 380 rule_fsts: str = "",
356 rule_fars: str = "", 381 rule_fars: str = "",
  382 + hr_dict_dir: str = "",
  383 + hr_rule_fsts: str = "",
  384 + hr_lexicon: str = "",
357 ): 385 ):
358 """ 386 """
359 Please refer to 387 Please refer to
@@ -411,6 +439,9 @@ class OfflineRecognizer(object): @@ -411,6 +439,9 @@ class OfflineRecognizer(object):
411 decoding_method=decoding_method, 439 decoding_method=decoding_method,
412 rule_fsts=rule_fsts, 440 rule_fsts=rule_fsts,
413 rule_fars=rule_fars, 441 rule_fars=rule_fars,
  442 + hr=HomophoneReplacerConfig(
  443 + dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts
  444 + ),
414 ) 445 )
415 self.recognizer = _Recognizer(recognizer_config) 446 self.recognizer = _Recognizer(recognizer_config)
416 self.config = recognizer_config 447 self.config = recognizer_config
@@ -429,6 +460,9 @@ class OfflineRecognizer(object): @@ -429,6 +460,9 @@ class OfflineRecognizer(object):
429 provider: str = "cpu", 460 provider: str = "cpu",
430 rule_fsts: str = "", 461 rule_fsts: str = "",
431 rule_fars: str = "", 462 rule_fars: str = "",
  463 + hr_dict_dir: str = "",
  464 + hr_rule_fsts: str = "",
  465 + hr_lexicon: str = "",
432 ): 466 ):
433 """ 467 """
434 Please refer to 468 Please refer to
@@ -483,6 +517,11 @@ class OfflineRecognizer(object): @@ -483,6 +517,11 @@ class OfflineRecognizer(object):
483 decoding_method=decoding_method, 517 decoding_method=decoding_method,
484 rule_fsts=rule_fsts, 518 rule_fsts=rule_fsts,
485 rule_fars=rule_fars, 519 rule_fars=rule_fars,
  520 + hr=HomophoneReplacerConfig(
  521 + dict_dir=hr_dict_dir,
  522 + lexicon=hr_lexicon,
  523 + rule_fsts=hr_rule_fsts,
  524 + ),
486 ) 525 )
487 self.recognizer = _Recognizer(recognizer_config) 526 self.recognizer = _Recognizer(recognizer_config)
488 self.config = recognizer_config 527 self.config = recognizer_config
@@ -501,6 +540,9 @@ class OfflineRecognizer(object): @@ -501,6 +540,9 @@ class OfflineRecognizer(object):
501 provider: str = "cpu", 540 provider: str = "cpu",
502 rule_fsts: str = "", 541 rule_fsts: str = "",
503 rule_fars: str = "", 542 rule_fars: str = "",
  543 + hr_dict_dir: str = "",
  544 + hr_rule_fsts: str = "",
  545 + hr_lexicon: str = "",
504 ): 546 ):
505 """ 547 """
506 Please refer to 548 Please refer to
@@ -557,6 +599,11 @@ class OfflineRecognizer(object): @@ -557,6 +599,11 @@ class OfflineRecognizer(object):
557 decoding_method=decoding_method, 599 decoding_method=decoding_method,
558 rule_fsts=rule_fsts, 600 rule_fsts=rule_fsts,
559 rule_fars=rule_fars, 601 rule_fars=rule_fars,
  602 + hr=HomophoneReplacerConfig(
  603 + dict_dir=hr_dict_dir,
  604 + lexicon=hr_lexicon,
  605 + rule_fsts=hr_rule_fsts,
  606 + ),
560 ) 607 )
561 self.recognizer = _Recognizer(recognizer_config) 608 self.recognizer = _Recognizer(recognizer_config)
562 self.config = recognizer_config 609 self.config = recognizer_config
@@ -577,6 +624,9 @@ class OfflineRecognizer(object): @@ -577,6 +624,9 @@ class OfflineRecognizer(object):
577 tail_paddings: int = -1, 624 tail_paddings: int = -1,
578 rule_fsts: str = "", 625 rule_fsts: str = "",
579 rule_fars: str = "", 626 rule_fars: str = "",
  627 + hr_dict_dir: str = "",
  628 + hr_rule_fsts: str = "",
  629 + hr_lexicon: str = "",
580 ): 630 ):
581 """ 631 """
582 Please refer to 632 Please refer to
@@ -647,6 +697,11 @@ class OfflineRecognizer(object): @@ -647,6 +697,11 @@ class OfflineRecognizer(object):
647 decoding_method=decoding_method, 697 decoding_method=decoding_method,
648 rule_fsts=rule_fsts, 698 rule_fsts=rule_fsts,
649 rule_fars=rule_fars, 699 rule_fars=rule_fars,
  700 + hr=HomophoneReplacerConfig(
  701 + dict_dir=hr_dict_dir,
  702 + lexicon=hr_lexicon,
  703 + rule_fsts=hr_rule_fsts,
  704 + ),
650 ) 705 )
651 self.recognizer = _Recognizer(recognizer_config) 706 self.recognizer = _Recognizer(recognizer_config)
652 self.config = recognizer_config 707 self.config = recognizer_config
@@ -664,6 +719,9 @@ class OfflineRecognizer(object): @@ -664,6 +719,9 @@ class OfflineRecognizer(object):
664 provider: str = "cpu", 719 provider: str = "cpu",
665 rule_fsts: str = "", 720 rule_fsts: str = "",
666 rule_fars: str = "", 721 rule_fars: str = "",
  722 + hr_dict_dir: str = "",
  723 + hr_rule_fsts: str = "",
  724 + hr_lexicon: str = "",
667 ): 725 ):
668 """ 726 """
669 Please refer to 727 Please refer to
@@ -719,6 +777,11 @@ class OfflineRecognizer(object): @@ -719,6 +777,11 @@ class OfflineRecognizer(object):
719 decoding_method=decoding_method, 777 decoding_method=decoding_method,
720 rule_fsts=rule_fsts, 778 rule_fsts=rule_fsts,
721 rule_fars=rule_fars, 779 rule_fars=rule_fars,
  780 + hr=HomophoneReplacerConfig(
  781 + dict_dir=hr_dict_dir,
  782 + lexicon=hr_lexicon,
  783 + rule_fsts=hr_rule_fsts,
  784 + ),
722 ) 785 )
723 self.recognizer = _Recognizer(recognizer_config) 786 self.recognizer = _Recognizer(recognizer_config)
724 self.config = recognizer_config 787 self.config = recognizer_config
@@ -738,6 +801,9 @@ class OfflineRecognizer(object): @@ -738,6 +801,9 @@ class OfflineRecognizer(object):
738 provider: str = "cpu", 801 provider: str = "cpu",
739 rule_fsts: str = "", 802 rule_fsts: str = "",
740 rule_fars: str = "", 803 rule_fars: str = "",
  804 + hr_dict_dir: str = "",
  805 + hr_rule_fsts: str = "",
  806 + hr_lexicon: str = "",
741 ): 807 ):
742 """ 808 """
743 Please refer to 809 Please refer to
@@ -800,6 +866,11 @@ class OfflineRecognizer(object): @@ -800,6 +866,11 @@ class OfflineRecognizer(object):
800 decoding_method=decoding_method, 866 decoding_method=decoding_method,
801 rule_fsts=rule_fsts, 867 rule_fsts=rule_fsts,
802 rule_fars=rule_fars, 868 rule_fars=rule_fars,
  869 + hr=HomophoneReplacerConfig(
  870 + dict_dir=hr_dict_dir,
  871 + lexicon=hr_lexicon,
  872 + rule_fsts=hr_rule_fsts,
  873 + ),
803 ) 874 )
804 self.recognizer = _Recognizer(recognizer_config) 875 self.recognizer = _Recognizer(recognizer_config)
805 self.config = recognizer_config 876 self.config = recognizer_config
@@ -818,6 +889,9 @@ class OfflineRecognizer(object): @@ -818,6 +889,9 @@ class OfflineRecognizer(object):
818 provider: str = "cpu", 889 provider: str = "cpu",
819 rule_fsts: str = "", 890 rule_fsts: str = "",
820 rule_fars: str = "", 891 rule_fars: str = "",
  892 + hr_dict_dir: str = "",
  893 + hr_rule_fsts: str = "",
  894 + hr_lexicon: str = "",
821 ): 895 ):
822 """ 896 """
823 Please refer to 897 Please refer to
@@ -873,6 +947,11 @@ class OfflineRecognizer(object): @@ -873,6 +947,11 @@ class OfflineRecognizer(object):
873 decoding_method=decoding_method, 947 decoding_method=decoding_method,
874 rule_fsts=rule_fsts, 948 rule_fsts=rule_fsts,
875 rule_fars=rule_fars, 949 rule_fars=rule_fars,
  950 + hr=HomophoneReplacerConfig(
  951 + dict_dir=hr_dict_dir,
  952 + lexicon=hr_lexicon,
  953 + rule_fsts=hr_rule_fsts,
  954 + ),
876 ) 955 )
877 self.recognizer = _Recognizer(recognizer_config) 956 self.recognizer = _Recognizer(recognizer_config)
878 self.config = recognizer_config 957 self.config = recognizer_config
@@ -891,6 +970,9 @@ class OfflineRecognizer(object): @@ -891,6 +970,9 @@ class OfflineRecognizer(object):
891 provider: str = "cpu", 970 provider: str = "cpu",
892 rule_fsts: str = "", 971 rule_fsts: str = "",
893 rule_fars: str = "", 972 rule_fars: str = "",
  973 + hr_dict_dir: str = "",
  974 + hr_rule_fsts: str = "",
  975 + hr_lexicon: str = "",
894 ): 976 ):
895 """ 977 """
896 Please refer to 978 Please refer to
@@ -947,6 +1029,11 @@ class OfflineRecognizer(object): @@ -947,6 +1029,11 @@ class OfflineRecognizer(object):
947 decoding_method=decoding_method, 1029 decoding_method=decoding_method,
948 rule_fsts=rule_fsts, 1030 rule_fsts=rule_fsts,
949 rule_fars=rule_fars, 1031 rule_fars=rule_fars,
  1032 + hr=HomophoneReplacerConfig(
  1033 + dict_dir=hr_dict_dir,
  1034 + lexicon=hr_lexicon,
  1035 + rule_fsts=hr_rule_fsts,
  1036 + ),
950 ) 1037 )
951 self.recognizer = _Recognizer(recognizer_config) 1038 self.recognizer = _Recognizer(recognizer_config)
952 self.config = recognizer_config 1039 self.config = recognizer_config
@@ -3,25 +3,26 @@ from pathlib import Path @@ -3,25 +3,26 @@ from pathlib import Path
3 from typing import List, Optional 3 from typing import List, Optional
4 4
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
  6 + CudaConfig,
6 EndpointConfig, 7 EndpointConfig,
7 FeatureExtractorConfig, 8 FeatureExtractorConfig,
  9 + HomophoneReplacerConfig,
  10 + OnlineCtcFstDecoderConfig,
8 OnlineLMConfig, 11 OnlineLMConfig,
9 OnlineModelConfig, 12 OnlineModelConfig,
  13 + OnlineNeMoCtcModelConfig,
10 OnlineParaformerModelConfig, 14 OnlineParaformerModelConfig,
11 ) 15 )
12 from _sherpa_onnx import OnlineRecognizer as _Recognizer 16 from _sherpa_onnx import OnlineRecognizer as _Recognizer
13 from _sherpa_onnx import ( 17 from _sherpa_onnx import (
14 - CudaConfig,  
15 - TensorrtConfig,  
16 - ProviderConfig,  
17 OnlineRecognizerConfig, 18 OnlineRecognizerConfig,
18 OnlineRecognizerResult, 19 OnlineRecognizerResult,
19 OnlineStream, 20 OnlineStream,
20 OnlineTransducerModelConfig, 21 OnlineTransducerModelConfig,
21 OnlineWenetCtcModelConfig, 22 OnlineWenetCtcModelConfig,
22 - OnlineNeMoCtcModelConfig,  
23 OnlineZipformer2CtcModelConfig, 23 OnlineZipformer2CtcModelConfig,
24 - OnlineCtcFstDecoderConfig, 24 + ProviderConfig,
  25 + TensorrtConfig,
25 ) 26 )
26 27
27 28
@@ -82,9 +83,12 @@ class OnlineRecognizer(object): @@ -82,9 +83,12 @@ class OnlineRecognizer(object):
82 trt_detailed_build_log: bool = False, 83 trt_detailed_build_log: bool = False,
83 trt_engine_cache_enable: bool = True, 84 trt_engine_cache_enable: bool = True,
84 trt_timing_cache_enable: bool = True, 85 trt_timing_cache_enable: bool = True,
85 - trt_engine_cache_path: str ="",  
86 - trt_timing_cache_path: str ="", 86 + trt_engine_cache_path: str = "",
  87 + trt_timing_cache_path: str = "",
87 trt_dump_subgraphs: bool = False, 88 trt_dump_subgraphs: bool = False,
  89 + hr_dict_dir: str = "",
  90 + hr_rule_fsts: str = "",
  91 + hr_lexicon: str = "",
88 ): 92 ):
89 """ 93 """
90 Please refer to 94 Please refer to
@@ -228,27 +232,27 @@ class OnlineRecognizer(object): @@ -228,27 +232,27 @@ class OnlineRecognizer(object):
228 ) 232 )
229 233
230 cuda_config = CudaConfig( 234 cuda_config = CudaConfig(
231 - cudnn_conv_algo_search=cudnn_conv_algo_search, 235 + cudnn_conv_algo_search=cudnn_conv_algo_search,
232 ) 236 )
233 237
234 trt_config = TensorrtConfig( 238 trt_config = TensorrtConfig(
235 - trt_max_workspace_size=trt_max_workspace_size,  
236 - trt_max_partition_iterations=trt_max_partition_iterations,  
237 - trt_min_subgraph_size=trt_min_subgraph_size,  
238 - trt_fp16_enable=trt_fp16_enable,  
239 - trt_detailed_build_log=trt_detailed_build_log,  
240 - trt_engine_cache_enable=trt_engine_cache_enable,  
241 - trt_timing_cache_enable=trt_timing_cache_enable,  
242 - trt_engine_cache_path=trt_engine_cache_path,  
243 - trt_timing_cache_path=trt_timing_cache_path,  
244 - trt_dump_subgraphs=trt_dump_subgraphs, 239 + trt_max_workspace_size=trt_max_workspace_size,
  240 + trt_max_partition_iterations=trt_max_partition_iterations,
  241 + trt_min_subgraph_size=trt_min_subgraph_size,
  242 + trt_fp16_enable=trt_fp16_enable,
  243 + trt_detailed_build_log=trt_detailed_build_log,
  244 + trt_engine_cache_enable=trt_engine_cache_enable,
  245 + trt_timing_cache_enable=trt_timing_cache_enable,
  246 + trt_engine_cache_path=trt_engine_cache_path,
  247 + trt_timing_cache_path=trt_timing_cache_path,
  248 + trt_dump_subgraphs=trt_dump_subgraphs,
245 ) 249 )
246 250
247 provider_config = ProviderConfig( 251 provider_config = ProviderConfig(
248 - trt_config=trt_config,  
249 - cuda_config=cuda_config,  
250 - provider=provider,  
251 - device=device, 252 + trt_config=trt_config,
  253 + cuda_config=cuda_config,
  254 + provider=provider,
  255 + device=device,
252 ) 256 )
253 257
254 model_config = OnlineModelConfig( 258 model_config = OnlineModelConfig(
@@ -311,6 +315,11 @@ class OnlineRecognizer(object): @@ -311,6 +315,11 @@ class OnlineRecognizer(object):
311 rule_fsts=rule_fsts, 315 rule_fsts=rule_fsts,
312 rule_fars=rule_fars, 316 rule_fars=rule_fars,
313 reset_encoder=reset_encoder, 317 reset_encoder=reset_encoder,
  318 + hr=HomophoneReplacerConfig(
  319 + dict_dir=hr_dict_dir,
  320 + lexicon=hr_lexicon,
  321 + rule_fsts=hr_rule_fsts,
  322 + ),
314 ) 323 )
315 324
316 self.recognizer = _Recognizer(recognizer_config) 325 self.recognizer = _Recognizer(recognizer_config)
@@ -336,6 +345,9 @@ class OnlineRecognizer(object): @@ -336,6 +345,9 @@ class OnlineRecognizer(object):
336 rule_fsts: str = "", 345 rule_fsts: str = "",
337 rule_fars: str = "", 346 rule_fars: str = "",
338 device: int = 0, 347 device: int = 0,
  348 + hr_dict_dir: str = "",
  349 + hr_rule_fsts: str = "",
  350 + hr_lexicon: str = "",
339 ): 351 ):
340 """ 352 """
341 Please refer to 353 Please refer to
@@ -402,8 +414,8 @@ class OnlineRecognizer(object): @@ -402,8 +414,8 @@ class OnlineRecognizer(object):
402 ) 414 )
403 415
404 provider_config = ProviderConfig( 416 provider_config = ProviderConfig(
405 - provider=provider,  
406 - device=device, 417 + provider=provider,
  418 + device=device,
407 ) 419 )
408 420
409 model_config = OnlineModelConfig( 421 model_config = OnlineModelConfig(
@@ -434,6 +446,11 @@ class OnlineRecognizer(object): @@ -434,6 +446,11 @@ class OnlineRecognizer(object):
434 decoding_method=decoding_method, 446 decoding_method=decoding_method,
435 rule_fsts=rule_fsts, 447 rule_fsts=rule_fsts,
436 rule_fars=rule_fars, 448 rule_fars=rule_fars,
  449 + hr=HomophoneReplacerConfig(
  450 + dict_dir=hr_dict_dir,
  451 + lexicon=hr_lexicon,
  452 + rule_fsts=hr_rule_fsts,
  453 + ),
437 ) 454 )
438 455
439 self.recognizer = _Recognizer(recognizer_config) 456 self.recognizer = _Recognizer(recognizer_config)
@@ -460,6 +477,9 @@ class OnlineRecognizer(object): @@ -460,6 +477,9 @@ class OnlineRecognizer(object):
460 rule_fsts: str = "", 477 rule_fsts: str = "",
461 rule_fars: str = "", 478 rule_fars: str = "",
462 device: int = 0, 479 device: int = 0,
  480 + hr_dict_dir: str = "",
  481 + hr_rule_fsts: str = "",
  482 + hr_lexicon: str = "",
463 ): 483 ):
464 """ 484 """
465 Please refer to 485 Please refer to
@@ -526,8 +546,8 @@ class OnlineRecognizer(object): @@ -526,8 +546,8 @@ class OnlineRecognizer(object):
526 zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) 546 zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
527 547
528 provider_config = ProviderConfig( 548 provider_config = ProviderConfig(
529 - provider=provider,  
530 - device=device, 549 + provider=provider,
  550 + device=device,
531 ) 551 )
532 552
533 model_config = OnlineModelConfig( 553 model_config = OnlineModelConfig(
@@ -563,6 +583,11 @@ class OnlineRecognizer(object): @@ -563,6 +583,11 @@ class OnlineRecognizer(object):
563 decoding_method=decoding_method, 583 decoding_method=decoding_method,
564 rule_fsts=rule_fsts, 584 rule_fsts=rule_fsts,
565 rule_fars=rule_fars, 585 rule_fars=rule_fars,
  586 + hr=HomophoneReplacerConfig(
  587 + dict_dir=hr_dict_dir,
  588 + lexicon=hr_lexicon,
  589 + rule_fsts=hr_rule_fsts,
  590 + ),
566 ) 591 )
567 592
568 self.recognizer = _Recognizer(recognizer_config) 593 self.recognizer = _Recognizer(recognizer_config)
@@ -587,6 +612,9 @@ class OnlineRecognizer(object): @@ -587,6 +612,9 @@ class OnlineRecognizer(object):
587 rule_fsts: str = "", 612 rule_fsts: str = "",
588 rule_fars: str = "", 613 rule_fars: str = "",
589 device: int = 0, 614 device: int = 0,
  615 + hr_dict_dir: str = "",
  616 + hr_rule_fsts: str = "",
  617 + hr_lexicon: str = "",
590 ): 618 ):
591 """ 619 """
592 Please refer to 620 Please refer to
@@ -650,8 +678,8 @@ class OnlineRecognizer(object): @@ -650,8 +678,8 @@ class OnlineRecognizer(object):
650 ) 678 )
651 679
652 provider_config = ProviderConfig( 680 provider_config = ProviderConfig(
653 - provider=provider,  
654 - device=device, 681 + provider=provider,
  682 + device=device,
655 ) 683 )
656 684
657 model_config = OnlineModelConfig( 685 model_config = OnlineModelConfig(
@@ -681,6 +709,11 @@ class OnlineRecognizer(object): @@ -681,6 +709,11 @@ class OnlineRecognizer(object):
681 decoding_method=decoding_method, 709 decoding_method=decoding_method,
682 rule_fsts=rule_fsts, 710 rule_fsts=rule_fsts,
683 rule_fars=rule_fars, 711 rule_fars=rule_fars,
  712 + hr=HomophoneReplacerConfig(
  713 + dict_dir=hr_dict_dir,
  714 + lexicon=hr_lexicon,
  715 + rule_fsts=hr_rule_fsts,
  716 + ),
684 ) 717 )
685 718
686 self.recognizer = _Recognizer(recognizer_config) 719 self.recognizer = _Recognizer(recognizer_config)
@@ -707,6 +740,9 @@ class OnlineRecognizer(object): @@ -707,6 +740,9 @@ class OnlineRecognizer(object):
707 rule_fsts: str = "", 740 rule_fsts: str = "",
708 rule_fars: str = "", 741 rule_fars: str = "",
709 device: int = 0, 742 device: int = 0,
  743 + hr_dict_dir: str = "",
  744 + hr_rule_fsts: str = "",
  745 + hr_lexicon: str = "",
710 ): 746 ):
711 """ 747 """
712 Please refer to 748 Please refer to
@@ -775,8 +811,8 @@ class OnlineRecognizer(object): @@ -775,8 +811,8 @@ class OnlineRecognizer(object):
775 ) 811 )
776 812
777 provider_config = ProviderConfig( 813 provider_config = ProviderConfig(
778 - provider=provider,  
779 - device=device, 814 + provider=provider,
  815 + device=device,
780 ) 816 )
781 817
782 model_config = OnlineModelConfig( 818 model_config = OnlineModelConfig(
@@ -806,6 +842,11 @@ class OnlineRecognizer(object): @@ -806,6 +842,11 @@ class OnlineRecognizer(object):
806 decoding_method=decoding_method, 842 decoding_method=decoding_method,
807 rule_fsts=rule_fsts, 843 rule_fsts=rule_fsts,
808 rule_fars=rule_fars, 844 rule_fars=rule_fars,
  845 + hr=HomophoneReplacerConfig(
  846 + dict_dir=hr_dict_dir,
  847 + lexicon=hr_lexicon,
  848 + rule_fsts=hr_rule_fsts,
  849 + ),
809 ) 850 )
810 851
811 self.recognizer = _Recognizer(recognizer_config) 852 self.recognizer = _Recognizer(recognizer_config)