online-recognizer.cc
3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// sherpa-onnx/csrc/online-recognizer.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#include "sherpa-onnx/csrc/online-recognizer.h"
#include <assert.h>
#include <algorithm>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
using json = nlohmann::json;
json j;
j["text"] = text;
j["tokens"] = tokens;
j["start_time"] = start_time;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std::ostringstream os;
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]";
j["timestamps"] = os.str();
#else
j["timestamps"] = timestamps;
#endif
j["segment"] = segment;
j["is_final"] = is_final;
return j.dump();
}
void OnlineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
endpoint_config.Register(po);
lm_config.Register(po);
po->Register("enable-endpoint", &enable_endpoint,
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("context-score", &context_score,
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
}
bool OnlineRecognizerConfig::Validate() const {
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
if (max_active_paths <= 0) {
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
max_active_paths);
return false;
}
if (!lm_config.Validate()) return false;
}
return model_config.Validate();
}
std::string OnlineRecognizerConfig::ToString() const {
std::ostringstream os;
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "context_score=" << context_score << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
}
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
: impl_(OnlineRecognizerImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: impl_(OnlineRecognizerImpl::Create(mgr, config)) {}
#endif
OnlineRecognizer::~OnlineRecognizer() = default;
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
return impl_->CreateStream();
}
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
}
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const {
return impl_->GetResult(s);
}
bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const {
return impl_->IsEndpoint(s);
}
void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); }
} // namespace sherpa_onnx