Committed by
GitHub
Support replacing homonphonic phrases (#2153)
正在显示
42 个修改的文件
包含
834 行增加
和
134 行删除
| @@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do | @@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do | ||
| 98 | done | 98 | done |
| 99 | done | 99 | done |
| 100 | 100 | ||
| 101 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 | ||
| 102 | +tar xf dict.tar.bz2 | ||
| 103 | +rm dict.tar.bz2 | ||
| 104 | + | ||
| 105 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst | ||
| 106 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav | ||
| 107 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt | ||
| 108 | + | ||
| 109 | +for m in model.onnx model.int8.onnx; do | ||
| 110 | + for use_itn in 0 1; do | ||
| 111 | + echo "$m $w $use_itn" | ||
| 112 | + time $EXE \ | ||
| 113 | + --tokens=$repo/tokens.txt \ | ||
| 114 | + --sense-voice-model=$repo/$m \ | ||
| 115 | + --sense-voice-use-itn=$use_itn \ | ||
| 116 | + --hr-lexicon=./lexicon.txt \ | ||
| 117 | + --hr-dict-dir=./dict \ | ||
| 118 | + --hr-rule-fsts=./replace.fst \ | ||
| 119 | + ./test-hr.wav | ||
| 120 | + done | ||
| 121 | +done | ||
| 122 | + | ||
| 123 | +rm -rf dict replace.fst test-hr.wav lexicon.txt | ||
| 101 | 124 | ||
| 102 | # test wav reader for non-standard wav files | 125 | # test wav reader for non-standard wav files |
| 103 | waves=( | 126 | waves=( |
| @@ -95,6 +95,18 @@ rm $name | @@ -95,6 +95,18 @@ rm $name | ||
| 95 | ls -lh $repo | 95 | ls -lh $repo |
| 96 | python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py | 96 | python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py |
| 97 | 97 | ||
| 98 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 | ||
| 99 | +tar xf dict.tar.bz2 | ||
| 100 | +rm dict.tar.bz2 | ||
| 101 | + | ||
| 102 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst | ||
| 103 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav | ||
| 104 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt | ||
| 105 | + | ||
| 106 | +python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py | ||
| 107 | + | ||
| 108 | +rm -rf dict replace.fst test-hr.wav lexicon.txt | ||
| 109 | + | ||
| 98 | if [[ $(uname) == Linux ]]; then | 110 | if [[ $(uname) == Linux ]]; then |
| 99 | # It needs ffmpeg | 111 | # It needs ffmpeg |
| 100 | log "generate subtitles (Chinese)" | 112 | log "generate subtitles (Chinese)" |
| 1 | function(download_kaldifst) | 1 | function(download_kaldifst) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz") | ||
| 5 | - set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz") | ||
| 6 | - set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f") | 4 | + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz") |
| 5 | + set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz") | ||
| 6 | + set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2") | ||
| 7 | 7 | ||
| 8 | # If you don't have access to the Internet, | 8 | # If you don't have access to the Internet, |
| 9 | # please pre-download kaldifst | 9 | # please pre-download kaldifst |
| 10 | set(possible_file_locations | 10 | set(possible_file_locations |
| 11 | - $ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz | ||
| 12 | - ${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz | ||
| 13 | - ${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz | ||
| 14 | - /tmp/kaldifst-1.7.11.tar.gz | ||
| 15 | - /star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz | 11 | + $ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz |
| 12 | + ${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz | ||
| 13 | + ${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz | ||
| 14 | + /tmp/kaldifst-1.7.13.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz | ||
| 16 | ) | 16 | ) |
| 17 | 17 | ||
| 18 | foreach(f IN LISTS possible_file_locations) | 18 | foreach(f IN LISTS possible_file_locations) |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +""" | ||
| 4 | +This file shows how to use a non-streaming SenseVoice CTC model from | ||
| 5 | +https://github.com/FunAudioLLM/SenseVoice | ||
| 6 | +to decode files. | ||
| 7 | + | ||
| 8 | +Please download model files from | ||
| 9 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 10 | + | ||
| 11 | +For instance, | ||
| 12 | + | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 15 | +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 | ||
| 16 | + | ||
| 17 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2 | ||
| 18 | +tar xf dict.tar.bz2 | ||
| 19 | + | ||
| 20 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst | ||
| 21 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav | ||
| 22 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt | ||
| 23 | +""" | ||
| 24 | + | ||
| 25 | +from pathlib import Path | ||
| 26 | + | ||
| 27 | +import sherpa_onnx | ||
| 28 | +import soundfile as sf | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +def create_recognizer(): | ||
| 32 | + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx" | ||
| 33 | + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" | ||
| 34 | + test_wav = "./test-hr.wav" | ||
| 35 | + | ||
| 36 | + if not Path(model).is_file() or not Path(test_wav).is_file(): | ||
| 37 | + raise ValueError( | ||
| 38 | + """Please download model files from | ||
| 39 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models | ||
| 40 | + and | ||
| 41 | + https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files | ||
| 42 | + """ | ||
| 43 | + ) | ||
| 44 | + return ( | ||
| 45 | + sherpa_onnx.OfflineRecognizer.from_sense_voice( | ||
| 46 | + model=model, | ||
| 47 | + tokens=tokens, | ||
| 48 | + use_itn=True, | ||
| 49 | + debug=True, | ||
| 50 | + hr_lexicon="./lexicon.txt", | ||
| 51 | + hr_dict_dir="./dict", | ||
| 52 | + hr_rule_fsts="./replace.fst", | ||
| 53 | + ), | ||
| 54 | + test_wav, | ||
| 55 | + ) | ||
| 56 | + | ||
| 57 | + | ||
| 58 | +def main(): | ||
| 59 | + recognizer, wave_filename = create_recognizer() | ||
| 60 | + | ||
| 61 | + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) | ||
| 62 | + audio = audio[:, 0] # only use the first channel | ||
| 63 | + | ||
| 64 | + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] | ||
| 65 | + # sample_rate does not need to be 16000 Hz | ||
| 66 | + | ||
| 67 | + stream = recognizer.create_stream() | ||
| 68 | + stream.accept_waveform(sample_rate, audio) | ||
| 69 | + recognizer.decode_stream(stream) | ||
| 70 | + print(wave_filename) | ||
| 71 | + print(stream.result) | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +if __name__ == "__main__": | ||
| 75 | + main() |
| @@ -20,7 +20,9 @@ set(sources | @@ -20,7 +20,9 @@ set(sources | ||
| 20 | features.cc | 20 | features.cc |
| 21 | file-utils.cc | 21 | file-utils.cc |
| 22 | fst-utils.cc | 22 | fst-utils.cc |
| 23 | + homophone-replacer.cc | ||
| 23 | hypothesis.cc | 24 | hypothesis.cc |
| 25 | + jieba.cc | ||
| 24 | keyword-spotter-impl.cc | 26 | keyword-spotter-impl.cc |
| 25 | keyword-spotter.cc | 27 | keyword-spotter.cc |
| 26 | offline-ctc-fst-decoder-config.cc | 28 | offline-ctc-fst-decoder-config.cc |
sherpa-onnx/csrc/homophone-replacer.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/homophone-replacer.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 6 | + | ||
| 7 | +#include <fstream> | ||
| 8 | +#include <sstream> | ||
| 9 | +#include <string> | ||
| 10 | +#include <strstream> | ||
| 11 | +#include <unordered_map> | ||
| 12 | +#include <utility> | ||
| 13 | +#include <vector> | ||
| 14 | + | ||
| 15 | +#if __ANDROID_API__ >= 9 | ||
| 16 | +#include "android/asset_manager.h" | ||
| 17 | +#include "android/asset_manager_jni.h" | ||
| 18 | +#endif | ||
| 19 | + | ||
| 20 | +#if __OHOS__ | ||
| 21 | +#include "rawfile/raw_file_manager.h" | ||
| 22 | +#endif | ||
| 23 | + | ||
| 24 | +#include "kaldifst/csrc/text-normalizer.h" | ||
| 25 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 26 | +#include "sherpa-onnx/csrc/jieba.h" | ||
| 27 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 28 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 29 | + | ||
| 30 | +namespace sherpa_onnx { | ||
| 31 | + | ||
| 32 | +void HomophoneReplacerConfig::Register(ParseOptions *po) { | ||
| 33 | + po->Register("hr-dict-dir", &dict_dir, | ||
| 34 | + "The dict directory for jieba used by HomophoneReplacer"); | ||
| 35 | + | ||
| 36 | + po->Register("hr-lexicon", &lexicon, | ||
| 37 | + "Path to lexicon.txt used by HomophoneReplacer."); | ||
| 38 | + | ||
| 39 | + po->Register("hr-rule-fsts", &rule_fsts, | ||
| 40 | + "Fst files for HomophoneReplacer. If there are multiple, they " | ||
| 41 | + "are separated by a comma. E.g., a.fst,b.fst,c.fst"); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +bool HomophoneReplacerConfig::Validate() const { | ||
| 45 | + if (!dict_dir.empty()) { | ||
| 46 | + std::vector<std::string> required_files = { | ||
| 47 | + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", | ||
| 48 | + "idf.utf8", "stop_words.utf8", | ||
| 49 | + }; | ||
| 50 | + | ||
| 51 | + for (const auto &f : required_files) { | ||
| 52 | + if (!FileExists(dict_dir + "/" + f)) { | ||
| 53 | + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir", | ||
| 54 | + dict_dir.c_str(), f.c_str()); | ||
| 55 | + return false; | ||
| 56 | + } | ||
| 57 | + } | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + if (!lexicon.empty() && !FileExists(lexicon)) { | ||
| 61 | + SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str()); | ||
| 62 | + return false; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + if (!rule_fsts.empty()) { | ||
| 66 | + std::vector<std::string> files; | ||
| 67 | + SplitStringToVector(rule_fsts, ",", false, &files); | ||
| 68 | + | ||
| 69 | + if (files.size() > 1) { | ||
| 70 | + SHERPA_ONNX_LOGE("Only 1 file is supported now."); | ||
| 71 | + SHERPA_ONNX_EXIT(-1); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + for (const auto &f : files) { | ||
| 75 | + if (!FileExists(f)) { | ||
| 76 | + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); | ||
| 77 | + return false; | ||
| 78 | + } | ||
| 79 | + } | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + return true; | ||
| 83 | +} | ||
| 84 | + | ||
| 85 | +std::string HomophoneReplacerConfig::ToString() const { | ||
| 86 | + std::ostringstream os; | ||
| 87 | + | ||
| 88 | + os << "HomophoneReplacerConfig("; | ||
| 89 | + os << "dict_dir=\"" << dict_dir << "\", "; | ||
| 90 | + os << "lexicon=\"" << lexicon << "\", "; | ||
| 91 | + os << "rule_fsts=\"" << rule_fsts << "\")"; | ||
| 92 | + | ||
| 93 | + return os.str(); | ||
| 94 | +} | ||
| 95 | + | ||
| 96 | +class HomophoneReplacer::Impl { | ||
| 97 | + public: | ||
| 98 | + explicit Impl(const HomophoneReplacerConfig &config) : config_(config) { | ||
| 99 | + jieba_ = InitJieba(config.dict_dir); | ||
| 100 | + | ||
| 101 | + { | ||
| 102 | + std::ifstream is(config.lexicon); | ||
| 103 | + InitLexicon(is); | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + if (!config.rule_fsts.empty()) { | ||
| 107 | + std::vector<std::string> files; | ||
| 108 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 109 | + replacer_list_.reserve(files.size()); | ||
| 110 | + for (const auto &f : files) { | ||
| 111 | + if (config.debug) { | ||
| 112 | + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str()); | ||
| 113 | + } | ||
| 114 | + replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f)); | ||
| 115 | + } | ||
| 116 | + } | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + template <typename Manager> | ||
| 120 | + Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) { | ||
| 121 | + jieba_ = InitJieba(config.dict_dir); | ||
| 122 | + { | ||
| 123 | + auto buf = ReadFile(mgr, config.lexicon); | ||
| 124 | + | ||
| 125 | + std::istrstream is(buf.data(), buf.size()); | ||
| 126 | + InitLexicon(is); | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + if (!config.rule_fsts.empty()) { | ||
| 130 | + std::vector<std::string> files; | ||
| 131 | + SplitStringToVector(config.rule_fsts, ",", false, &files); | ||
| 132 | + replacer_list_.reserve(files.size()); | ||
| 133 | + for (const auto &f : files) { | ||
| 134 | + if (config.debug) { | ||
| 135 | + SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str()); | ||
| 136 | + } | ||
| 137 | + auto buf = ReadFile(mgr, f); | ||
| 138 | + std::istrstream is(buf.data(), buf.size()); | ||
| 139 | + replacer_list_.push_back( | ||
| 140 | + std::make_unique<kaldifst::TextNormalizer>(is)); | ||
| 141 | + } | ||
| 142 | + } | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + std::string Apply(const std::string &text) const { | ||
| 146 | + bool is_hmm = true; | ||
| 147 | + | ||
| 148 | + std::vector<std::string> words; | ||
| 149 | + jieba_->Cut(text, words, is_hmm); | ||
| 150 | + if (config_.debug) { | ||
| 151 | + SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str()); | ||
| 152 | + std::ostringstream os; | ||
| 153 | + os << "After jieba: "; | ||
| 154 | + std::string sep; | ||
| 155 | + for (const auto &w : words) { | ||
| 156 | + os << sep << w; | ||
| 157 | + sep = "_"; | ||
| 158 | + } | ||
| 159 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 160 | + } | ||
| 161 | + | ||
| 162 | + // convert words to pronunciations | ||
| 163 | + std::vector<std::string> pronunciations; | ||
| 164 | + | ||
| 165 | + for (const auto &w : words) { | ||
| 166 | + auto p = ConvertWordToPronunciation(w); | ||
| 167 | + if (config_.debug) { | ||
| 168 | + SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str()); | ||
| 169 | + } | ||
| 170 | + pronunciations.push_back(std::move(p)); | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + std::string ans; | ||
| 174 | + for (const auto &r : replacer_list_) { | ||
| 175 | + ans = r->Normalize(words, pronunciations); | ||
| 176 | + // TODO(fangjun): We support only 1 rule fst at present. | ||
| 177 | + break; | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + return ans; | ||
| 181 | + } | ||
| 182 | + | ||
| 183 | + private: | ||
| 184 | + std::string ConvertWordToPronunciation(const std::string &word) const { | ||
| 185 | + if (word2pron_.count(word)) { | ||
| 186 | + return word2pron_.at(word); | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + if (word.size() <= 3) { | ||
| 190 | + // not a Chinese character | ||
| 191 | + return word; | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + std::vector<std::string> words = SplitUtf8(word); | ||
| 195 | + std::string ans; | ||
| 196 | + for (const auto &w : words) { | ||
| 197 | + if (word2pron_.count(w)) { | ||
| 198 | + ans.append(word2pron_.at(w)); | ||
| 199 | + } else { | ||
| 200 | + ans.append(w); | ||
| 201 | + } | ||
| 202 | + } | ||
| 203 | + | ||
| 204 | + return ans; | ||
| 205 | + } | ||
| 206 | + | ||
| 207 | + void InitLexicon(std::istream &is) { | ||
| 208 | + std::string word; | ||
| 209 | + std::string pron; | ||
| 210 | + std::string p; | ||
| 211 | + | ||
| 212 | + std::string line; | ||
| 213 | + int32_t line_num = 0; | ||
| 214 | + int32_t num_warn = 0; | ||
| 215 | + while (std::getline(is, line)) { | ||
| 216 | + ++line_num; | ||
| 217 | + std::istringstream iss(line); | ||
| 218 | + | ||
| 219 | + pron.clear(); | ||
| 220 | + iss >> word; | ||
| 221 | + ToLowerCase(&word); | ||
| 222 | + | ||
| 223 | + if (word2pron_.count(word)) { | ||
| 224 | + num_warn += 1; | ||
| 225 | + if (num_warn < 10) { | ||
| 226 | + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", | ||
| 227 | + word.c_str(), line_num, line.c_str()); | ||
| 228 | + } | ||
| 229 | + continue; | ||
| 230 | + } | ||
| 231 | + | ||
| 232 | + while (iss >> p) { | ||
| 233 | + pron.append(std::move(p)); | ||
| 234 | + } | ||
| 235 | + | ||
| 236 | + if (pron.empty()) { | ||
| 237 | + SHERPA_ONNX_LOGE( | ||
| 238 | + "Empty pronunciation for word '%s' at line %d:%s. Ignore it.", | ||
| 239 | + word.c_str(), line_num, line.c_str()); | ||
| 240 | + continue; | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + word2pron_.insert({std::move(word), std::move(pron)}); | ||
| 244 | + } | ||
| 245 | + } | ||
| 246 | + | ||
| 247 | + private: | ||
| 248 | + HomophoneReplacerConfig config_; | ||
| 249 | + std::unique_ptr<cppjieba::Jieba> jieba_; | ||
| 250 | + std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_; | ||
| 251 | + std::unordered_map<std::string, std::string> word2pron_; | ||
| 252 | +}; | ||
| 253 | + | ||
| 254 | +HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config) | ||
| 255 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 256 | + | ||
| 257 | +template <typename Manager> | ||
| 258 | +HomophoneReplacer::HomophoneReplacer(Manager *mgr, | ||
| 259 | + const HomophoneReplacerConfig &config) | ||
| 260 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 261 | + | ||
| 262 | +HomophoneReplacer::~HomophoneReplacer() = default; | ||
| 263 | + | ||
| 264 | +std::string HomophoneReplacer::Apply(const std::string &text) const { | ||
| 265 | + return impl_->Apply(text); | ||
| 266 | +} | ||
| 267 | + | ||
| 268 | +#if __ANDROID_API__ >= 9 | ||
| 269 | +template HomophoneReplacer::HomophoneReplacer( | ||
| 270 | + AAssetManager *mgr, const HomophoneReplacerConfig &config); | ||
| 271 | +#endif | ||
| 272 | + | ||
| 273 | +#if __OHOS__ | ||
| 274 | +template HomophoneReplacer::HomophoneReplacer( | ||
| 275 | + NativeResourceManager *mgr, const HomophoneReplacerConfig &config); | ||
| 276 | +#endif | ||
| 277 | + | ||
| 278 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/homophone-replacer.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/homophone-replacer.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +struct HomophoneReplacerConfig { | ||
| 16 | + std::string dict_dir; | ||
| 17 | + std::string lexicon; | ||
| 18 | + | ||
| 19 | + // comma separated fst files, e.g. a.fst,b.fst,c.fst | ||
| 20 | + std::string rule_fsts; | ||
| 21 | + | ||
| 22 | + bool debug; | ||
| 23 | + | ||
| 24 | + HomophoneReplacerConfig() = default; | ||
| 25 | + | ||
| 26 | + HomophoneReplacerConfig(const std::string &dict_dir, | ||
| 27 | + const std::string &lexicon, | ||
| 28 | + const std::string &rule_fsts, bool debug) | ||
| 29 | + : dict_dir(dict_dir), | ||
| 30 | + lexicon(lexicon), | ||
| 31 | + rule_fsts(rule_fsts), | ||
| 32 | + debug(debug) {} | ||
| 33 | + | ||
| 34 | + void Register(ParseOptions *po); | ||
| 35 | + bool Validate() const; | ||
| 36 | + | ||
| 37 | + std::string ToString() const; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +class HomophoneReplacer { | ||
| 41 | + public: | ||
| 42 | + explicit HomophoneReplacer(const HomophoneReplacerConfig &config); | ||
| 43 | + | ||
| 44 | + template <typename Manager> | ||
| 45 | + HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config); | ||
| 46 | + | ||
| 47 | + ~HomophoneReplacer(); | ||
| 48 | + | ||
| 49 | + std::string Apply(const std::string &text) const; | ||
| 50 | + | ||
| 51 | + private: | ||
| 52 | + class Impl; | ||
| 53 | + std::unique_ptr<Impl> impl_; | ||
| 54 | +}; | ||
| 55 | + | ||
| 56 | +} // namespace sherpa_onnx | ||
| 57 | + | ||
| 58 | +#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_ |
| @@ -19,8 +19,8 @@ | @@ -19,8 +19,8 @@ | ||
| 19 | #include "rawfile/raw_file_manager.h" | 19 | #include "rawfile/raw_file_manager.h" |
| 20 | #endif | 20 | #endif |
| 21 | 21 | ||
| 22 | -#include "cppjieba/Jieba.hpp" | ||
| 23 | #include "sherpa-onnx/csrc/file-utils.h" | 22 | #include "sherpa-onnx/csrc/file-utils.h" |
| 23 | +#include "sherpa-onnx/csrc/jieba.h" | ||
| 24 | #include "sherpa-onnx/csrc/macros.h" | 24 | #include "sherpa-onnx/csrc/macros.h" |
| 25 | #include "sherpa-onnx/csrc/onnx-utils.h" | 25 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 26 | #include "sherpa-onnx/csrc/symbol-table.h" | 26 | #include "sherpa-onnx/csrc/symbol-table.h" |
| @@ -41,20 +41,7 @@ class JiebaLexicon::Impl { | @@ -41,20 +41,7 @@ class JiebaLexicon::Impl { | ||
| 41 | Impl(const std::string &lexicon, const std::string &tokens, | 41 | Impl(const std::string &lexicon, const std::string &tokens, |
| 42 | const std::string &dict_dir, bool debug) | 42 | const std::string &dict_dir, bool debug) |
| 43 | : debug_(debug) { | 43 | : debug_(debug) { |
| 44 | - std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 45 | - std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 46 | - std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 47 | - std::string idf = dict_dir + "/idf.utf8"; | ||
| 48 | - std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 49 | - | ||
| 50 | - AssertFileExists(dict); | ||
| 51 | - AssertFileExists(hmm); | ||
| 52 | - AssertFileExists(user_dict); | ||
| 53 | - AssertFileExists(idf); | ||
| 54 | - AssertFileExists(stop_word); | ||
| 55 | - | ||
| 56 | - jieba_ = | ||
| 57 | - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | 44 | + jieba_ = InitJieba(dict_dir); |
| 58 | 45 | ||
| 59 | { | 46 | { |
| 60 | std::ifstream is(tokens); | 47 | std::ifstream is(tokens); |
| @@ -71,20 +58,7 @@ class JiebaLexicon::Impl { | @@ -71,20 +58,7 @@ class JiebaLexicon::Impl { | ||
| 71 | Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, | 58 | Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, |
| 72 | const std::string &dict_dir, bool debug) | 59 | const std::string &dict_dir, bool debug) |
| 73 | : debug_(debug) { | 60 | : debug_(debug) { |
| 74 | - std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 75 | - std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 76 | - std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 77 | - std::string idf = dict_dir + "/idf.utf8"; | ||
| 78 | - std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 79 | - | ||
| 80 | - AssertFileExists(dict); | ||
| 81 | - AssertFileExists(hmm); | ||
| 82 | - AssertFileExists(user_dict); | ||
| 83 | - AssertFileExists(idf); | ||
| 84 | - AssertFileExists(stop_word); | ||
| 85 | - | ||
| 86 | - jieba_ = | ||
| 87 | - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | 61 | + jieba_ = InitJieba(dict_dir); |
| 88 | 62 | ||
| 89 | { | 63 | { |
| 90 | auto buf = ReadFile(mgr, tokens); | 64 | auto buf = ReadFile(mgr, tokens); |
sherpa-onnx/csrc/jieba.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/jieba.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/jieba.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir) { | ||
| 12 | + if (dict_dir.empty()) { | ||
| 13 | + return {}; | ||
| 14 | + } | ||
| 15 | + | ||
| 16 | + std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 17 | + std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 18 | + std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 19 | + std::string idf = dict_dir + "/idf.utf8"; | ||
| 20 | + std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 21 | + | ||
| 22 | + AssertFileExists(dict); | ||
| 23 | + AssertFileExists(hmm); | ||
| 24 | + AssertFileExists(user_dict); | ||
| 25 | + AssertFileExists(idf); | ||
| 26 | + AssertFileExists(stop_word); | ||
| 27 | + | ||
| 28 | + return std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, | ||
| 29 | + stop_word); | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/jieba.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/jieba.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_JIEBA_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_JIEBA_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "cppjieba/Jieba.hpp" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir); | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +#endif // SHERPA_ONNX_CSRC_JIEBA_H_ |
| @@ -22,11 +22,11 @@ | @@ -22,11 +22,11 @@ | ||
| 22 | 22 | ||
| 23 | #include <codecvt> | 23 | #include <codecvt> |
| 24 | 24 | ||
| 25 | -#include "cppjieba/Jieba.hpp" | ||
| 26 | #include "espeak-ng/speak_lib.h" | 25 | #include "espeak-ng/speak_lib.h" |
| 27 | #include "phoneme_ids.hpp" | 26 | #include "phoneme_ids.hpp" |
| 28 | #include "phonemize.hpp" | 27 | #include "phonemize.hpp" |
| 29 | #include "sherpa-onnx/csrc/file-utils.h" | 28 | #include "sherpa-onnx/csrc/file-utils.h" |
| 29 | +#include "sherpa-onnx/csrc/jieba.h" | ||
| 30 | #include "sherpa-onnx/csrc/onnx-utils.h" | 30 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 31 | #include "sherpa-onnx/csrc/symbol-table.h" | 31 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 32 | #include "sherpa-onnx/csrc/text-utils.h" | 32 | #include "sherpa-onnx/csrc/text-utils.h" |
| @@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl { | @@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl { | ||
| 47 | 47 | ||
| 48 | InitLexicon(lexicon); | 48 | InitLexicon(lexicon); |
| 49 | 49 | ||
| 50 | - InitJieba(dict_dir); | 50 | + jieba_ = InitJieba(dict_dir); |
| 51 | 51 | ||
| 52 | InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc | 52 | InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc |
| 53 | } | 53 | } |
| @@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl { | @@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl { | ||
| 62 | InitLexicon(mgr, lexicon); | 62 | InitLexicon(mgr, lexicon); |
| 63 | 63 | ||
| 64 | // we assume you have copied dict_dir and data_dir from assets to some path | 64 | // we assume you have copied dict_dir and data_dir from assets to some path |
| 65 | - InitJieba(dict_dir); | 65 | + jieba_ = InitJieba(dict_dir); |
| 66 | 66 | ||
| 67 | InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc | 67 | InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc |
| 68 | } | 68 | } |
| @@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl { | @@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl { | ||
| 456 | } | 456 | } |
| 457 | } | 457 | } |
| 458 | 458 | ||
| 459 | - void InitJieba(const std::string &dict_dir) { | ||
| 460 | - std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 461 | - std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 462 | - std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 463 | - std::string idf = dict_dir + "/idf.utf8"; | ||
| 464 | - std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 465 | - | ||
| 466 | - AssertFileExists(dict); | ||
| 467 | - AssertFileExists(hmm); | ||
| 468 | - AssertFileExists(user_dict); | ||
| 469 | - AssertFileExists(idf); | ||
| 470 | - AssertFileExists(stop_word); | ||
| 471 | - | ||
| 472 | - jieba_ = | ||
| 473 | - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | ||
| 474 | - } | ||
| 475 | - | ||
| 476 | private: | 459 | private: |
| 477 | OfflineTtsKokoroModelMetaData meta_data_; | 460 | OfflineTtsKokoroModelMetaData meta_data_; |
| 478 | 461 |
| @@ -19,8 +19,8 @@ | @@ -19,8 +19,8 @@ | ||
| 19 | #include "rawfile/raw_file_manager.h" | 19 | #include "rawfile/raw_file_manager.h" |
| 20 | #endif | 20 | #endif |
| 21 | 21 | ||
| 22 | -#include "cppjieba/Jieba.hpp" | ||
| 23 | #include "sherpa-onnx/csrc/file-utils.h" | 22 | #include "sherpa-onnx/csrc/file-utils.h" |
| 23 | +#include "sherpa-onnx/csrc/jieba.h" | ||
| 24 | #include "sherpa-onnx/csrc/macros.h" | 24 | #include "sherpa-onnx/csrc/macros.h" |
| 25 | #include "sherpa-onnx/csrc/onnx-utils.h" | 25 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 26 | #include "sherpa-onnx/csrc/symbol-table.h" | 26 | #include "sherpa-onnx/csrc/symbol-table.h" |
| @@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl { | @@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl { | ||
| 34 | const std::string &dict_dir, | 34 | const std::string &dict_dir, |
| 35 | const OfflineTtsVitsModelMetaData &meta_data, bool debug) | 35 | const OfflineTtsVitsModelMetaData &meta_data, bool debug) |
| 36 | : meta_data_(meta_data), debug_(debug) { | 36 | : meta_data_(meta_data), debug_(debug) { |
| 37 | - std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 38 | - std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 39 | - std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 40 | - std::string idf = dict_dir + "/idf.utf8"; | ||
| 41 | - std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 42 | - | ||
| 43 | - AssertFileExists(dict); | ||
| 44 | - AssertFileExists(hmm); | ||
| 45 | - AssertFileExists(user_dict); | ||
| 46 | - AssertFileExists(idf); | ||
| 47 | - AssertFileExists(stop_word); | ||
| 48 | - | ||
| 49 | - jieba_ = | ||
| 50 | - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | 37 | + jieba_ = InitJieba(dict_dir); |
| 51 | 38 | ||
| 52 | { | 39 | { |
| 53 | std::ifstream is(tokens); | 40 | std::ifstream is(tokens); |
| @@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl { | @@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl { | ||
| 79 | const std::string &dict_dir, | 66 | const std::string &dict_dir, |
| 80 | const OfflineTtsVitsModelMetaData &meta_data, bool debug) | 67 | const OfflineTtsVitsModelMetaData &meta_data, bool debug) |
| 81 | : meta_data_(meta_data), debug_(debug) { | 68 | : meta_data_(meta_data), debug_(debug) { |
| 82 | - std::string dict = dict_dir + "/jieba.dict.utf8"; | ||
| 83 | - std::string hmm = dict_dir + "/hmm_model.utf8"; | ||
| 84 | - std::string user_dict = dict_dir + "/user.dict.utf8"; | ||
| 85 | - std::string idf = dict_dir + "/idf.utf8"; | ||
| 86 | - std::string stop_word = dict_dir + "/stop_words.utf8"; | ||
| 87 | - | ||
| 88 | - AssertFileExists(dict); | ||
| 89 | - AssertFileExists(hmm); | ||
| 90 | - AssertFileExists(user_dict); | ||
| 91 | - AssertFileExists(idf); | ||
| 92 | - AssertFileExists(stop_word); | ||
| 93 | - | ||
| 94 | - jieba_ = | ||
| 95 | - std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word); | 69 | + jieba_ = InitJieba(dict_dir); |
| 96 | 70 | ||
| 97 | { | 71 | { |
| 98 | auto buf = ReadFile(mgr, tokens); | 72 | auto buf = ReadFile(mgr, tokens); |
| @@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 239 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 239 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 240 | model_->SubsamplingFactor()); | 240 | model_->SubsamplingFactor()); |
| 241 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 241 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 242 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 242 | ss[i]->SetResult(r); | 243 | ss[i]->SetResult(r); |
| 243 | } | 244 | } |
| 244 | } | 245 | } |
| @@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 277 | auto r = Convert(results[0], symbol_table_, frame_shift_ms, | 278 | auto r = Convert(results[0], symbol_table_, frame_shift_ms, |
| 278 | model_->SubsamplingFactor()); | 279 | model_->SubsamplingFactor()); |
| 279 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 280 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 281 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 280 | s->SetResult(r); | 282 | s->SetResult(r); |
| 281 | } | 283 | } |
| 282 | 284 |
| @@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { | @@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { | ||
| 125 | auto r = Convert(results[0], symbol_table_); | 125 | auto r = Convert(results[0], symbol_table_); |
| 126 | 126 | ||
| 127 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 127 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 128 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 128 | s->SetResult(r); | 129 | s->SetResult(r); |
| 129 | } | 130 | } |
| 130 | 131 |
| @@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 408 | OfflineRecognizerImpl::OfflineRecognizerImpl( | 408 | OfflineRecognizerImpl::OfflineRecognizerImpl( |
| 409 | const OfflineRecognizerConfig &config) | 409 | const OfflineRecognizerConfig &config) |
| 410 | : config_(config) { | 410 | : config_(config) { |
| 411 | + // TODO(fangjun): Refactor this function | ||
| 412 | + | ||
| 411 | if (!config.rule_fsts.empty()) { | 413 | if (!config.rule_fsts.empty()) { |
| 412 | std::vector<std::string> files; | 414 | std::vector<std::string> files; |
| 413 | SplitStringToVector(config.rule_fsts, ",", false, &files); | 415 | SplitStringToVector(config.rule_fsts, ",", false, &files); |
| @@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( | @@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( | ||
| 448 | SHERPA_ONNX_LOGE("FST archives loaded!"); | 450 | SHERPA_ONNX_LOGE("FST archives loaded!"); |
| 449 | } | 451 | } |
| 450 | } | 452 | } |
| 453 | + | ||
| 454 | + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && | ||
| 455 | + !config.hr.rule_fsts.empty()) { | ||
| 456 | + auto hr_config = config.hr; | ||
| 457 | + hr_config.debug = config.model_config.debug; | ||
| 458 | + hr_ = std::make_unique<HomophoneReplacer>(hr_config); | ||
| 459 | + } | ||
| 451 | } | 460 | } |
| 452 | 461 | ||
| 453 | template <typename Manager> | 462 | template <typename Manager> |
| @@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( | @@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl( | ||
| 495 | } // for (; !reader->Done(); reader->Next()) | 504 | } // for (; !reader->Done(); reader->Next()) |
| 496 | } // for (const auto &f : files) | 505 | } // for (const auto &f : files) |
| 497 | } // if (!config.rule_fars.empty()) | 506 | } // if (!config.rule_fars.empty()) |
| 507 | + | ||
| 508 | + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && | ||
| 509 | + !config.hr.rule_fsts.empty()) { | ||
| 510 | + auto hr_config = config.hr; | ||
| 511 | + hr_config.debug = config.model_config.debug; | ||
| 512 | + hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config); | ||
| 513 | + } | ||
| 498 | } | 514 | } |
| 499 | 515 | ||
| 500 | std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | 516 | std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( |
| @@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | @@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | ||
| 510 | return text; | 526 | return text; |
| 511 | } | 527 | } |
| 512 | 528 | ||
| 529 | +std::string OfflineRecognizerImpl::ApplyHomophoneReplacer( | ||
| 530 | + std::string text) const { | ||
| 531 | + if (hr_) { | ||
| 532 | + text = hr_->Apply(text); | ||
| 533 | + } | ||
| 534 | + | ||
| 535 | + return text; | ||
| 536 | +} | ||
| 537 | + | ||
| 513 | void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { | 538 | void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { |
| 514 | config_ = config; | 539 | config_ = config; |
| 515 | } | 540 | } |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | #include "kaldifst/csrc/text-normalizer.h" | 12 | #include "kaldifst/csrc/text-normalizer.h" |
| 13 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 13 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 15 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 15 | #include "sherpa-onnx/csrc/offline-stream.h" | 16 | #include "sherpa-onnx/csrc/offline-stream.h" |
| @@ -48,12 +49,15 @@ class OfflineRecognizerImpl { | @@ -48,12 +49,15 @@ class OfflineRecognizerImpl { | ||
| 48 | 49 | ||
| 49 | std::string ApplyInverseTextNormalization(std::string text) const; | 50 | std::string ApplyInverseTextNormalization(std::string text) const; |
| 50 | 51 | ||
| 52 | + std::string ApplyHomophoneReplacer(std::string text) const; | ||
| 53 | + | ||
| 51 | private: | 54 | private: |
| 52 | OfflineRecognizerConfig config_; | 55 | OfflineRecognizerConfig config_; |
| 53 | // for inverse text normalization. Used only if | 56 | // for inverse text normalization. Used only if |
| 54 | // config.rule_fsts is not empty or | 57 | // config.rule_fsts is not empty or |
| 55 | // config.rule_fars is not empty | 58 | // config.rule_fars is not empty |
| 56 | std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; | 59 | std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; |
| 60 | + std::unique_ptr<HomophoneReplacer> hr_; | ||
| 57 | }; | 61 | }; |
| 58 | 62 | ||
| 59 | } // namespace sherpa_onnx | 63 | } // namespace sherpa_onnx |
| @@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { | @@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { | ||
| 121 | 121 | ||
| 122 | auto r = Convert(results[0], symbol_table_); | 122 | auto r = Convert(results[0], symbol_table_); |
| 123 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 123 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 124 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 124 | s->SetResult(r); | 125 | s->SetResult(r); |
| 125 | } catch (const Ort::Exception &ex) { | 126 | } catch (const Ort::Exception &ex) { |
| 126 | SHERPA_ONNX_LOGE( | 127 | SHERPA_ONNX_LOGE( |
| @@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 197 | for (int32_t i = 0; i != n; ++i) { | 197 | for (int32_t i = 0; i != n; ++i) { |
| 198 | auto r = Convert(results[i], symbol_table_); | 198 | auto r = Convert(results[i], symbol_table_); |
| 199 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 199 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 200 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 200 | ss[i]->SetResult(r); | 201 | ss[i]->SetResult(r); |
| 201 | } | 202 | } |
| 202 | } | 203 | } |
| @@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | @@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | ||
| 222 | auto r = ConvertSenseVoiceResult(results[i], symbol_table_, | 222 | auto r = ConvertSenseVoiceResult(results[i], symbol_table_, |
| 223 | frame_shift_ms, subsampling_factor); | 223 | frame_shift_ms, subsampling_factor); |
| 224 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 224 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 225 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 225 | ss[i]->SetResult(r); | 226 | ss[i]->SetResult(r); |
| 226 | } | 227 | } |
| 227 | } | 228 | } |
| @@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | @@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { | ||
| 295 | subsampling_factor); | 296 | subsampling_factor); |
| 296 | 297 | ||
| 297 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 298 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 299 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 298 | s->SetResult(r); | 300 | s->SetResult(r); |
| 299 | } | 301 | } |
| 300 | 302 |
| @@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 239 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 239 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 240 | model_->SubsamplingFactor()); | 240 | model_->SubsamplingFactor()); |
| 241 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 241 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 242 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 242 | 243 | ||
| 243 | ss[i]->SetResult(r); | 244 | ss[i]->SetResult(r); |
| 244 | } | 245 | } |
| @@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 128 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, | 128 | auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 129 | model_->SubsamplingFactor()); | 129 | model_->SubsamplingFactor()); |
| 130 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 130 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 131 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 131 | 132 | ||
| 132 | ss[i]->SetResult(r); | 133 | ss[i]->SetResult(r); |
| 133 | } | 134 | } |
| @@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 160 | 160 | ||
| 161 | std::string s = sym_table[i]; | 161 | std::string s = sym_table[i]; |
| 162 | s = ApplyInverseTextNormalization(s); | 162 | s = ApplyInverseTextNormalization(s); |
| 163 | + s = ApplyHomophoneReplacer(std::move(s)); | ||
| 163 | 164 | ||
| 164 | text += s; | 165 | text += s; |
| 165 | r.tokens.push_back(s); | 166 | r.tokens.push_back(s); |
| @@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 28 | model_config.Register(po); | 28 | model_config.Register(po); |
| 29 | lm_config.Register(po); | 29 | lm_config.Register(po); |
| 30 | ctc_fst_decoder_config.Register(po); | 30 | ctc_fst_decoder_config.Register(po); |
| 31 | + hr.Register(po); | ||
| 31 | 32 | ||
| 32 | po->Register( | 33 | po->Register( |
| 33 | "decoding-method", &decoding_method, | 34 | "decoding-method", &decoding_method, |
| @@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const { | @@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const { | ||
| 120 | } | 121 | } |
| 121 | } | 122 | } |
| 122 | 123 | ||
| 124 | + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() && | ||
| 125 | + !hr.Validate()) { | ||
| 126 | + return false; | ||
| 127 | + } | ||
| 128 | + | ||
| 123 | return model_config.Validate(); | 129 | return model_config.Validate(); |
| 124 | } | 130 | } |
| 125 | 131 | ||
| @@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 137 | os << "hotwords_score=" << hotwords_score << ", "; | 143 | os << "hotwords_score=" << hotwords_score << ", "; |
| 138 | os << "blank_penalty=" << blank_penalty << ", "; | 144 | os << "blank_penalty=" << blank_penalty << ", "; |
| 139 | os << "rule_fsts=\"" << rule_fsts << "\", "; | 145 | os << "rule_fsts=\"" << rule_fsts << "\", "; |
| 140 | - os << "rule_fars=\"" << rule_fars << "\")"; | 146 | + os << "rule_fars=\"" << rule_fars << "\", "; |
| 147 | + os << "hr=" << hr.ToString() << ")"; | ||
| 141 | 148 | ||
| 142 | return os.str(); | 149 | return os.str(); |
| 143 | } | 150 | } |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/features.h" | 12 | #include "sherpa-onnx/csrc/features.h" |
| 13 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 13 | #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | 14 | #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" |
| 14 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 15 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 15 | #include "sherpa-onnx/csrc/offline-model-config.h" | 16 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| @@ -40,6 +41,7 @@ struct OfflineRecognizerConfig { | @@ -40,6 +41,7 @@ struct OfflineRecognizerConfig { | ||
| 40 | 41 | ||
| 41 | // If there are multiple FST archives, they are applied from left to right. | 42 | // If there are multiple FST archives, they are applied from left to right. |
| 42 | std::string rule_fars; | 43 | std::string rule_fars; |
| 44 | + HomophoneReplacerConfig hr; | ||
| 43 | 45 | ||
| 44 | // only greedy_search is implemented | 46 | // only greedy_search is implemented |
| 45 | // TODO(fangjun): Implement modified_beam_search | 47 | // TODO(fangjun): Implement modified_beam_search |
| @@ -52,7 +54,7 @@ struct OfflineRecognizerConfig { | @@ -52,7 +54,7 @@ struct OfflineRecognizerConfig { | ||
| 52 | const std::string &decoding_method, int32_t max_active_paths, | 54 | const std::string &decoding_method, int32_t max_active_paths, |
| 53 | const std::string &hotwords_file, float hotwords_score, | 55 | const std::string &hotwords_file, float hotwords_score, |
| 54 | float blank_penalty, const std::string &rule_fsts, | 56 | float blank_penalty, const std::string &rule_fsts, |
| 55 | - const std::string &rule_fars) | 57 | + const std::string &rule_fars, const HomophoneReplacerConfig &hr) |
| 56 | : feat_config(feat_config), | 58 | : feat_config(feat_config), |
| 57 | model_config(model_config), | 59 | model_config(model_config), |
| 58 | lm_config(lm_config), | 60 | lm_config(lm_config), |
| @@ -63,7 +65,8 @@ struct OfflineRecognizerConfig { | @@ -63,7 +65,8 @@ struct OfflineRecognizerConfig { | ||
| 63 | hotwords_score(hotwords_score), | 65 | hotwords_score(hotwords_score), |
| 64 | blank_penalty(blank_penalty), | 66 | blank_penalty(blank_penalty), |
| 65 | rule_fsts(rule_fsts), | 67 | rule_fsts(rule_fsts), |
| 66 | - rule_fars(rule_fars) {} | 68 | + rule_fars(rule_fars), |
| 69 | + hr(hr) {} | ||
| 67 | 70 | ||
| 68 | void Register(ParseOptions *po); | 71 | void Register(ParseOptions *po); |
| 69 | bool Validate() const; | 72 | bool Validate() const; |
| @@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 201 | auto r = | 201 | auto r = |
| 202 | ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 202 | ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 203 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 203 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 204 | - r.text = ApplyInverseTextNormalization(r.text); | 204 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 205 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 205 | return r; | 206 | return r; |
| 206 | } | 207 | } |
| 207 | 208 |
| @@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) | @@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) | ||
| 192 | SHERPA_ONNX_LOGE("FST archives loaded!"); | 192 | SHERPA_ONNX_LOGE("FST archives loaded!"); |
| 193 | } | 193 | } |
| 194 | } | 194 | } |
| 195 | + | ||
| 196 | + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && | ||
| 197 | + !config.hr.rule_fsts.empty()) { | ||
| 198 | + auto hr_config = config.hr; | ||
| 199 | + hr_config.debug = config.model_config.debug; | ||
| 200 | + hr_ = std::make_unique<HomophoneReplacer>(hr_config); | ||
| 201 | + } | ||
| 195 | } | 202 | } |
| 196 | 203 | ||
| 197 | template <typename Manager> | 204 | template <typename Manager> |
| @@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, | @@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, | ||
| 239 | } // for (; !reader->Done(); reader->Next()) | 246 | } // for (; !reader->Done(); reader->Next()) |
| 240 | } // for (const auto &f : files) | 247 | } // for (const auto &f : files) |
| 241 | } // if (!config.rule_fars.empty()) | 248 | } // if (!config.rule_fars.empty()) |
| 249 | + if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() && | ||
| 250 | + !config.hr.rule_fsts.empty()) { | ||
| 251 | + auto hr_config = config.hr; | ||
| 252 | + hr_config.debug = config.model_config.debug; | ||
| 253 | + hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config); | ||
| 254 | + } | ||
| 242 | } | 255 | } |
| 243 | 256 | ||
| 244 | std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( | 257 | std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( |
| @@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( | @@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( | ||
| 254 | return text; | 267 | return text; |
| 255 | } | 268 | } |
| 256 | 269 | ||
| 270 | +std::string OnlineRecognizerImpl::ApplyHomophoneReplacer( | ||
| 271 | + std::string text) const { | ||
| 272 | + if (hr_) { | ||
| 273 | + text = hr_->Apply(text); | ||
| 274 | + } | ||
| 275 | + | ||
| 276 | + return text; | ||
| 277 | +} | ||
| 278 | + | ||
| 257 | #if __ANDROID_API__ >= 9 | 279 | #if __ANDROID_API__ >= 9 |
| 258 | template OnlineRecognizerImpl::OnlineRecognizerImpl( | 280 | template OnlineRecognizerImpl::OnlineRecognizerImpl( |
| 259 | AAssetManager *mgr, const OnlineRecognizerConfig &config); | 281 | AAssetManager *mgr, const OnlineRecognizerConfig &config); |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | #include "kaldifst/csrc/text-normalizer.h" | 12 | #include "kaldifst/csrc/text-normalizer.h" |
| 13 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 13 | #include "sherpa-onnx/csrc/macros.h" | 14 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/online-recognizer.h" | 15 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 15 | #include "sherpa-onnx/csrc/online-stream.h" | 16 | #include "sherpa-onnx/csrc/online-stream.h" |
| @@ -57,6 +58,7 @@ class OnlineRecognizerImpl { | @@ -57,6 +58,7 @@ class OnlineRecognizerImpl { | ||
| 57 | virtual void Reset(OnlineStream *s) const = 0; | 58 | virtual void Reset(OnlineStream *s) const = 0; |
| 58 | 59 | ||
| 59 | std::string ApplyInverseTextNormalization(std::string text) const; | 60 | std::string ApplyInverseTextNormalization(std::string text) const; |
| 61 | + std::string ApplyHomophoneReplacer(std::string text) const; | ||
| 60 | 62 | ||
| 61 | private: | 63 | private: |
| 62 | OnlineRecognizerConfig config_; | 64 | OnlineRecognizerConfig config_; |
| @@ -64,6 +66,7 @@ class OnlineRecognizerImpl { | @@ -64,6 +66,7 @@ class OnlineRecognizerImpl { | ||
| 64 | // config.rule_fsts is not empty or | 66 | // config.rule_fsts is not empty or |
| 65 | // config.rule_fars is not empty | 67 | // config.rule_fars is not empty |
| 66 | std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; | 68 | std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_; |
| 69 | + std::unique_ptr<HomophoneReplacer> hr_; | ||
| 67 | }; | 70 | }; |
| 68 | 71 | ||
| 69 | } // namespace sherpa_onnx | 72 | } // namespace sherpa_onnx |
| @@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | @@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 169 | auto decoder_result = s->GetParaformerResult(); | 169 | auto decoder_result = s->GetParaformerResult(); |
| 170 | 170 | ||
| 171 | auto r = Convert(decoder_result, sym_); | 171 | auto r = Convert(decoder_result, sym_); |
| 172 | - r.text = ApplyInverseTextNormalization(r.text); | 172 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 173 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 173 | return r; | 174 | return r; |
| 174 | } | 175 | } |
| 175 | 176 |
| @@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 349 | auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 349 | auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 350 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 350 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 351 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 351 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 352 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 352 | return r; | 353 | return r; |
| 353 | } | 354 | } |
| 354 | 355 | ||
| @@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 391 | // (the encoder state buffers are kept) | 392 | // (the encoder state buffers are kept) |
| 392 | for (const auto &it : last_result.hyps) { | 393 | for (const auto &it : last_result.hyps) { |
| 393 | auto h = it.second; | 394 | auto h = it.second; |
| 394 | - r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, | ||
| 395 | - h.ys.end()), | 395 | + r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, h.ys.end()), |
| 396 | h.log_prob}); | 396 | h.log_prob}); |
| 397 | } | 397 | } |
| 398 | 398 | ||
| 399 | - r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, | ||
| 400 | - last_result.tokens.end()); | 399 | + r.tokens = std::vector<int64_t>(last_result.tokens.end() - context_size, |
| 400 | + last_result.tokens.end()); | ||
| 401 | } else { | 401 | } else { |
| 402 | - if(config_.reset_encoder) { | 402 | + if (config_.reset_encoder) { |
| 403 | // reset encoder states, use blanks as 'ys' context | 403 | // reset encoder states, use blanks as 'ys' context |
| 404 | s->SetStates(model_->GetEncoderInitStates()); | 404 | s->SetStates(model_->GetEncoderInitStates()); |
| 405 | } | 405 | } |
| @@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | @@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { | ||
| 100 | subsampling_factor, s->GetCurrentSegment(), | 100 | subsampling_factor, s->GetCurrentSegment(), |
| 101 | s->GetNumFramesSinceStart()); | 101 | s->GetNumFramesSinceStart()); |
| 102 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 102 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 103 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 103 | return r; | 104 | return r; |
| 104 | } | 105 | } |
| 105 | 106 |
| @@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 88 | endpoint_config.Register(po); | 88 | endpoint_config.Register(po); |
| 89 | lm_config.Register(po); | 89 | lm_config.Register(po); |
| 90 | ctc_fst_decoder_config.Register(po); | 90 | ctc_fst_decoder_config.Register(po); |
| 91 | + hr.Register(po); | ||
| 91 | 92 | ||
| 92 | po->Register("enable-endpoint", &enable_endpoint, | 93 | po->Register("enable-endpoint", &enable_endpoint, |
| 93 | "True to enable endpoint detection. False to disable it."); | 94 | "True to enable endpoint detection. False to disable it."); |
| @@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const { | @@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const { | ||
| 182 | } | 183 | } |
| 183 | } | 184 | } |
| 184 | 185 | ||
| 186 | + if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() && | ||
| 187 | + !hr.Validate()) { | ||
| 188 | + return false; | ||
| 189 | + } | ||
| 190 | + | ||
| 185 | return model_config.Validate(); | 191 | return model_config.Validate(); |
| 186 | } | 192 | } |
| 187 | 193 | ||
| @@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 203 | os << "temperature_scale=" << temperature_scale << ", "; | 209 | os << "temperature_scale=" << temperature_scale << ", "; |
| 204 | os << "rule_fsts=\"" << rule_fsts << "\", "; | 210 | os << "rule_fsts=\"" << rule_fsts << "\", "; |
| 205 | os << "rule_fars=\"" << rule_fars << "\", "; | 211 | os << "rule_fars=\"" << rule_fars << "\", "; |
| 206 | - os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")"; | 212 | + os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", "; |
| 213 | + os << "hr=" << hr.ToString() << ")"; | ||
| 207 | 214 | ||
| 208 | return os.str(); | 215 | return os.str(); |
| 209 | } | 216 | } |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/endpoint.h" | 12 | #include "sherpa-onnx/csrc/endpoint.h" |
| 13 | #include "sherpa-onnx/csrc/features.h" | 13 | #include "sherpa-onnx/csrc/features.h" |
| 14 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 14 | #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" | 15 | #include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" |
| 15 | #include "sherpa-onnx/csrc/online-lm-config.h" | 16 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 16 | #include "sherpa-onnx/csrc/online-model-config.h" | 17 | #include "sherpa-onnx/csrc/online-model-config.h" |
| @@ -107,6 +108,8 @@ struct OnlineRecognizerConfig { | @@ -107,6 +108,8 @@ struct OnlineRecognizerConfig { | ||
| 107 | // currently only in `OnlineRecognizerTransducerImpl`. | 108 | // currently only in `OnlineRecognizerTransducerImpl`. |
| 108 | bool reset_encoder = false; | 109 | bool reset_encoder = false; |
| 109 | 110 | ||
| 111 | + HomophoneReplacerConfig hr; | ||
| 112 | + | ||
| 110 | /// used only for modified_beam_search, if hotwords_buf is non-empty, | 113 | /// used only for modified_beam_search, if hotwords_buf is non-empty, |
| 111 | /// the hotwords will be loaded from the buffered string instead of from the | 114 | /// the hotwords will be loaded from the buffered string instead of from the |
| 112 | /// "hotwords_file" | 115 | /// "hotwords_file" |
| @@ -123,7 +126,7 @@ struct OnlineRecognizerConfig { | @@ -123,7 +126,7 @@ struct OnlineRecognizerConfig { | ||
| 123 | int32_t max_active_paths, const std::string &hotwords_file, | 126 | int32_t max_active_paths, const std::string &hotwords_file, |
| 124 | float hotwords_score, float blank_penalty, float temperature_scale, | 127 | float hotwords_score, float blank_penalty, float temperature_scale, |
| 125 | const std::string &rule_fsts, const std::string &rule_fars, | 128 | const std::string &rule_fsts, const std::string &rule_fars, |
| 126 | - bool reset_encoder) | 129 | + bool reset_encoder, const HomophoneReplacerConfig &hr) |
| 127 | : feat_config(feat_config), | 130 | : feat_config(feat_config), |
| 128 | model_config(model_config), | 131 | model_config(model_config), |
| 129 | lm_config(lm_config), | 132 | lm_config(lm_config), |
| @@ -138,7 +141,8 @@ struct OnlineRecognizerConfig { | @@ -138,7 +141,8 @@ struct OnlineRecognizerConfig { | ||
| 138 | temperature_scale(temperature_scale), | 141 | temperature_scale(temperature_scale), |
| 139 | rule_fsts(rule_fsts), | 142 | rule_fsts(rule_fsts), |
| 140 | rule_fars(rule_fars), | 143 | rule_fars(rule_fars), |
| 141 | - reset_encoder(reset_encoder) {} | 144 | + reset_encoder(reset_encoder), |
| 145 | + hr(hr) {} | ||
| 142 | 146 | ||
| 143 | void Register(ParseOptions *po); | 147 | void Register(ParseOptions *po); |
| 144 | bool Validate() const; | 148 | bool Validate() const; |
| @@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { | @@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { | ||
| 89 | auto r = | 89 | auto r = |
| 90 | ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 90 | ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 91 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 91 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 92 | - r.text = ApplyInverseTextNormalization(r.text); | 92 | + r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 93 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 93 | return r; | 94 | return r; |
| 94 | } | 95 | } |
| 95 | 96 |
| @@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | @@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { | ||
| 177 | auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | 177 | auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 178 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | 178 | s->GetCurrentSegment(), s->GetNumFramesSinceStart()); |
| 179 | r.text = ApplyInverseTextNormalization(std::move(r.text)); | 179 | r.text = ApplyInverseTextNormalization(std::move(r.text)); |
| 180 | + r.text = ApplyHomophoneReplacer(std::move(r.text)); | ||
| 180 | return r; | 181 | return r; |
| 181 | } | 182 | } |
| 182 | 183 |
| @@ -7,6 +7,7 @@ set(srcs | @@ -7,6 +7,7 @@ set(srcs | ||
| 7 | display.cc | 7 | display.cc |
| 8 | endpoint.cc | 8 | endpoint.cc |
| 9 | features.cc | 9 | features.cc |
| 10 | + homophone-replacer.cc | ||
| 10 | keyword-spotter.cc | 11 | keyword-spotter.cc |
| 11 | offline-ctc-fst-decoder-config.cc | 12 | offline-ctc-fst-decoder-config.cc |
| 12 | offline-dolphin-model-config.cc | 13 | offline-dolphin-model-config.cc |
| 1 | +// sherpa-onnx/python/csrc/homophone-replacer.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/homophone-replacer.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/homophone-replacer.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindHomophoneReplacer(py::module *m) { | ||
| 14 | + using PyClass = HomophoneReplacerConfig; | ||
| 15 | + py::class_<PyClass>(*m, "HomophoneReplacerConfig") | ||
| 16 | + .def(py::init<>()) | ||
| 17 | + .def(py::init<const std::string &, const std::string &, | ||
| 18 | + const std::string &, bool>(), | ||
| 19 | + py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"), | ||
| 20 | + py::arg("debug") = false) | ||
| 21 | + .def_readwrite("dict_dir", &PyClass::dict_dir) | ||
| 22 | + .def_readwrite("lexicon", &PyClass::lexicon) | ||
| 23 | + .def_readwrite("rule_fsts", &PyClass::rule_fsts) | ||
| 24 | + .def_readwrite("debug", &PyClass::debug) | ||
| 25 | + .def("__str__", &PyClass::ToString); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/homophone-replacer.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/homophone-replacer.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindHomophoneReplacer(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_ |
| @@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 17 | .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, | 17 | .def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &, |
| 18 | const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, | 18 | const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &, |
| 19 | const std::string &, int32_t, const std::string &, float, | 19 | const std::string &, int32_t, const std::string &, float, |
| 20 | - float, const std::string &, const std::string &>(), | 20 | + float, const std::string &, const std::string &, |
| 21 | + const HomophoneReplacerConfig &>(), | ||
| 21 | py::arg("feat_config"), py::arg("model_config"), | 22 | py::arg("feat_config"), py::arg("model_config"), |
| 22 | py::arg("lm_config") = OfflineLMConfig(), | 23 | py::arg("lm_config") = OfflineLMConfig(), |
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 24 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| 24 | py::arg("decoding_method") = "greedy_search", | 25 | py::arg("decoding_method") = "greedy_search", |
| 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 26 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 26 | py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, | 27 | py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, |
| 27 | - py::arg("rule_fsts") = "", py::arg("rule_fars") = "") | 28 | + py::arg("rule_fsts") = "", py::arg("rule_fars") = "", |
| 29 | + py::arg("hr") = HomophoneReplacerConfig{}) | ||
| 28 | .def_readwrite("feat_config", &PyClass::feat_config) | 30 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 29 | .def_readwrite("model_config", &PyClass::model_config) | 31 | .def_readwrite("model_config", &PyClass::model_config) |
| 30 | .def_readwrite("lm_config", &PyClass::lm_config) | 32 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 36 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) | 38 | .def_readwrite("blank_penalty", &PyClass::blank_penalty) |
| 37 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) | 39 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) |
| 38 | .def_readwrite("rule_fars", &PyClass::rule_fars) | 40 | .def_readwrite("rule_fars", &PyClass::rule_fars) |
| 41 | + .def_readwrite("hr", &PyClass::hr) | ||
| 39 | .def("__str__", &PyClass::ToString); | 42 | .def("__str__", &PyClass::ToString); |
| 40 | } | 43 | } |
| 41 | 44 |
| @@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 58 | const OnlineLMConfig &, const EndpointConfig &, | 58 | const OnlineLMConfig &, const EndpointConfig &, |
| 59 | const OnlineCtcFstDecoderConfig &, bool, | 59 | const OnlineCtcFstDecoderConfig &, bool, |
| 60 | const std::string &, int32_t, const std::string &, float, | 60 | const std::string &, int32_t, const std::string &, float, |
| 61 | - float, float, const std::string &, const std::string &, bool>(), | 61 | + float, float, const std::string &, const std::string &, |
| 62 | + bool, const HomophoneReplacerConfig &>(), | ||
| 62 | py::arg("feat_config"), py::arg("model_config"), | 63 | py::arg("feat_config"), py::arg("model_config"), |
| 63 | py::arg("lm_config") = OnlineLMConfig(), | 64 | py::arg("lm_config") = OnlineLMConfig(), |
| 64 | py::arg("endpoint_config") = EndpointConfig(), | 65 | py::arg("endpoint_config") = EndpointConfig(), |
| @@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 67 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 68 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 68 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, | 69 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, |
| 69 | py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", | 70 | py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", |
| 70 | - py::arg("rule_fars") = "", py::arg("reset_encoder") = false) | 71 | + py::arg("rule_fars") = "", py::arg("reset_encoder") = false, |
| 72 | + py::arg("hr") = HomophoneReplacerConfig{}) | ||
| 71 | .def_readwrite("feat_config", &PyClass::feat_config) | 73 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 72 | .def_readwrite("model_config", &PyClass::model_config) | 74 | .def_readwrite("model_config", &PyClass::model_config) |
| 73 | .def_readwrite("lm_config", &PyClass::lm_config) | 75 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 83 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) | 85 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) |
| 84 | .def_readwrite("rule_fars", &PyClass::rule_fars) | 86 | .def_readwrite("rule_fars", &PyClass::rule_fars) |
| 85 | .def_readwrite("reset_encoder", &PyClass::reset_encoder) | 87 | .def_readwrite("reset_encoder", &PyClass::reset_encoder) |
| 88 | + .def_readwrite("hr", &PyClass::hr) | ||
| 86 | .def("__str__", &PyClass::ToString); | 89 | .def("__str__", &PyClass::ToString); |
| 87 | } | 90 | } |
| 88 | 91 |
| @@ -10,6 +10,7 @@ | @@ -10,6 +10,7 @@ | ||
| 10 | #include "sherpa-onnx/python/csrc/display.h" | 10 | #include "sherpa-onnx/python/csrc/display.h" |
| 11 | #include "sherpa-onnx/python/csrc/endpoint.h" | 11 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 12 | #include "sherpa-onnx/python/csrc/features.h" | 12 | #include "sherpa-onnx/python/csrc/features.h" |
| 13 | +#include "sherpa-onnx/python/csrc/homophone-replacer.h" | ||
| 13 | #include "sherpa-onnx/python/csrc/keyword-spotter.h" | 14 | #include "sherpa-onnx/python/csrc/keyword-spotter.h" |
| 14 | #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" | 15 | #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" |
| 15 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" | 16 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" |
| @@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 51 | PybindAudioTagging(&m); | 52 | PybindAudioTagging(&m); |
| 52 | PybindOfflinePunctuation(&m); | 53 | PybindOfflinePunctuation(&m); |
| 53 | PybindOnlinePunctuation(&m); | 54 | PybindOnlinePunctuation(&m); |
| 55 | + PybindHomophoneReplacer(&m); | ||
| 54 | 56 | ||
| 55 | PybindFeatures(&m); | 57 | PybindFeatures(&m); |
| 56 | PybindOnlineCtcFstDecoderConfig(&m); | 58 | PybindOnlineCtcFstDecoderConfig(&m); |
| @@ -5,6 +5,7 @@ from typing import List, Optional | @@ -5,6 +5,7 @@ from typing import List, Optional | ||
| 5 | 5 | ||
| 6 | from _sherpa_onnx import ( | 6 | from _sherpa_onnx import ( |
| 7 | FeatureExtractorConfig, | 7 | FeatureExtractorConfig, |
| 8 | + HomophoneReplacerConfig, | ||
| 8 | OfflineCtcFstDecoderConfig, | 9 | OfflineCtcFstDecoderConfig, |
| 9 | OfflineDolphinModelConfig, | 10 | OfflineDolphinModelConfig, |
| 10 | OfflineFireRedAsrModelConfig, | 11 | OfflineFireRedAsrModelConfig, |
| @@ -64,6 +65,9 @@ class OfflineRecognizer(object): | @@ -64,6 +65,9 @@ class OfflineRecognizer(object): | ||
| 64 | rule_fars: str = "", | 65 | rule_fars: str = "", |
| 65 | lm: str = "", | 66 | lm: str = "", |
| 66 | lm_scale: float = 0.1, | 67 | lm_scale: float = 0.1, |
| 68 | + hr_dict_dir: str = "", | ||
| 69 | + hr_rule_fsts: str = "", | ||
| 70 | + hr_lexicon: str = "", | ||
| 67 | ): | 71 | ): |
| 68 | """ | 72 | """ |
| 69 | Please refer to | 73 | Please refer to |
| @@ -181,6 +185,11 @@ class OfflineRecognizer(object): | @@ -181,6 +185,11 @@ class OfflineRecognizer(object): | ||
| 181 | blank_penalty=blank_penalty, | 185 | blank_penalty=blank_penalty, |
| 182 | rule_fsts=rule_fsts, | 186 | rule_fsts=rule_fsts, |
| 183 | rule_fars=rule_fars, | 187 | rule_fars=rule_fars, |
| 188 | + hr=HomophoneReplacerConfig( | ||
| 189 | + dict_dir=hr_dict_dir, | ||
| 190 | + lexicon=hr_lexicon, | ||
| 191 | + rule_fsts=hr_rule_fsts, | ||
| 192 | + ), | ||
| 184 | ) | 193 | ) |
| 185 | self.recognizer = _Recognizer(recognizer_config) | 194 | self.recognizer = _Recognizer(recognizer_config) |
| 186 | self.config = recognizer_config | 195 | self.config = recognizer_config |
| @@ -201,6 +210,9 @@ class OfflineRecognizer(object): | @@ -201,6 +210,9 @@ class OfflineRecognizer(object): | ||
| 201 | use_itn: bool = False, | 210 | use_itn: bool = False, |
| 202 | rule_fsts: str = "", | 211 | rule_fsts: str = "", |
| 203 | rule_fars: str = "", | 212 | rule_fars: str = "", |
| 213 | + hr_dict_dir: str = "", | ||
| 214 | + hr_rule_fsts: str = "", | ||
| 215 | + hr_lexicon: str = "", | ||
| 204 | ): | 216 | ): |
| 205 | """ | 217 | """ |
| 206 | Please refer to | 218 | Please refer to |
| @@ -263,6 +275,11 @@ class OfflineRecognizer(object): | @@ -263,6 +275,11 @@ class OfflineRecognizer(object): | ||
| 263 | decoding_method=decoding_method, | 275 | decoding_method=decoding_method, |
| 264 | rule_fsts=rule_fsts, | 276 | rule_fsts=rule_fsts, |
| 265 | rule_fars=rule_fars, | 277 | rule_fars=rule_fars, |
| 278 | + hr=HomophoneReplacerConfig( | ||
| 279 | + dict_dir=hr_dict_dir, | ||
| 280 | + lexicon=hr_lexicon, | ||
| 281 | + rule_fsts=hr_rule_fsts, | ||
| 282 | + ), | ||
| 266 | ) | 283 | ) |
| 267 | self.recognizer = _Recognizer(recognizer_config) | 284 | self.recognizer = _Recognizer(recognizer_config) |
| 268 | self.config = recognizer_config | 285 | self.config = recognizer_config |
| @@ -281,6 +298,9 @@ class OfflineRecognizer(object): | @@ -281,6 +298,9 @@ class OfflineRecognizer(object): | ||
| 281 | provider: str = "cpu", | 298 | provider: str = "cpu", |
| 282 | rule_fsts: str = "", | 299 | rule_fsts: str = "", |
| 283 | rule_fars: str = "", | 300 | rule_fars: str = "", |
| 301 | + hr_dict_dir: str = "", | ||
| 302 | + hr_rule_fsts: str = "", | ||
| 303 | + hr_lexicon: str = "", | ||
| 284 | ): | 304 | ): |
| 285 | """ | 305 | """ |
| 286 | Please refer to | 306 | Please refer to |
| @@ -336,6 +356,11 @@ class OfflineRecognizer(object): | @@ -336,6 +356,11 @@ class OfflineRecognizer(object): | ||
| 336 | decoding_method=decoding_method, | 356 | decoding_method=decoding_method, |
| 337 | rule_fsts=rule_fsts, | 357 | rule_fsts=rule_fsts, |
| 338 | rule_fars=rule_fars, | 358 | rule_fars=rule_fars, |
| 359 | + hr=HomophoneReplacerConfig( | ||
| 360 | + dict_dir=hr_dict_dir, | ||
| 361 | + lexicon=hr_lexicon, | ||
| 362 | + rule_fsts=hr_rule_fsts, | ||
| 363 | + ), | ||
| 339 | ) | 364 | ) |
| 340 | self.recognizer = _Recognizer(recognizer_config) | 365 | self.recognizer = _Recognizer(recognizer_config) |
| 341 | self.config = recognizer_config | 366 | self.config = recognizer_config |
| @@ -354,6 +379,9 @@ class OfflineRecognizer(object): | @@ -354,6 +379,9 @@ class OfflineRecognizer(object): | ||
| 354 | provider: str = "cpu", | 379 | provider: str = "cpu", |
| 355 | rule_fsts: str = "", | 380 | rule_fsts: str = "", |
| 356 | rule_fars: str = "", | 381 | rule_fars: str = "", |
| 382 | + hr_dict_dir: str = "", | ||
| 383 | + hr_rule_fsts: str = "", | ||
| 384 | + hr_lexicon: str = "", | ||
| 357 | ): | 385 | ): |
| 358 | """ | 386 | """ |
| 359 | Please refer to | 387 | Please refer to |
| @@ -411,6 +439,9 @@ class OfflineRecognizer(object): | @@ -411,6 +439,9 @@ class OfflineRecognizer(object): | ||
| 411 | decoding_method=decoding_method, | 439 | decoding_method=decoding_method, |
| 412 | rule_fsts=rule_fsts, | 440 | rule_fsts=rule_fsts, |
| 413 | rule_fars=rule_fars, | 441 | rule_fars=rule_fars, |
| 442 | + hr=HomophoneReplacerConfig( | ||
| 443 | + dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts | ||
| 444 | + ), | ||
| 414 | ) | 445 | ) |
| 415 | self.recognizer = _Recognizer(recognizer_config) | 446 | self.recognizer = _Recognizer(recognizer_config) |
| 416 | self.config = recognizer_config | 447 | self.config = recognizer_config |
| @@ -429,6 +460,9 @@ class OfflineRecognizer(object): | @@ -429,6 +460,9 @@ class OfflineRecognizer(object): | ||
| 429 | provider: str = "cpu", | 460 | provider: str = "cpu", |
| 430 | rule_fsts: str = "", | 461 | rule_fsts: str = "", |
| 431 | rule_fars: str = "", | 462 | rule_fars: str = "", |
| 463 | + hr_dict_dir: str = "", | ||
| 464 | + hr_rule_fsts: str = "", | ||
| 465 | + hr_lexicon: str = "", | ||
| 432 | ): | 466 | ): |
| 433 | """ | 467 | """ |
| 434 | Please refer to | 468 | Please refer to |
| @@ -483,6 +517,11 @@ class OfflineRecognizer(object): | @@ -483,6 +517,11 @@ class OfflineRecognizer(object): | ||
| 483 | decoding_method=decoding_method, | 517 | decoding_method=decoding_method, |
| 484 | rule_fsts=rule_fsts, | 518 | rule_fsts=rule_fsts, |
| 485 | rule_fars=rule_fars, | 519 | rule_fars=rule_fars, |
| 520 | + hr=HomophoneReplacerConfig( | ||
| 521 | + dict_dir=hr_dict_dir, | ||
| 522 | + lexicon=hr_lexicon, | ||
| 523 | + rule_fsts=hr_rule_fsts, | ||
| 524 | + ), | ||
| 486 | ) | 525 | ) |
| 487 | self.recognizer = _Recognizer(recognizer_config) | 526 | self.recognizer = _Recognizer(recognizer_config) |
| 488 | self.config = recognizer_config | 527 | self.config = recognizer_config |
| @@ -501,6 +540,9 @@ class OfflineRecognizer(object): | @@ -501,6 +540,9 @@ class OfflineRecognizer(object): | ||
| 501 | provider: str = "cpu", | 540 | provider: str = "cpu", |
| 502 | rule_fsts: str = "", | 541 | rule_fsts: str = "", |
| 503 | rule_fars: str = "", | 542 | rule_fars: str = "", |
| 543 | + hr_dict_dir: str = "", | ||
| 544 | + hr_rule_fsts: str = "", | ||
| 545 | + hr_lexicon: str = "", | ||
| 504 | ): | 546 | ): |
| 505 | """ | 547 | """ |
| 506 | Please refer to | 548 | Please refer to |
| @@ -557,6 +599,11 @@ class OfflineRecognizer(object): | @@ -557,6 +599,11 @@ class OfflineRecognizer(object): | ||
| 557 | decoding_method=decoding_method, | 599 | decoding_method=decoding_method, |
| 558 | rule_fsts=rule_fsts, | 600 | rule_fsts=rule_fsts, |
| 559 | rule_fars=rule_fars, | 601 | rule_fars=rule_fars, |
| 602 | + hr=HomophoneReplacerConfig( | ||
| 603 | + dict_dir=hr_dict_dir, | ||
| 604 | + lexicon=hr_lexicon, | ||
| 605 | + rule_fsts=hr_rule_fsts, | ||
| 606 | + ), | ||
| 560 | ) | 607 | ) |
| 561 | self.recognizer = _Recognizer(recognizer_config) | 608 | self.recognizer = _Recognizer(recognizer_config) |
| 562 | self.config = recognizer_config | 609 | self.config = recognizer_config |
| @@ -577,6 +624,9 @@ class OfflineRecognizer(object): | @@ -577,6 +624,9 @@ class OfflineRecognizer(object): | ||
| 577 | tail_paddings: int = -1, | 624 | tail_paddings: int = -1, |
| 578 | rule_fsts: str = "", | 625 | rule_fsts: str = "", |
| 579 | rule_fars: str = "", | 626 | rule_fars: str = "", |
| 627 | + hr_dict_dir: str = "", | ||
| 628 | + hr_rule_fsts: str = "", | ||
| 629 | + hr_lexicon: str = "", | ||
| 580 | ): | 630 | ): |
| 581 | """ | 631 | """ |
| 582 | Please refer to | 632 | Please refer to |
| @@ -647,6 +697,11 @@ class OfflineRecognizer(object): | @@ -647,6 +697,11 @@ class OfflineRecognizer(object): | ||
| 647 | decoding_method=decoding_method, | 697 | decoding_method=decoding_method, |
| 648 | rule_fsts=rule_fsts, | 698 | rule_fsts=rule_fsts, |
| 649 | rule_fars=rule_fars, | 699 | rule_fars=rule_fars, |
| 700 | + hr=HomophoneReplacerConfig( | ||
| 701 | + dict_dir=hr_dict_dir, | ||
| 702 | + lexicon=hr_lexicon, | ||
| 703 | + rule_fsts=hr_rule_fsts, | ||
| 704 | + ), | ||
| 650 | ) | 705 | ) |
| 651 | self.recognizer = _Recognizer(recognizer_config) | 706 | self.recognizer = _Recognizer(recognizer_config) |
| 652 | self.config = recognizer_config | 707 | self.config = recognizer_config |
| @@ -664,6 +719,9 @@ class OfflineRecognizer(object): | @@ -664,6 +719,9 @@ class OfflineRecognizer(object): | ||
| 664 | provider: str = "cpu", | 719 | provider: str = "cpu", |
| 665 | rule_fsts: str = "", | 720 | rule_fsts: str = "", |
| 666 | rule_fars: str = "", | 721 | rule_fars: str = "", |
| 722 | + hr_dict_dir: str = "", | ||
| 723 | + hr_rule_fsts: str = "", | ||
| 724 | + hr_lexicon: str = "", | ||
| 667 | ): | 725 | ): |
| 668 | """ | 726 | """ |
| 669 | Please refer to | 727 | Please refer to |
| @@ -719,6 +777,11 @@ class OfflineRecognizer(object): | @@ -719,6 +777,11 @@ class OfflineRecognizer(object): | ||
| 719 | decoding_method=decoding_method, | 777 | decoding_method=decoding_method, |
| 720 | rule_fsts=rule_fsts, | 778 | rule_fsts=rule_fsts, |
| 721 | rule_fars=rule_fars, | 779 | rule_fars=rule_fars, |
| 780 | + hr=HomophoneReplacerConfig( | ||
| 781 | + dict_dir=hr_dict_dir, | ||
| 782 | + lexicon=hr_lexicon, | ||
| 783 | + rule_fsts=hr_rule_fsts, | ||
| 784 | + ), | ||
| 722 | ) | 785 | ) |
| 723 | self.recognizer = _Recognizer(recognizer_config) | 786 | self.recognizer = _Recognizer(recognizer_config) |
| 724 | self.config = recognizer_config | 787 | self.config = recognizer_config |
| @@ -738,6 +801,9 @@ class OfflineRecognizer(object): | @@ -738,6 +801,9 @@ class OfflineRecognizer(object): | ||
| 738 | provider: str = "cpu", | 801 | provider: str = "cpu", |
| 739 | rule_fsts: str = "", | 802 | rule_fsts: str = "", |
| 740 | rule_fars: str = "", | 803 | rule_fars: str = "", |
| 804 | + hr_dict_dir: str = "", | ||
| 805 | + hr_rule_fsts: str = "", | ||
| 806 | + hr_lexicon: str = "", | ||
| 741 | ): | 807 | ): |
| 742 | """ | 808 | """ |
| 743 | Please refer to | 809 | Please refer to |
| @@ -800,6 +866,11 @@ class OfflineRecognizer(object): | @@ -800,6 +866,11 @@ class OfflineRecognizer(object): | ||
| 800 | decoding_method=decoding_method, | 866 | decoding_method=decoding_method, |
| 801 | rule_fsts=rule_fsts, | 867 | rule_fsts=rule_fsts, |
| 802 | rule_fars=rule_fars, | 868 | rule_fars=rule_fars, |
| 869 | + hr=HomophoneReplacerConfig( | ||
| 870 | + dict_dir=hr_dict_dir, | ||
| 871 | + lexicon=hr_lexicon, | ||
| 872 | + rule_fsts=hr_rule_fsts, | ||
| 873 | + ), | ||
| 803 | ) | 874 | ) |
| 804 | self.recognizer = _Recognizer(recognizer_config) | 875 | self.recognizer = _Recognizer(recognizer_config) |
| 805 | self.config = recognizer_config | 876 | self.config = recognizer_config |
| @@ -818,6 +889,9 @@ class OfflineRecognizer(object): | @@ -818,6 +889,9 @@ class OfflineRecognizer(object): | ||
| 818 | provider: str = "cpu", | 889 | provider: str = "cpu", |
| 819 | rule_fsts: str = "", | 890 | rule_fsts: str = "", |
| 820 | rule_fars: str = "", | 891 | rule_fars: str = "", |
| 892 | + hr_dict_dir: str = "", | ||
| 893 | + hr_rule_fsts: str = "", | ||
| 894 | + hr_lexicon: str = "", | ||
| 821 | ): | 895 | ): |
| 822 | """ | 896 | """ |
| 823 | Please refer to | 897 | Please refer to |
| @@ -873,6 +947,11 @@ class OfflineRecognizer(object): | @@ -873,6 +947,11 @@ class OfflineRecognizer(object): | ||
| 873 | decoding_method=decoding_method, | 947 | decoding_method=decoding_method, |
| 874 | rule_fsts=rule_fsts, | 948 | rule_fsts=rule_fsts, |
| 875 | rule_fars=rule_fars, | 949 | rule_fars=rule_fars, |
| 950 | + hr=HomophoneReplacerConfig( | ||
| 951 | + dict_dir=hr_dict_dir, | ||
| 952 | + lexicon=hr_lexicon, | ||
| 953 | + rule_fsts=hr_rule_fsts, | ||
| 954 | + ), | ||
| 876 | ) | 955 | ) |
| 877 | self.recognizer = _Recognizer(recognizer_config) | 956 | self.recognizer = _Recognizer(recognizer_config) |
| 878 | self.config = recognizer_config | 957 | self.config = recognizer_config |
| @@ -891,6 +970,9 @@ class OfflineRecognizer(object): | @@ -891,6 +970,9 @@ class OfflineRecognizer(object): | ||
| 891 | provider: str = "cpu", | 970 | provider: str = "cpu", |
| 892 | rule_fsts: str = "", | 971 | rule_fsts: str = "", |
| 893 | rule_fars: str = "", | 972 | rule_fars: str = "", |
| 973 | + hr_dict_dir: str = "", | ||
| 974 | + hr_rule_fsts: str = "", | ||
| 975 | + hr_lexicon: str = "", | ||
| 894 | ): | 976 | ): |
| 895 | """ | 977 | """ |
| 896 | Please refer to | 978 | Please refer to |
| @@ -947,6 +1029,11 @@ class OfflineRecognizer(object): | @@ -947,6 +1029,11 @@ class OfflineRecognizer(object): | ||
| 947 | decoding_method=decoding_method, | 1029 | decoding_method=decoding_method, |
| 948 | rule_fsts=rule_fsts, | 1030 | rule_fsts=rule_fsts, |
| 949 | rule_fars=rule_fars, | 1031 | rule_fars=rule_fars, |
| 1032 | + hr=HomophoneReplacerConfig( | ||
| 1033 | + dict_dir=hr_dict_dir, | ||
| 1034 | + lexicon=hr_lexicon, | ||
| 1035 | + rule_fsts=hr_rule_fsts, | ||
| 1036 | + ), | ||
| 950 | ) | 1037 | ) |
| 951 | self.recognizer = _Recognizer(recognizer_config) | 1038 | self.recognizer = _Recognizer(recognizer_config) |
| 952 | self.config = recognizer_config | 1039 | self.config = recognizer_config |
| @@ -3,25 +3,26 @@ from pathlib import Path | @@ -3,25 +3,26 @@ from pathlib import Path | ||
| 3 | from typing import List, Optional | 3 | from typing import List, Optional |
| 4 | 4 | ||
| 5 | from _sherpa_onnx import ( | 5 | from _sherpa_onnx import ( |
| 6 | + CudaConfig, | ||
| 6 | EndpointConfig, | 7 | EndpointConfig, |
| 7 | FeatureExtractorConfig, | 8 | FeatureExtractorConfig, |
| 9 | + HomophoneReplacerConfig, | ||
| 10 | + OnlineCtcFstDecoderConfig, | ||
| 8 | OnlineLMConfig, | 11 | OnlineLMConfig, |
| 9 | OnlineModelConfig, | 12 | OnlineModelConfig, |
| 13 | + OnlineNeMoCtcModelConfig, | ||
| 10 | OnlineParaformerModelConfig, | 14 | OnlineParaformerModelConfig, |
| 11 | ) | 15 | ) |
| 12 | from _sherpa_onnx import OnlineRecognizer as _Recognizer | 16 | from _sherpa_onnx import OnlineRecognizer as _Recognizer |
| 13 | from _sherpa_onnx import ( | 17 | from _sherpa_onnx import ( |
| 14 | - CudaConfig, | ||
| 15 | - TensorrtConfig, | ||
| 16 | - ProviderConfig, | ||
| 17 | OnlineRecognizerConfig, | 18 | OnlineRecognizerConfig, |
| 18 | OnlineRecognizerResult, | 19 | OnlineRecognizerResult, |
| 19 | OnlineStream, | 20 | OnlineStream, |
| 20 | OnlineTransducerModelConfig, | 21 | OnlineTransducerModelConfig, |
| 21 | OnlineWenetCtcModelConfig, | 22 | OnlineWenetCtcModelConfig, |
| 22 | - OnlineNeMoCtcModelConfig, | ||
| 23 | OnlineZipformer2CtcModelConfig, | 23 | OnlineZipformer2CtcModelConfig, |
| 24 | - OnlineCtcFstDecoderConfig, | 24 | + ProviderConfig, |
| 25 | + TensorrtConfig, | ||
| 25 | ) | 26 | ) |
| 26 | 27 | ||
| 27 | 28 | ||
| @@ -82,9 +83,12 @@ class OnlineRecognizer(object): | @@ -82,9 +83,12 @@ class OnlineRecognizer(object): | ||
| 82 | trt_detailed_build_log: bool = False, | 83 | trt_detailed_build_log: bool = False, |
| 83 | trt_engine_cache_enable: bool = True, | 84 | trt_engine_cache_enable: bool = True, |
| 84 | trt_timing_cache_enable: bool = True, | 85 | trt_timing_cache_enable: bool = True, |
| 85 | - trt_engine_cache_path: str ="", | ||
| 86 | - trt_timing_cache_path: str ="", | 86 | + trt_engine_cache_path: str = "", |
| 87 | + trt_timing_cache_path: str = "", | ||
| 87 | trt_dump_subgraphs: bool = False, | 88 | trt_dump_subgraphs: bool = False, |
| 89 | + hr_dict_dir: str = "", | ||
| 90 | + hr_rule_fsts: str = "", | ||
| 91 | + hr_lexicon: str = "", | ||
| 88 | ): | 92 | ): |
| 89 | """ | 93 | """ |
| 90 | Please refer to | 94 | Please refer to |
| @@ -228,27 +232,27 @@ class OnlineRecognizer(object): | @@ -228,27 +232,27 @@ class OnlineRecognizer(object): | ||
| 228 | ) | 232 | ) |
| 229 | 233 | ||
| 230 | cuda_config = CudaConfig( | 234 | cuda_config = CudaConfig( |
| 231 | - cudnn_conv_algo_search=cudnn_conv_algo_search, | 235 | + cudnn_conv_algo_search=cudnn_conv_algo_search, |
| 232 | ) | 236 | ) |
| 233 | 237 | ||
| 234 | trt_config = TensorrtConfig( | 238 | trt_config = TensorrtConfig( |
| 235 | - trt_max_workspace_size=trt_max_workspace_size, | ||
| 236 | - trt_max_partition_iterations=trt_max_partition_iterations, | ||
| 237 | - trt_min_subgraph_size=trt_min_subgraph_size, | ||
| 238 | - trt_fp16_enable=trt_fp16_enable, | ||
| 239 | - trt_detailed_build_log=trt_detailed_build_log, | ||
| 240 | - trt_engine_cache_enable=trt_engine_cache_enable, | ||
| 241 | - trt_timing_cache_enable=trt_timing_cache_enable, | ||
| 242 | - trt_engine_cache_path=trt_engine_cache_path, | ||
| 243 | - trt_timing_cache_path=trt_timing_cache_path, | ||
| 244 | - trt_dump_subgraphs=trt_dump_subgraphs, | 239 | + trt_max_workspace_size=trt_max_workspace_size, |
| 240 | + trt_max_partition_iterations=trt_max_partition_iterations, | ||
| 241 | + trt_min_subgraph_size=trt_min_subgraph_size, | ||
| 242 | + trt_fp16_enable=trt_fp16_enable, | ||
| 243 | + trt_detailed_build_log=trt_detailed_build_log, | ||
| 244 | + trt_engine_cache_enable=trt_engine_cache_enable, | ||
| 245 | + trt_timing_cache_enable=trt_timing_cache_enable, | ||
| 246 | + trt_engine_cache_path=trt_engine_cache_path, | ||
| 247 | + trt_timing_cache_path=trt_timing_cache_path, | ||
| 248 | + trt_dump_subgraphs=trt_dump_subgraphs, | ||
| 245 | ) | 249 | ) |
| 246 | 250 | ||
| 247 | provider_config = ProviderConfig( | 251 | provider_config = ProviderConfig( |
| 248 | - trt_config=trt_config, | ||
| 249 | - cuda_config=cuda_config, | ||
| 250 | - provider=provider, | ||
| 251 | - device=device, | 252 | + trt_config=trt_config, |
| 253 | + cuda_config=cuda_config, | ||
| 254 | + provider=provider, | ||
| 255 | + device=device, | ||
| 252 | ) | 256 | ) |
| 253 | 257 | ||
| 254 | model_config = OnlineModelConfig( | 258 | model_config = OnlineModelConfig( |
| @@ -311,6 +315,11 @@ class OnlineRecognizer(object): | @@ -311,6 +315,11 @@ class OnlineRecognizer(object): | ||
| 311 | rule_fsts=rule_fsts, | 315 | rule_fsts=rule_fsts, |
| 312 | rule_fars=rule_fars, | 316 | rule_fars=rule_fars, |
| 313 | reset_encoder=reset_encoder, | 317 | reset_encoder=reset_encoder, |
| 318 | + hr=HomophoneReplacerConfig( | ||
| 319 | + dict_dir=hr_dict_dir, | ||
| 320 | + lexicon=hr_lexicon, | ||
| 321 | + rule_fsts=hr_rule_fsts, | ||
| 322 | + ), | ||
| 314 | ) | 323 | ) |
| 315 | 324 | ||
| 316 | self.recognizer = _Recognizer(recognizer_config) | 325 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -336,6 +345,9 @@ class OnlineRecognizer(object): | @@ -336,6 +345,9 @@ class OnlineRecognizer(object): | ||
| 336 | rule_fsts: str = "", | 345 | rule_fsts: str = "", |
| 337 | rule_fars: str = "", | 346 | rule_fars: str = "", |
| 338 | device: int = 0, | 347 | device: int = 0, |
| 348 | + hr_dict_dir: str = "", | ||
| 349 | + hr_rule_fsts: str = "", | ||
| 350 | + hr_lexicon: str = "", | ||
| 339 | ): | 351 | ): |
| 340 | """ | 352 | """ |
| 341 | Please refer to | 353 | Please refer to |
| @@ -402,8 +414,8 @@ class OnlineRecognizer(object): | @@ -402,8 +414,8 @@ class OnlineRecognizer(object): | ||
| 402 | ) | 414 | ) |
| 403 | 415 | ||
| 404 | provider_config = ProviderConfig( | 416 | provider_config = ProviderConfig( |
| 405 | - provider=provider, | ||
| 406 | - device=device, | 417 | + provider=provider, |
| 418 | + device=device, | ||
| 407 | ) | 419 | ) |
| 408 | 420 | ||
| 409 | model_config = OnlineModelConfig( | 421 | model_config = OnlineModelConfig( |
| @@ -434,6 +446,11 @@ class OnlineRecognizer(object): | @@ -434,6 +446,11 @@ class OnlineRecognizer(object): | ||
| 434 | decoding_method=decoding_method, | 446 | decoding_method=decoding_method, |
| 435 | rule_fsts=rule_fsts, | 447 | rule_fsts=rule_fsts, |
| 436 | rule_fars=rule_fars, | 448 | rule_fars=rule_fars, |
| 449 | + hr=HomophoneReplacerConfig( | ||
| 450 | + dict_dir=hr_dict_dir, | ||
| 451 | + lexicon=hr_lexicon, | ||
| 452 | + rule_fsts=hr_rule_fsts, | ||
| 453 | + ), | ||
| 437 | ) | 454 | ) |
| 438 | 455 | ||
| 439 | self.recognizer = _Recognizer(recognizer_config) | 456 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -460,6 +477,9 @@ class OnlineRecognizer(object): | @@ -460,6 +477,9 @@ class OnlineRecognizer(object): | ||
| 460 | rule_fsts: str = "", | 477 | rule_fsts: str = "", |
| 461 | rule_fars: str = "", | 478 | rule_fars: str = "", |
| 462 | device: int = 0, | 479 | device: int = 0, |
| 480 | + hr_dict_dir: str = "", | ||
| 481 | + hr_rule_fsts: str = "", | ||
| 482 | + hr_lexicon: str = "", | ||
| 463 | ): | 483 | ): |
| 464 | """ | 484 | """ |
| 465 | Please refer to | 485 | Please refer to |
| @@ -526,8 +546,8 @@ class OnlineRecognizer(object): | @@ -526,8 +546,8 @@ class OnlineRecognizer(object): | ||
| 526 | zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) | 546 | zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) |
| 527 | 547 | ||
| 528 | provider_config = ProviderConfig( | 548 | provider_config = ProviderConfig( |
| 529 | - provider=provider, | ||
| 530 | - device=device, | 549 | + provider=provider, |
| 550 | + device=device, | ||
| 531 | ) | 551 | ) |
| 532 | 552 | ||
| 533 | model_config = OnlineModelConfig( | 553 | model_config = OnlineModelConfig( |
| @@ -563,6 +583,11 @@ class OnlineRecognizer(object): | @@ -563,6 +583,11 @@ class OnlineRecognizer(object): | ||
| 563 | decoding_method=decoding_method, | 583 | decoding_method=decoding_method, |
| 564 | rule_fsts=rule_fsts, | 584 | rule_fsts=rule_fsts, |
| 565 | rule_fars=rule_fars, | 585 | rule_fars=rule_fars, |
| 586 | + hr=HomophoneReplacerConfig( | ||
| 587 | + dict_dir=hr_dict_dir, | ||
| 588 | + lexicon=hr_lexicon, | ||
| 589 | + rule_fsts=hr_rule_fsts, | ||
| 590 | + ), | ||
| 566 | ) | 591 | ) |
| 567 | 592 | ||
| 568 | self.recognizer = _Recognizer(recognizer_config) | 593 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -587,6 +612,9 @@ class OnlineRecognizer(object): | @@ -587,6 +612,9 @@ class OnlineRecognizer(object): | ||
| 587 | rule_fsts: str = "", | 612 | rule_fsts: str = "", |
| 588 | rule_fars: str = "", | 613 | rule_fars: str = "", |
| 589 | device: int = 0, | 614 | device: int = 0, |
| 615 | + hr_dict_dir: str = "", | ||
| 616 | + hr_rule_fsts: str = "", | ||
| 617 | + hr_lexicon: str = "", | ||
| 590 | ): | 618 | ): |
| 591 | """ | 619 | """ |
| 592 | Please refer to | 620 | Please refer to |
| @@ -650,8 +678,8 @@ class OnlineRecognizer(object): | @@ -650,8 +678,8 @@ class OnlineRecognizer(object): | ||
| 650 | ) | 678 | ) |
| 651 | 679 | ||
| 652 | provider_config = ProviderConfig( | 680 | provider_config = ProviderConfig( |
| 653 | - provider=provider, | ||
| 654 | - device=device, | 681 | + provider=provider, |
| 682 | + device=device, | ||
| 655 | ) | 683 | ) |
| 656 | 684 | ||
| 657 | model_config = OnlineModelConfig( | 685 | model_config = OnlineModelConfig( |
| @@ -681,6 +709,11 @@ class OnlineRecognizer(object): | @@ -681,6 +709,11 @@ class OnlineRecognizer(object): | ||
| 681 | decoding_method=decoding_method, | 709 | decoding_method=decoding_method, |
| 682 | rule_fsts=rule_fsts, | 710 | rule_fsts=rule_fsts, |
| 683 | rule_fars=rule_fars, | 711 | rule_fars=rule_fars, |
| 712 | + hr=HomophoneReplacerConfig( | ||
| 713 | + dict_dir=hr_dict_dir, | ||
| 714 | + lexicon=hr_lexicon, | ||
| 715 | + rule_fsts=hr_rule_fsts, | ||
| 716 | + ), | ||
| 684 | ) | 717 | ) |
| 685 | 718 | ||
| 686 | self.recognizer = _Recognizer(recognizer_config) | 719 | self.recognizer = _Recognizer(recognizer_config) |
| @@ -707,6 +740,9 @@ class OnlineRecognizer(object): | @@ -707,6 +740,9 @@ class OnlineRecognizer(object): | ||
| 707 | rule_fsts: str = "", | 740 | rule_fsts: str = "", |
| 708 | rule_fars: str = "", | 741 | rule_fars: str = "", |
| 709 | device: int = 0, | 742 | device: int = 0, |
| 743 | + hr_dict_dir: str = "", | ||
| 744 | + hr_rule_fsts: str = "", | ||
| 745 | + hr_lexicon: str = "", | ||
| 710 | ): | 746 | ): |
| 711 | """ | 747 | """ |
| 712 | Please refer to | 748 | Please refer to |
| @@ -775,8 +811,8 @@ class OnlineRecognizer(object): | @@ -775,8 +811,8 @@ class OnlineRecognizer(object): | ||
| 775 | ) | 811 | ) |
| 776 | 812 | ||
| 777 | provider_config = ProviderConfig( | 813 | provider_config = ProviderConfig( |
| 778 | - provider=provider, | ||
| 779 | - device=device, | 814 | + provider=provider, |
| 815 | + device=device, | ||
| 780 | ) | 816 | ) |
| 781 | 817 | ||
| 782 | model_config = OnlineModelConfig( | 818 | model_config = OnlineModelConfig( |
| @@ -806,6 +842,11 @@ class OnlineRecognizer(object): | @@ -806,6 +842,11 @@ class OnlineRecognizer(object): | ||
| 806 | decoding_method=decoding_method, | 842 | decoding_method=decoding_method, |
| 807 | rule_fsts=rule_fsts, | 843 | rule_fsts=rule_fsts, |
| 808 | rule_fars=rule_fars, | 844 | rule_fars=rule_fars, |
| 845 | + hr=HomophoneReplacerConfig( | ||
| 846 | + dict_dir=hr_dict_dir, | ||
| 847 | + lexicon=hr_lexicon, | ||
| 848 | + rule_fsts=hr_rule_fsts, | ||
| 849 | + ), | ||
| 809 | ) | 850 | ) |
| 810 | 851 | ||
| 811 | self.recognizer = _Recognizer(recognizer_config) | 852 | self.recognizer = _Recognizer(recognizer_config) |
-
请 注册 或 登录 后发表评论