Fangjun Kuang
Committed by GitHub

Support replacing homonphonic phrases (#2153)

正在显示 42 个修改的文件 包含 834 行增加134 行删除
... ... @@ -98,6 +98,29 @@ for m in model.onnx model.int8.onnx; do
done
done
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
tar xf dict.tar.bz2
rm dict.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
for m in model.onnx model.int8.onnx; do
for use_itn in 0 1; do
echo "$m $w $use_itn"
time $EXE \
--tokens=$repo/tokens.txt \
--sense-voice-model=$repo/$m \
--sense-voice-use-itn=$use_itn \
--hr-lexicon=./lexicon.txt \
--hr-dict-dir=./dict \
--hr-rule-fsts=./replace.fst \
./test-hr.wav
done
done
rm -rf dict replace.fst test-hr.wav lexicon.txt
# test wav reader for non-standard wav files
waves=(
... ...
... ... @@ -95,6 +95,18 @@ rm $name
ls -lh $repo
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files.py
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
tar xf dict.tar.bz2
rm dict.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
python3 ./python-api-examples/offline-sense-voice-ctc-decode-files-with-hr.py
rm -rf dict replace.fst test-hr.wav lexicon.txt
if [[ $(uname) == Linux ]]; then
# It needs ffmpeg
log "generate subtitles (Chinese)"
... ...
function(download_kaldifst)
include(FetchContent)
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz")
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz")
set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f")
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.13.tar.gz")
set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.13.tar.gz")
set(kaldifst_HASH "SHA256=f8dc15fdaf314d7c9c3551ad8c11ed15da0f34de36446798bbd1b90fa7946eb2")
# If you don't have access to the Internet,
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz
${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz
${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz
/tmp/kaldifst-1.7.11.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz
$ENV{HOME}/Downloads/kaldifst-1.7.13.tar.gz
${CMAKE_SOURCE_DIR}/kaldifst-1.7.13.tar.gz
${CMAKE_BINARY_DIR}/kaldifst-1.7.13.tar.gz
/tmp/kaldifst-1.7.13.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.13.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
#!/usr/bin/env python3
"""
This file shows how to use a non-streaming SenseVoice CTC model from
https://github.com/FunAudioLLM/SenseVoice
to decode files.
Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
For instance,
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/dict.tar.bz2
tar xf dict.tar.bz2
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/replace.fst
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/test-hr.wav
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/hr-files/lexicon.txt
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def create_recognizer():
model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx"
tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"
test_wav = "./test-hr.wav"
if not Path(model).is_file() or not Path(test_wav).is_file():
raise ValueError(
"""Please download model files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
and
https://github.com/k2-fsa/sherpa-onnx/releases/tag/hr-files
"""
)
return (
sherpa_onnx.OfflineRecognizer.from_sense_voice(
model=model,
tokens=tokens,
use_itn=True,
debug=True,
hr_lexicon="./lexicon.txt",
hr_dict_dir="./dict",
hr_rule_fsts="./replace.fst",
),
test_wav,
)
def main():
recognizer, wave_filename = create_recognizer()
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
# sample_rate does not need to be 16000 Hz
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, audio)
recognizer.decode_stream(stream)
print(wave_filename)
print(stream.result)
if __name__ == "__main__":
main()
... ...
... ... @@ -20,7 +20,9 @@ set(sources
features.cc
file-utils.cc
fst-utils.cc
homophone-replacer.cc
hypothesis.cc
jieba.cc
keyword-spotter-impl.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
... ...
// sherpa-onnx/csrc/homophone-replacer.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include <fstream>
#include <sstream>
#include <string>
#include <strstream>
#include <unordered_map>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void HomophoneReplacerConfig::Register(ParseOptions *po) {
po->Register("hr-dict-dir", &dict_dir,
"The dict directory for jieba used by HomophoneReplacer");
po->Register("hr-lexicon", &lexicon,
"Path to lexicon.txt used by HomophoneReplacer.");
po->Register("hr-rule-fsts", &rule_fsts,
"Fst files for HomophoneReplacer. If there are multiple, they "
"are separated by a comma. E.g., a.fst,b.fst,c.fst");
}
bool HomophoneReplacerConfig::Validate() const {
if (!dict_dir.empty()) {
std::vector<std::string> required_files = {
"jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8",
"idf.utf8", "stop_words.utf8",
};
for (const auto &f : required_files) {
if (!FileExists(dict_dir + "/" + f)) {
SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir",
dict_dir.c_str(), f.c_str());
return false;
}
}
}
if (!lexicon.empty() && !FileExists(lexicon)) {
SHERPA_ONNX_LOGE("--hr-lexicon: '%s' does not exist", lexicon.c_str());
return false;
}
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
if (files.size() > 1) {
SHERPA_ONNX_LOGE("Only 1 file is supported now.");
SHERPA_ONNX_EXIT(-1);
}
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
}
return true;
}
std::string HomophoneReplacerConfig::ToString() const {
std::ostringstream os;
os << "HomophoneReplacerConfig(";
os << "dict_dir=\"" << dict_dir << "\", ";
os << "lexicon=\"" << lexicon << "\", ";
os << "rule_fsts=\"" << rule_fsts << "\")";
return os.str();
}
class HomophoneReplacer::Impl {
public:
explicit Impl(const HomophoneReplacerConfig &config) : config_(config) {
jieba_ = InitJieba(config.dict_dir);
{
std::ifstream is(config.lexicon);
InitLexicon(is);
}
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
replacer_list_.reserve(files.size());
for (const auto &f : files) {
if (config.debug) {
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
}
replacer_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
}
template <typename Manager>
Impl(Manager *mgr, const HomophoneReplacerConfig &config) : config_(config) {
jieba_ = InitJieba(config.dict_dir);
{
auto buf = ReadFile(mgr, config.lexicon);
std::istrstream is(buf.data(), buf.size());
InitLexicon(is);
}
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
replacer_list_.reserve(files.size());
for (const auto &f : files) {
if (config.debug) {
SHERPA_ONNX_LOGE("hr rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
replacer_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(is));
}
}
}
std::string Apply(const std::string &text) const {
bool is_hmm = true;
std::vector<std::string> words;
jieba_->Cut(text, words, is_hmm);
if (config_.debug) {
SHERPA_ONNX_LOGE("Input text: '%s'", text.c_str());
std::ostringstream os;
os << "After jieba: ";
std::string sep;
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
// convert words to pronunciations
std::vector<std::string> pronunciations;
for (const auto &w : words) {
auto p = ConvertWordToPronunciation(w);
if (config_.debug) {
SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str());
}
pronunciations.push_back(std::move(p));
}
std::string ans;
for (const auto &r : replacer_list_) {
ans = r->Normalize(words, pronunciations);
// TODO(fangjun): We support only 1 rule fst at present.
break;
}
return ans;
}
private:
std::string ConvertWordToPronunciation(const std::string &word) const {
if (word2pron_.count(word)) {
return word2pron_.at(word);
}
if (word.size() <= 3) {
// not a Chinese character
return word;
}
std::vector<std::string> words = SplitUtf8(word);
std::string ans;
for (const auto &w : words) {
if (word2pron_.count(w)) {
ans.append(word2pron_.at(w));
} else {
ans.append(w);
}
}
return ans;
}
void InitLexicon(std::istream &is) {
std::string word;
std::string pron;
std::string p;
std::string line;
int32_t line_num = 0;
int32_t num_warn = 0;
while (std::getline(is, line)) {
++line_num;
std::istringstream iss(line);
pron.clear();
iss >> word;
ToLowerCase(&word);
if (word2pron_.count(word)) {
num_warn += 1;
if (num_warn < 10) {
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
}
continue;
}
while (iss >> p) {
pron.append(std::move(p));
}
if (pron.empty()) {
SHERPA_ONNX_LOGE(
"Empty pronunciation for word '%s' at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
continue;
}
word2pron_.insert({std::move(word), std::move(pron)});
}
}
private:
HomophoneReplacerConfig config_;
std::unique_ptr<cppjieba::Jieba> jieba_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> replacer_list_;
std::unordered_map<std::string, std::string> word2pron_;
};
HomophoneReplacer::HomophoneReplacer(const HomophoneReplacerConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
HomophoneReplacer::HomophoneReplacer(Manager *mgr,
const HomophoneReplacerConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
HomophoneReplacer::~HomophoneReplacer() = default;
std::string HomophoneReplacer::Apply(const std::string &text) const {
return impl_->Apply(text);
}
#if __ANDROID_API__ >= 9
template HomophoneReplacer::HomophoneReplacer(
AAssetManager *mgr, const HomophoneReplacerConfig &config);
#endif
#if __OHOS__
template HomophoneReplacer::HomophoneReplacer(
NativeResourceManager *mgr, const HomophoneReplacerConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/homophone-replacer.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
#define SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct HomophoneReplacerConfig {
std::string dict_dir;
std::string lexicon;
// comma separated fst files, e.g. a.fst,b.fst,c.fst
std::string rule_fsts;
bool debug;
HomophoneReplacerConfig() = default;
HomophoneReplacerConfig(const std::string &dict_dir,
const std::string &lexicon,
const std::string &rule_fsts, bool debug)
: dict_dir(dict_dir),
lexicon(lexicon),
rule_fsts(rule_fsts),
debug(debug) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class HomophoneReplacer {
public:
explicit HomophoneReplacer(const HomophoneReplacerConfig &config);
template <typename Manager>
HomophoneReplacer(Manager *mgr, const HomophoneReplacerConfig &config);
~HomophoneReplacer();
std::string Apply(const std::string &text) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_HOMOPHONE_REPLACER_H_
... ...
... ... @@ -19,8 +19,8 @@
#include "rawfile/raw_file_manager.h"
#endif
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
... ... @@ -41,20 +41,7 @@ class JiebaLexicon::Impl {
Impl(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir, bool debug)
: debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
jieba_ = InitJieba(dict_dir);
{
std::ifstream is(tokens);
... ... @@ -71,20 +58,7 @@ class JiebaLexicon::Impl {
Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir, bool debug)
: debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
jieba_ = InitJieba(dict_dir);
{
auto buf = ReadFile(mgr, tokens);
... ...
// sherpa-onnx/csrc/jieba.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/file-utils.h"
namespace sherpa_onnx {
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir) {
if (dict_dir.empty()) {
return {};
}
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
return std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf,
stop_word);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/jieba.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_JIEBA_H_
#define SHERPA_ONNX_CSRC_JIEBA_H_
#include <memory>
#include <string>
#include "cppjieba/Jieba.hpp"
namespace sherpa_onnx {
std::unique_ptr<cppjieba::Jieba> InitJieba(const std::string &dict_dir);
}
#endif // SHERPA_ONNX_CSRC_JIEBA_H_
... ...
... ... @@ -22,11 +22,11 @@
#include <codecvt>
#include "cppjieba/Jieba.hpp"
#include "espeak-ng/speak_lib.h"
#include "phoneme_ids.hpp"
#include "phonemize.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/text-utils.h"
... ... @@ -47,7 +47,7 @@ class KokoroMultiLangLexicon::Impl {
InitLexicon(lexicon);
InitJieba(dict_dir);
jieba_ = InitJieba(dict_dir);
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
}
... ... @@ -62,7 +62,7 @@ class KokoroMultiLangLexicon::Impl {
InitLexicon(mgr, lexicon);
// we assume you have copied dict_dir and data_dir from assets to some path
InitJieba(dict_dir);
jieba_ = InitJieba(dict_dir);
InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc
}
... ... @@ -456,23 +456,6 @@ class KokoroMultiLangLexicon::Impl {
}
}
void InitJieba(const std::string &dict_dir) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
}
private:
OfflineTtsKokoroModelMetaData meta_data_;
... ...
... ... @@ -19,8 +19,8 @@
#include "rawfile/raw_file_manager.h"
#endif
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/jieba.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/symbol-table.h"
... ... @@ -34,20 +34,7 @@ class MeloTtsLexicon::Impl {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
jieba_ = InitJieba(dict_dir);
{
std::ifstream is(tokens);
... ... @@ -79,20 +66,7 @@ class MeloTtsLexicon::Impl {
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
jieba_ = InitJieba(dict_dir);
{
auto buf = ReadFile(mgr, tokens);
... ...
... ... @@ -239,6 +239,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
ss[i]->SetResult(r);
}
}
... ... @@ -277,6 +278,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
... ...
... ... @@ -125,6 +125,7 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
... ...
... ... @@ -408,6 +408,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
OfflineRecognizerImpl::OfflineRecognizerImpl(
const OfflineRecognizerConfig &config)
: config_(config) {
// TODO(fangjun): Refactor this function
if (!config.rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
... ... @@ -448,6 +450,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
!config.hr.rule_fsts.empty()) {
auto hr_config = config.hr;
hr_config.debug = config.model_config.debug;
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
}
}
template <typename Manager>
... ... @@ -495,6 +504,13 @@ OfflineRecognizerImpl::OfflineRecognizerImpl(
} // for (; !reader->Done(); reader->Next())
} // for (const auto &f : files)
} // if (!config.rule_fars.empty())
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
!config.hr.rule_fsts.empty()) {
auto hr_config = config.hr;
hr_config.debug = config.model_config.debug;
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
}
}
std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
... ... @@ -510,6 +526,15 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
return text;
}
std::string OfflineRecognizerImpl::ApplyHomophoneReplacer(
std::string text) const {
if (hr_) {
text = hr_->Apply(text);
}
return text;
}
void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
config_ = config;
}
... ...
... ... @@ -10,6 +10,7 @@
#include <vector>
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-stream.h"
... ... @@ -48,12 +49,15 @@ class OfflineRecognizerImpl {
std::string ApplyInverseTextNormalization(std::string text) const;
std::string ApplyHomophoneReplacer(std::string text) const;
private:
OfflineRecognizerConfig config_;
// for inverse text normalization. Used only if
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
std::unique_ptr<HomophoneReplacer> hr_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -121,6 +121,7 @@ class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl {
auto r = Convert(results[0], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
... ...
... ... @@ -197,6 +197,7 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
ss[i]->SetResult(r);
}
}
... ...
... ... @@ -222,6 +222,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
auto r = ConvertSenseVoiceResult(results[i], symbol_table_,
frame_shift_ms, subsampling_factor);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
ss[i]->SetResult(r);
}
}
... ... @@ -295,6 +296,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
subsampling_factor);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
... ...
... ... @@ -239,6 +239,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
ss[i]->SetResult(r);
}
... ...
... ... @@ -128,6 +128,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
ss[i]->SetResult(r);
}
... ...
... ... @@ -160,6 +160,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std::string s = sym_table[i];
s = ApplyInverseTextNormalization(s);
s = ApplyHomophoneReplacer(std::move(s));
text += s;
r.tokens.push_back(s);
... ...
... ... @@ -28,6 +28,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
model_config.Register(po);
lm_config.Register(po);
ctc_fst_decoder_config.Register(po);
hr.Register(po);
po->Register(
"decoding-method", &decoding_method,
... ... @@ -120,6 +121,11 @@ bool OfflineRecognizerConfig::Validate() const {
}
}
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
!hr.Validate()) {
return false;
}
return model_config.Validate();
}
... ... @@ -137,7 +143,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "hotwords_score=" << hotwords_score << ", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
os << "rule_fars=\"" << rule_fars << "\", ";
os << "hr=" << hr.ToString() << ")";
return os.str();
}
... ...
... ... @@ -10,6 +10,7 @@
#include <vector>
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
... ... @@ -40,6 +41,7 @@ struct OfflineRecognizerConfig {
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;
HomophoneReplacerConfig hr;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
... ... @@ -52,7 +54,7 @@ struct OfflineRecognizerConfig {
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score,
float blank_penalty, const std::string &rule_fsts,
const std::string &rule_fars)
const std::string &rule_fars, const HomophoneReplacerConfig &hr)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
... ... @@ -63,7 +65,8 @@ struct OfflineRecognizerConfig {
hotwords_score(hotwords_score),
blank_penalty(blank_penalty),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
rule_fars(rule_fars),
hr(hr) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -201,7 +201,8 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ...
... ... @@ -192,6 +192,13 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
!config.hr.rule_fsts.empty()) {
auto hr_config = config.hr;
hr_config.debug = config.model_config.debug;
hr_ = std::make_unique<HomophoneReplacer>(hr_config);
}
}
template <typename Manager>
... ... @@ -239,6 +246,12 @@ OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr,
} // for (; !reader->Done(); reader->Next())
} // for (const auto &f : files)
} // if (!config.rule_fars.empty())
if (!config.hr.dict_dir.empty() && !config.hr.lexicon.empty() &&
!config.hr.rule_fsts.empty()) {
auto hr_config = config.hr;
hr_config.debug = config.model_config.debug;
hr_ = std::make_unique<HomophoneReplacer>(mgr, hr_config);
}
}
std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
... ... @@ -254,6 +267,15 @@ std::string OnlineRecognizerImpl::ApplyInverseTextNormalization(
return text;
}
std::string OnlineRecognizerImpl::ApplyHomophoneReplacer(
std::string text) const {
if (hr_) {
text = hr_->Apply(text);
}
return text;
}
#if __ANDROID_API__ >= 9
template OnlineRecognizerImpl::OnlineRecognizerImpl(
AAssetManager *mgr, const OnlineRecognizerConfig &config);
... ...
... ... @@ -10,6 +10,7 @@
#include <vector>
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
... ... @@ -57,6 +58,7 @@ class OnlineRecognizerImpl {
virtual void Reset(OnlineStream *s) const = 0;
std::string ApplyInverseTextNormalization(std::string text) const;
std::string ApplyHomophoneReplacer(std::string text) const;
private:
OnlineRecognizerConfig config_;
... ... @@ -64,6 +66,7 @@ class OnlineRecognizerImpl {
// config.rule_fsts is not empty or
// config.rule_fars is not empty
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> itn_list_;
std::unique_ptr<HomophoneReplacer> hr_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -169,7 +169,8 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
auto decoder_result = s->GetParaformerResult();
auto r = Convert(decoder_result, sym_);
r.text = ApplyInverseTextNormalization(r.text);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ...
... ... @@ -349,6 +349,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ... @@ -391,15 +392,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
// (the encoder state buffers are kept)
for (const auto &it : last_result.hyps) {
auto h = it.second;
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
h.ys.end()),
r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, h.ys.end()),
h.log_prob});
}
r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
last_result.tokens.end());
r.tokens = std::vector<int64_t>(last_result.tokens.end() - context_size,
last_result.tokens.end());
} else {
if(config_.reset_encoder) {
if (config_.reset_encoder) {
// reset encoder states, use blanks as 'ys' context
s->SetStates(model_->GetEncoderInitStates());
}
... ...
... ... @@ -100,6 +100,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ...
... ... @@ -88,6 +88,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
endpoint_config.Register(po);
lm_config.Register(po);
ctc_fst_decoder_config.Register(po);
hr.Register(po);
po->Register("enable-endpoint", &enable_endpoint,
"True to enable endpoint detection. False to disable it.");
... ... @@ -182,6 +183,11 @@ bool OnlineRecognizerConfig::Validate() const {
}
}
if (!hr.dict_dir.empty() && !hr.lexicon.empty() && !hr.rule_fsts.empty() &&
!hr.Validate()) {
return false;
}
return model_config.Validate();
}
... ... @@ -203,7 +209,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "temperature_scale=" << temperature_scale << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\", ";
os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
os << "reset_encoder=" << (reset_encoder ? "True" : "False") << ", ";
os << "hr=" << hr.ToString() << ")";
return os.str();
}
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/homophone-replacer.h"
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
... ... @@ -107,6 +108,8 @@ struct OnlineRecognizerConfig {
// currently only in `OnlineRecognizerTransducerImpl`.
bool reset_encoder = false;
HomophoneReplacerConfig hr;
/// used only for modified_beam_search, if hotwords_buf is non-empty,
/// the hotwords will be loaded from the buffered string instead of from the
/// "hotwords_file"
... ... @@ -123,7 +126,7 @@ struct OnlineRecognizerConfig {
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars,
bool reset_encoder)
bool reset_encoder, const HomophoneReplacerConfig &hr)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
... ... @@ -138,7 +141,8 @@ struct OnlineRecognizerConfig {
temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
rule_fars(rule_fars),
reset_encoder(reset_encoder) {}
reset_encoder(reset_encoder),
hr(hr) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -89,7 +89,8 @@ class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl {
auto r =
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(r.text);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ...
... ... @@ -177,6 +177,7 @@ class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
return r;
}
... ...
... ... @@ -7,6 +7,7 @@ set(srcs
display.cc
endpoint.cc
features.cc
homophone-replacer.cc
keyword-spotter.cc
offline-ctc-fst-decoder-config.cc
offline-dolphin-model-config.cc
... ...
// sherpa-onnx/python/csrc/homophone-replacer.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
#include <string>
#include "sherpa-onnx/csrc/homophone-replacer.h"
namespace sherpa_onnx {
void PybindHomophoneReplacer(py::module *m) {
using PyClass = HomophoneReplacerConfig;
py::class_<PyClass>(*m, "HomophoneReplacerConfig")
.def(py::init<>())
.def(py::init<const std::string &, const std::string &,
const std::string &, bool>(),
py::arg("dict_dir"), py::arg("lexicon"), py::arg("rule_fsts"),
py::arg("debug") = false)
.def_readwrite("dict_dir", &PyClass::dict_dir)
.def_readwrite("lexicon", &PyClass::lexicon)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("debug", &PyClass::debug)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/homophone-replacer.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
#define SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindHomophoneReplacer(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_HOMOPHONE_REPLACER_H_
... ...
... ... @@ -17,14 +17,16 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
const std::string &, int32_t, const std::string &, float,
float, const std::string &, const std::string &>(),
float, const std::string &, const std::string &,
const HomophoneReplacerConfig &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
py::arg("rule_fsts") = "", py::arg("rule_fars") = "",
py::arg("hr") = HomophoneReplacerConfig{})
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
... ... @@ -36,6 +38,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("hr", &PyClass::hr)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -58,7 +58,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &, bool>(),
float, float, const std::string &, const std::string &,
bool, const HomophoneReplacerConfig &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
... ... @@ -67,7 +68,8 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
py::arg("rule_fars") = "", py::arg("reset_encoder") = false,
py::arg("hr") = HomophoneReplacerConfig{})
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
... ... @@ -83,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("reset_encoder", &PyClass::reset_encoder)
.def_readwrite("hr", &PyClass::hr)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -10,6 +10,7 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/homophone-replacer.h"
#include "sherpa-onnx/python/csrc/keyword-spotter.h"
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
... ... @@ -51,6 +52,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindAudioTagging(&m);
PybindOfflinePunctuation(&m);
PybindOnlinePunctuation(&m);
PybindHomophoneReplacer(&m);
PybindFeatures(&m);
PybindOnlineCtcFstDecoderConfig(&m);
... ...
... ... @@ -5,6 +5,7 @@ from typing import List, Optional
from _sherpa_onnx import (
FeatureExtractorConfig,
HomophoneReplacerConfig,
OfflineCtcFstDecoderConfig,
OfflineDolphinModelConfig,
OfflineFireRedAsrModelConfig,
... ... @@ -64,6 +65,9 @@ class OfflineRecognizer(object):
rule_fars: str = "",
lm: str = "",
lm_scale: float = 0.1,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -181,6 +185,11 @@ class OfflineRecognizer(object):
blank_penalty=blank_penalty,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -201,6 +210,9 @@ class OfflineRecognizer(object):
use_itn: bool = False,
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -263,6 +275,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -281,6 +298,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -336,6 +356,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -354,6 +379,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -411,6 +439,9 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir, lexicon=hr_lexicon, rule_fsts=hr_rule_fsts
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -429,6 +460,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -483,6 +517,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -501,6 +540,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -557,6 +599,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -577,6 +624,9 @@ class OfflineRecognizer(object):
tail_paddings: int = -1,
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -647,6 +697,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -664,6 +719,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -719,6 +777,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -738,6 +801,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -800,6 +866,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -818,6 +889,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -873,6 +947,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ... @@ -891,6 +970,9 @@ class OfflineRecognizer(object):
provider: str = "cpu",
rule_fsts: str = "",
rule_fars: str = "",
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -947,6 +1029,11 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ...
... ... @@ -3,25 +3,26 @@ from pathlib import Path
from typing import List, Optional
from _sherpa_onnx import (
CudaConfig,
EndpointConfig,
FeatureExtractorConfig,
HomophoneReplacerConfig,
OnlineCtcFstDecoderConfig,
OnlineLMConfig,
OnlineModelConfig,
OnlineNeMoCtcModelConfig,
OnlineParaformerModelConfig,
)
from _sherpa_onnx import OnlineRecognizer as _Recognizer
from _sherpa_onnx import (
CudaConfig,
TensorrtConfig,
ProviderConfig,
OnlineRecognizerConfig,
OnlineRecognizerResult,
OnlineStream,
OnlineTransducerModelConfig,
OnlineWenetCtcModelConfig,
OnlineNeMoCtcModelConfig,
OnlineZipformer2CtcModelConfig,
OnlineCtcFstDecoderConfig,
ProviderConfig,
TensorrtConfig,
)
... ... @@ -82,9 +83,12 @@ class OnlineRecognizer(object):
trt_detailed_build_log: bool = False,
trt_engine_cache_enable: bool = True,
trt_timing_cache_enable: bool = True,
trt_engine_cache_path: str ="",
trt_timing_cache_path: str ="",
trt_engine_cache_path: str = "",
trt_timing_cache_path: str = "",
trt_dump_subgraphs: bool = False,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -228,27 +232,27 @@ class OnlineRecognizer(object):
)
cuda_config = CudaConfig(
cudnn_conv_algo_search=cudnn_conv_algo_search,
cudnn_conv_algo_search=cudnn_conv_algo_search,
)
trt_config = TensorrtConfig(
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
trt_max_workspace_size=trt_max_workspace_size,
trt_max_partition_iterations=trt_max_partition_iterations,
trt_min_subgraph_size=trt_min_subgraph_size,
trt_fp16_enable=trt_fp16_enable,
trt_detailed_build_log=trt_detailed_build_log,
trt_engine_cache_enable=trt_engine_cache_enable,
trt_timing_cache_enable=trt_timing_cache_enable,
trt_engine_cache_path=trt_engine_cache_path,
trt_timing_cache_path=trt_timing_cache_path,
trt_dump_subgraphs=trt_dump_subgraphs,
)
provider_config = ProviderConfig(
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
trt_config=trt_config,
cuda_config=cuda_config,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -311,6 +315,11 @@ class OnlineRecognizer(object):
rule_fsts=rule_fsts,
rule_fars=rule_fars,
reset_encoder=reset_encoder,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -336,6 +345,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -402,8 +414,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -434,6 +446,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -460,6 +477,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -526,8 +546,8 @@ class OnlineRecognizer(object):
zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -563,6 +583,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -587,6 +612,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -650,8 +678,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -681,6 +709,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
... ... @@ -707,6 +740,9 @@ class OnlineRecognizer(object):
rule_fsts: str = "",
rule_fars: str = "",
device: int = 0,
hr_dict_dir: str = "",
hr_rule_fsts: str = "",
hr_lexicon: str = "",
):
"""
Please refer to
... ... @@ -775,8 +811,8 @@ class OnlineRecognizer(object):
)
provider_config = ProviderConfig(
provider=provider,
device=device,
provider=provider,
device=device,
)
model_config = OnlineModelConfig(
... ... @@ -806,6 +842,11 @@ class OnlineRecognizer(object):
decoding_method=decoding_method,
rule_fsts=rule_fsts,
rule_fars=rule_fars,
hr=HomophoneReplacerConfig(
dict_dir=hr_dict_dir,
lexicon=hr_lexicon,
rule_fsts=hr_rule_fsts,
),
)
self.recognizer = _Recognizer(recognizer_config)
... ...