Fangjun Kuang
Committed by GitHub

Fix style issues for online punctuation source files (#1225)

@@ -58,6 +58,7 @@ def get_binaries(): @@ -58,6 +58,7 @@ def get_binaries():
58 "sherpa-onnx-offline-tts", 58 "sherpa-onnx-offline-tts",
59 "sherpa-onnx-offline-tts-play", 59 "sherpa-onnx-offline-tts-play",
60 "sherpa-onnx-offline-websocket-server", 60 "sherpa-onnx-offline-websocket-server",
  61 + "sherpa-onnx-online-punctuation",
61 "sherpa-onnx-online-websocket-client", 62 "sherpa-onnx-online-websocket-client",
62 "sherpa-onnx-online-websocket-server", 63 "sherpa-onnx-online-websocket-server",
63 "sherpa-onnx-vad-microphone", 64 "sherpa-onnx-vad-microphone",
@@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl { @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl {
35 } 35 }
36 #endif 36 #endif
37 37
38 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {  
39 - std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; 38 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
  39 + Ort::Value valid_ids,
  40 + Ort::Value label_lens) {
  41 + std::array<Ort::Value, 3> inputs = {
  42 + std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
40 43
41 auto ans = 44 auto ans =
42 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 45 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
@@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
117 120
118 OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; 121 OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
119 122
120 -std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,  
121 - Ort::Value valid_ids,  
122 - Ort::Value label_lens) const {  
123 - return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens)); 123 +std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(
  124 + Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const {
  125 + return impl_->Forward(std::move(token_ids), std::move(valid_ids),
  126 + std::move(label_lens));
124 } 127 }
125 128
126 OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { 129 OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
127 return impl_->Allocator(); 130 return impl_->Allocator();
128 } 131 }
129 132
130 -const OnlineCNNBiLSTMModelMetaData &  
131 -OnlineCNNBiLSTMModel::GetModelMetadata() const { 133 +const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata()
  134 + const {
132 return impl_->GetModelMetadata(); 135 return impl_->GetModelMetadata();
133 } 136 }
134 137
@@ -23,12 +23,11 @@ namespace sherpa_onnx { @@ -23,12 +23,11 @@ namespace sherpa_onnx {
23 */ 23 */
24 class OnlineCNNBiLSTMModel { 24 class OnlineCNNBiLSTMModel {
25 public: 25 public:
26 - explicit OnlineCNNBiLSTMModel(  
27 - const OnlinePunctuationModelConfig &config); 26 + explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config);
28 27
29 #if __ANDROID_API__ >= 9 28 #if __ANDROID_API__ >= 9
30 OnlineCNNBiLSTMModel(AAssetManager *mgr, 29 OnlineCNNBiLSTMModel(AAssetManager *mgr,
31 - const OnlinePunctuationModelConfig &config); 30 + const OnlinePunctuationModelConfig &config);
32 #endif 31 #endif
33 32
34 ~OnlineCNNBiLSTMModel(); 33 ~OnlineCNNBiLSTMModel();
@@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel { @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel {
43 * - case_logits: A 2-D tensor of shape (T', num_cases). 42 * - case_logits: A 2-D tensor of shape (T', num_cases).
44 * - punct_logits: A 2-D tensor of shape (T', num_puncts). 43 * - punct_logits: A 2-D tensor of shape (T', num_puncts).
45 */ 44 */
46 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const; 45 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
  46 + Ort::Value valid_ids,
  47 + Ort::Value label_lens) const;
47 48
48 /** Return an allocator for allocating memory 49 /** Return an allocator for allocating memory
49 */ 50 */
@@ -7,27 +7,28 @@ @@ -7,27 +7,28 @@
7 7
8 #include <math.h> 8 #include <math.h>
9 9
  10 +#include <algorithm>
10 #include <memory> 11 #include <memory>
11 #include <string> 12 #include <string>
12 #include <utility> 13 #include <utility>
13 #include <vector> 14 #include <vector>
14 -#include <algorithm>  
15 15
16 #if __ANDROID_API__ >= 9 16 #if __ANDROID_API__ >= 9
17 #include "android/asset_manager.h" 17 #include "android/asset_manager.h"
18 #include "android/asset_manager_jni.h" 18 #include "android/asset_manager_jni.h"
19 #endif 19 #endif
20 20
  21 +#include <chrono> // NOLINT
  22 +
21 #include "sherpa-onnx/csrc/macros.h" 23 #include "sherpa-onnx/csrc/macros.h"
22 #include "sherpa-onnx/csrc/math.h" 24 #include "sherpa-onnx/csrc/math.h"
  25 +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
23 #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" 26 #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
24 #include "sherpa-onnx/csrc/online-punctuation-impl.h" 27 #include "sherpa-onnx/csrc/online-punctuation-impl.h"
25 #include "sherpa-onnx/csrc/online-punctuation.h" 28 #include "sherpa-onnx/csrc/online-punctuation.h"
26 -#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"  
27 -#include "sherpa-onnx/csrc/text-utils.h"  
28 #include "sherpa-onnx/csrc/onnx-utils.h" 29 #include "sherpa-onnx/csrc/onnx-utils.h"
  30 +#include "sherpa-onnx/csrc/text-utils.h"
29 #include "ssentencepiece/csrc/ssentencepiece.h" 31 #include "ssentencepiece/csrc/ssentencepiece.h"
30 -#include <chrono> // NOLINT  
31 32
32 namespace sherpa_onnx { 33 namespace sherpa_onnx {
33 34
@@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200; @@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200;
35 36
36 class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { 37 class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
37 public: 38 public:
38 - explicit OnlinePunctuationCNNBiLSTMImpl(  
39 - const OnlinePunctuationConfig &config) 39 + explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config)
40 : config_(config), model_(config.model) { 40 : config_(config), model_(config.model) {
41 - if (!config_.model.bpe_vocab.empty()) {  
42 - bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(  
43 - config_.model.bpe_vocab);  
44 - }  
45 - } 41 + if (!config_.model.bpe_vocab.empty()) {
  42 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
  43 + config_.model.bpe_vocab);
  44 + }
  45 + }
46 46
47 #if __ANDROID_API__ >= 9 47 #if __ANDROID_API__ >= 9
48 OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, 48 OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr,
49 - const OnlinePunctuationConfig &config) 49 + const OnlinePunctuationConfig &config)
50 : config_(config), model_(mgr, config.model) { 50 : config_(config), model_(mgr, config.model) {
51 - if (!config_.model.bpe_vocab.empty()) {  
52 - auto buf = ReadFile(mgr, config_.model.bpe_vocab);  
53 - std::istringstream iss(std::string(buf.begin(), buf.end()));  
54 - bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);  
55 - }  
56 - } 51 + if (!config_.model.bpe_vocab.empty()) {
  52 + auto buf = ReadFile(mgr, config_.model.bpe_vocab);
  53 + std::istringstream iss(std::string(buf.begin(), buf.end()));
  54 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
  55 + }
  56 + }
57 #endif 57 #endif
58 58
59 std::string AddPunctuationWithCase(const std::string &text) const override { 59 std::string AddPunctuationWithCase(const std::string &text) const override {
@@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
61 return {}; 61 return {};
62 } 62 }
63 63
64 - std::vector<int32_t> tokens_list; // N * kMaxSeqLen  
65 - std::vector<int32_t> valids_list; // N * kMaxSeqLen  
66 - std::vector<int32_t> label_len_list; // N 64 + std::vector<int32_t> tokens_list; // N * kMaxSeqLen
  65 + std::vector<int32_t> valids_list; // N * kMaxSeqLen
  66 + std::vector<int32_t> label_len_list; // N
67 67
68 EncodeSentences(text, tokens_list, valids_list, label_len_list); 68 EncodeSentences(text, tokens_list, valids_list, label_len_list);
69 69
@@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
75 int32_t n = label_len_list.size(); 75 int32_t n = label_len_list.size();
76 76
77 std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen}; 77 std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
78 - Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),  
79 - token_ids_shape.data(), token_ids_shape.size()); 78 + Ort::Value token_ids = Ort::Value::CreateTensor(
  79 + memory_info, tokens_list.data(), tokens_list.size(),
  80 + token_ids_shape.data(), token_ids_shape.size());
80 81
81 std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen}; 82 std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
82 - Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),  
83 - valid_ids_shape.data(), valid_ids_shape.size()); 83 + Ort::Value valid_ids = Ort::Value::CreateTensor(
  84 + memory_info, valids_list.data(), valids_list.size(),
  85 + valid_ids_shape.data(), valid_ids_shape.size());
84 86
85 std::array<int64_t, 1> label_len_shape = {n}; 87 std::array<int64_t, 1> label_len_shape = {n};
86 - Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),  
87 - label_len_shape.data(), label_len_shape.size()); 88 + Ort::Value label_len = Ort::Value::CreateTensor(
  89 + memory_info, label_len_list.data(), label_len_list.size(),
  90 + label_len_shape.data(), label_len_shape.size());
88 91
89 - auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); 92 + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids),
  93 + std::move(label_len));
90 94
91 std::vector<int32_t> case_pred; 95 std::vector<int32_t> case_pred;
92 std::vector<int32_t> punct_pred; 96 std::vector<int32_t> punct_pred;
93 - const float* active_case_logits = pair.first.GetTensorData<float>();  
94 - const float* active_punct_logits = pair.second.GetTensorData<float>();  
95 - std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); 97 + const float *active_case_logits = pair.first.GetTensorData<float>();
  98 + const float *active_punct_logits = pair.second.GetTensorData<float>();
  99 + std::vector<int64_t> case_logits_shape =
  100 + pair.first.GetTensorTypeAndShapeInfo().GetShape();
96 101
97 for (int32_t i = 0; i < case_logits_shape[0]; ++i) { 102 for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
98 - const float* p_cur_case = active_case_logits + i * meta_data.num_cases; 103 + const float *p_cur_case = active_case_logits + i * meta_data.num_cases;
99 auto index_case = static_cast<int32_t>(std::distance( 104 auto index_case = static_cast<int32_t>(std::distance(
100 - p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); 105 + p_cur_case,
  106 + std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
101 case_pred.push_back(index_case); 107 case_pred.push_back(index_case);
102 108
103 - const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; 109 + const float *p_cur_punct =
  110 + active_punct_logits + i * meta_data.num_punctuations;
104 auto index_punct = static_cast<int32_t>(std::distance( 111 auto index_punct = static_cast<int32_t>(std::distance(
105 - p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); 112 + p_cur_punct,
  113 + std::max_element(p_cur_punct,
  114 + p_cur_punct + meta_data.num_punctuations)));
106 punct_pred.push_back(index_punct); 115 punct_pred.push_back(index_punct);
107 } 116 }
108 117
@@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
112 } 121 }
113 122
114 private: 123 private:
115 - void EncodeSentences(const std::string& text,  
116 - std::vector<int32_t>& tokens_list,  
117 - std::vector<int32_t>& valids_list,  
118 - std::vector<int32_t>& label_len_list) const { 124 + void EncodeSentences(const std::string &text,
  125 + std::vector<int32_t> &tokens_list, // NOLINT
  126 + std::vector<int32_t> &valids_list, // NOLINT
  127 + std::vector<int32_t> &label_len_list) const { // NOLINT
119 std::vector<int32_t> tokens; 128 std::vector<int32_t> tokens;
120 std::vector<int32_t> valids; 129 std::vector<int32_t> valids;
121 int32_t label_len = 0; 130 int32_t label_len = 0;
122 131
123 - tokens.push_back(1); // hardcode 1 now, 1 - <s> 132 + tokens.push_back(1); // hardcode 1 now, 1 - <s>
124 valids.push_back(1); 133 valids.push_back(1);
125 134
126 std::stringstream ss(text); 135 std::stringstream ss(text);
127 std::string word; 136 std::string word;
128 while (ss >> word) { 137 while (ss >> word) {
129 - std::vector<int32_t> word_tokens;  
130 - bpe_encoder_->Encode(word, &word_tokens); 138 + std::vector<int32_t> word_tokens;
  139 + bpe_encoder_->Encode(word, &word_tokens);
131 140
132 - int32_t seq_len = tokens.size() + word_tokens.size();  
133 - if (seq_len > kMaxSeqLen - 1) {  
134 - tokens.push_back(2); // hardcode 2 now, 2 - </s>  
135 - valids.push_back(1); 141 + int32_t seq_len = tokens.size() + word_tokens.size();
  142 + if (seq_len > kMaxSeqLen - 1) {
  143 + tokens.push_back(2); // hardcode 2 now, 2 - </s>
  144 + valids.push_back(1);
136 145
137 - label_len = std::count(valids.begin(), valids.end(), 1); 146 + label_len = std::count(valids.begin(), valids.end(), 1);
138 147
139 - if (tokens.size() < kMaxSeqLen) {  
140 - tokens.resize(kMaxSeqLen, 0);  
141 - valids.resize(kMaxSeqLen, 0);  
142 - } 148 + if (tokens.size() < kMaxSeqLen) {
  149 + tokens.resize(kMaxSeqLen, 0);
  150 + valids.resize(kMaxSeqLen, 0);
  151 + }
143 152
144 - assert(tokens.size() == kMaxSeqLen);  
145 - assert(valids.size() == kMaxSeqLen); 153 + assert(tokens.size() == kMaxSeqLen);
  154 + assert(valids.size() == kMaxSeqLen);
146 155
147 - tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());  
148 - valids_list.insert(valids_list.end(), valids.begin(), valids.end());  
149 - label_len_list.push_back(label_len); 156 + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
  157 + valids_list.insert(valids_list.end(), valids.begin(), valids.end());
  158 + label_len_list.push_back(label_len);
150 159
151 - std::vector<int32_t>().swap(tokens);  
152 - std::vector<int32_t>().swap(valids);  
153 - label_len = 0;  
154 - tokens.push_back(1); // hardcode 1 now, 1 - <s>  
155 - valids.push_back(1);  
156 - } 160 + std::vector<int32_t>().swap(tokens);
  161 + std::vector<int32_t>().swap(valids);
  162 + label_len = 0;
  163 + tokens.push_back(1); // hardcode 1 now, 1 - <s>
  164 + valids.push_back(1);
  165 + }
157 166
158 - tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());  
159 - valids.push_back(1); // only the first sub word is valid  
160 - int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;  
161 - if (remaining_size > 0) {  
162 - int32_t valids_cur_size = static_cast<int32_t>(valids.size());  
163 - valids.resize(valids_cur_size + remaining_size, 0);  
164 - } 167 + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
  168 + valids.push_back(1); // only the first sub word is valid
  169 + int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
  170 + if (remaining_size > 0) {
  171 + int32_t valids_cur_size = static_cast<int32_t>(valids.size());
  172 + valids.resize(valids_cur_size + remaining_size, 0);
  173 + }
165 } 174 }
166 175
167 if (tokens.size() > 0) { 176 if (tokens.size() > 0) {
168 - tokens.push_back(2); // hardcode 2 now, 2 - </s> 177 + tokens.push_back(2); // hardcode 2 now, 2 - </s>
169 valids.push_back(1); 178 valids.push_back(1);
170 179
171 label_len = std::count(valids.begin(), valids.end(), 1); 180 label_len = std::count(valids.begin(), valids.end(), 1);
@@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
176 } 185 }
177 186
178 assert(tokens.size() == kMaxSeqLen); 187 assert(tokens.size() == kMaxSeqLen);
179 - assert(valids.size() == kMaxSeqLen); 188 + assert(valids.size() == kMaxSeqLen);
180 189
181 tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); 190 tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
182 valids_list.insert(valids_list.end(), valids.begin(), valids.end()); 191 valids_list.insert(valids_list.end(), valids.begin(), valids.end());
183 label_len_list.push_back(label_len); 192 label_len_list.push_back(label_len);
184 - } 193 + }
185 } 194 }
186 195
187 - std::string DecodeSentences(const std::string& raw_text,  
188 - const std::vector<int32_t>& case_pred,  
189 - const std::vector<int32_t>& punct_pred) const { 196 + std::string DecodeSentences(const std::string &raw_text,
  197 + const std::vector<int32_t> &case_pred,
  198 + const std::vector<int32_t> &punct_pred) const {
190 std::string result_text; 199 std::string result_text;
191 std::istringstream iss(raw_text); 200 std::istringstream iss(raw_text);
192 std::vector<std::string> words; 201 std::vector<std::string> words;
@@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
203 std::string prefix = ((i != 0) ? " " : ""); 212 std::string prefix = ((i != 0) ? " " : "");
204 result_text += prefix; 213 result_text += prefix;
205 switch (case_pred[i]) { 214 switch (case_pred[i]) {
206 - case 1: // upper 215 + case 1: // upper
207 { 216 {
208 - std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); }); 217 + std::transform(words[i].begin(), words[i].end(), words[i].begin(),
  218 + [](auto c) { return std::toupper(c); });
209 result_text += words[i]; 219 result_text += words[i];
210 break; 220 break;
211 } 221 }
212 - case 2: // cap 222 + case 2: // cap
213 { 223 {
214 words[i][0] = std::toupper(words[i][0]); 224 words[i][0] = std::toupper(words[i][0]);
215 result_text += words[i]; 225 result_text += words[i];
216 break; 226 break;
217 } 227 }
218 - case 3: // mix case 228 + case 3: // mix case
219 { 229 {
220 - // TODO:  
221 - // Need to add a map containing supported mix case words so that we can fetch the predicted word from the map  
222 - // e.g. mcdonald's -> McDonald's 230 + // TODO(frankyoujian):
  231 + // Need to add a map containing supported mix case words so that we
  232 + // can fetch the predicted word from the map e.g. mcdonald's ->
  233 + // McDonald's
223 result_text += words[i]; 234 result_text += words[i];
224 break; 235 break;
225 } 236 }
226 - default:  
227 - { 237 + default: {
228 result_text += words[i]; 238 result_text += words[i];
229 break; 239 break;
230 } 240 }
@@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
232 242
233 std::string suffix; 243 std::string suffix;
234 switch (punct_pred[i]) { 244 switch (punct_pred[i]) {
235 - case 1: // comma 245 + case 1: // comma
236 { 246 {
237 suffix = ","; 247 suffix = ",";
238 break; 248 break;
239 } 249 }
240 - case 2: // period 250 + case 2: // period
241 { 251 {
242 suffix = "."; 252 suffix = ".";
243 break; 253 break;
244 } 254 }
245 - case 3: // question 255 + case 3: // question
246 { 256 {
247 suffix = "?"; 257 suffix = "?";
248 break; 258 break;
@@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { @@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
252 } 262 }
253 263
254 result_text += suffix; 264 result_text += suffix;
255 - } 265 + }
256 266
257 - return result_text; 267 + return result_text;
258 } 268 }
259 269
260 private: 270 private:
@@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( @@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
20 return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config); 20 return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
21 } 21 }
22 22
23 - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); 23 + SHERPA_ONNX_LOGE(
  24 + "Please specify a punctuation model and bpe vocab! Return a null "
  25 + "pointer");
24 return nullptr; 26 return nullptr;
25 } 27 }
26 28
@@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( @@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
31 return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config); 33 return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
32 } 34 }
33 35
34 - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); 36 + SHERPA_ONNX_LOGE(
  37 + "Please specify a punctuation model and bpe vocab! Return a null "
  38 + "pointer");
35 return nullptr; 39 return nullptr;
36 } 40 }
37 #endif 41 #endif
@@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) { @@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
13 po->Register("cnn-bilstm", &cnn_bilstm, 13 po->Register("cnn-bilstm", &cnn_bilstm,
14 "Path to the light-weight CNN-BiLSTM model"); 14 "Path to the light-weight CNN-BiLSTM model");
15 15
16 - po->Register("bpe-vocab", &bpe_vocab,  
17 - "Path to the bpe vocab file"); 16 + po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file");
18 17
19 po->Register("num-threads", &num_threads, 18 po->Register("num-threads", &num_threads,
20 "Number of threads to run the neural network"); 19 "Number of threads to run the neural network");
@@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const { @@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
33 } 32 }
34 33
35 if (!FileExists(cnn_bilstm)) { 34 if (!FileExists(cnn_bilstm)) {
36 - SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",  
37 - cnn_bilstm.c_str()); 35 + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str());
38 return false; 36 return false;
39 } 37 }
40 38
@@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const { @@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
44 } 42 }
45 43
46 if (!FileExists(bpe_vocab)) { 44 if (!FileExists(bpe_vocab)) {
47 - SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",  
48 - bpe_vocab.c_str()); 45 + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str());
49 return false; 46 return false;
50 } 47 }
51 48
@@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig { @@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig {
22 OnlinePunctuationModelConfig() = default; 22 OnlinePunctuationModelConfig() = default;
23 23
24 OnlinePunctuationModelConfig(const std::string &cnn_bilstm, 24 OnlinePunctuationModelConfig(const std::string &cnn_bilstm,
25 - const std::string &bpe_vocab,  
26 - int32_t num_threads, bool debug,  
27 - const std::string &provider) 25 + const std::string &bpe_vocab,
  26 + int32_t num_threads, bool debug,
  27 + const std::string &provider)
28 : cnn_bilstm(cnn_bilstm), 28 : cnn_bilstm(cnn_bilstm),
29 bpe_vocab(bpe_vocab), 29 bpe_vocab(bpe_vocab),
30 num_threads(num_threads), 30 num_threads(num_threads),
@@ -14,9 +14,7 @@ @@ -14,9 +14,7 @@
14 14
15 namespace sherpa_onnx { 15 namespace sherpa_onnx {
16 16
17 -void OnlinePunctuationConfig::Register(ParseOptions *po) {  
18 - model.Register(po);  
19 -} 17 +void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); }
20 18
21 bool OnlinePunctuationConfig::Validate() const { 19 bool OnlinePunctuationConfig::Validate() const {
22 if (!model.Validate()) { 20 if (!model.Validate()) {
@@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) @@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config)
40 38
41 #if __ANDROID_API__ >= 9 39 #if __ANDROID_API__ >= 9
42 OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, 40 OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
43 - const OnlinePunctuationConfig &config) 41 + const OnlinePunctuationConfig &config)
44 : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} 42 : impl_(OnlinePunctuationImpl::Create(mgr, config)) {}
45 #endif 43 #endif
46 44
47 OnlinePunctuation::~OnlinePunctuation() = default; 45 OnlinePunctuation::~OnlinePunctuation() = default;
48 46
49 -std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const { 47 +std::string OnlinePunctuation::AddPunctuationWithCase(
  48 + const std::string &text) const {
50 return impl_->AddPunctuationWithCase(text); 49 return impl_->AddPunctuationWithCase(text);
51 } 50 }
52 51
@@ -40,8 +40,7 @@ class OnlinePunctuation { @@ -40,8 +40,7 @@ class OnlinePunctuation {
40 explicit OnlinePunctuation(const OnlinePunctuationConfig &config); 40 explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
41 41
42 #if __ANDROID_API__ >= 9 42 #if __ANDROID_API__ >= 9
43 - OnlinePunctuation(AAssetManager *mgr,  
44 - const OnlinePunctuationConfig &config); 43 + OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config);
45 #endif 44 #endif
46 45
47 ~OnlinePunctuation(); 46 ~OnlinePunctuation();
@@ -3,9 +3,9 @@ @@ -3,9 +3,9 @@
3 // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) 3 // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
4 4
5 #include <stdio.h> 5 #include <stdio.h>
6 -#include <iostream>  
7 6
8 #include <chrono> // NOLINT 7 #include <chrono> // NOLINT
  8 +#include <iostream>
9 9
10 #include "sherpa-onnx/csrc/online-punctuation.h" 10 #include "sherpa-onnx/csrc/online-punctuation.h"
11 #include "sherpa-onnx/csrc/parse-options.h" 11 #include "sherpa-onnx/csrc/parse-options.h"
@@ -57,7 +57,7 @@ The output text should look like below: @@ -57,7 +57,7 @@ The output text should look like below:
57 std::string text = po.GetArg(1); 57 std::string text = po.GetArg(1);
58 58
59 std::string text_with_punct_case = punct.AddPunctuationWithCase(text); 59 std::string text_with_punct_case = punct.AddPunctuationWithCase(text);
60 - 60 +
61 const auto end = std::chrono::steady_clock::now(); 61 const auto end = std::chrono::steady_clock::now();
62 fprintf(stderr, "Done\n"); 62 fprintf(stderr, "Done\n");
63 63