offline-recognizer.cc
5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// sherpa-onnx/csrc/offline-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
void OfflineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
lm_config.Register(po);
ctc_fst_decoder_config.Register(po);
po->Register(
"decoding-method", &decoding_method,
"decoding method,"
"Valid values: greedy_search, modified_beam_search. "
"modified_beam_search is applicable only for transducer models.");
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("blank-penalty", &blank_penalty,
"The penalty applied on blank symbol during decoding. "
"Note: It is a positive value. "
"Increasing value will lead to lower deletion at the cost"
"of higher insertions. "
"Currently only applicable for transducer models.");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register("hotwords-score", &hotwords_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register(
"rule-fsts", &rule_fsts,
"If not empty, it specifies fsts for inverse text normalization. "
"If there are multiple fsts, they are separated by a comma.");
po->Register(
"rule-fars", &rule_fars,
"If not empty, it specifies fst archives for inverse text normalization. "
"If there are multiple archives, they are separated by a comma.");
}
bool OfflineRecognizerConfig::Validate() const {
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
if (max_active_paths <= 0) {
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
max_active_paths);
return false;
}
if (!lm_config.Validate()) {
return false;
}
}
if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
SHERPA_ONNX_LOGE(
"Please use --decoding-method=modified_beam_search if you"
" provide --hotwords-file. Given --decoding-method='%s'",
decoding_method.c_str());
return false;
}
if (!ctc_fst_decoder_config.graph.empty() &&
!ctc_fst_decoder_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in fst_decoder");
return false;
}
if (!hotwords_file.empty() && !FileExists(hotwords_file)) {
SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist",
hotwords_file.c_str());
return false;
}
if (!rule_fsts.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fsts, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str());
return false;
}
}
}
if (!rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fars, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str());
return false;
}
}
}
return model_config.Validate();
}
std::string OfflineRecognizerConfig::ToString() const {
std::ostringstream os;
os << "OfflineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
return os.str();
}
template <typename Manager>
OfflineRecognizer::OfflineRecognizer(Manager *mgr,
const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(mgr, config)) {}
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
: impl_(OfflineRecognizerImpl::Create(config)) {}
OfflineRecognizer::~OfflineRecognizer() = default;
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream(
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
impl_->SetConfig(config);
}
OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
return impl_->GetConfig();
}
#if __ANDROID_API__ >= 9
template OfflineRecognizer::OfflineRecognizer(
AAssetManager *mgr, const OfflineRecognizerConfig &config);
#endif
#if __OHOS__
template OfflineRecognizer::OfflineRecognizer(
NativeResourceManager *mgr, const OfflineRecognizerConfig &config);
#endif
} // namespace sherpa_onnx