Committed by
GitHub
Fix style issues for online punctuation source files (#1225)
正在显示
10 个修改的文件
包含
135 行增加
和
121 行删除
| @@ -58,6 +58,7 @@ def get_binaries(): | @@ -58,6 +58,7 @@ def get_binaries(): | ||
| 58 | "sherpa-onnx-offline-tts", | 58 | "sherpa-onnx-offline-tts", |
| 59 | "sherpa-onnx-offline-tts-play", | 59 | "sherpa-onnx-offline-tts-play", |
| 60 | "sherpa-onnx-offline-websocket-server", | 60 | "sherpa-onnx-offline-websocket-server", |
| 61 | + "sherpa-onnx-online-punctuation", | ||
| 61 | "sherpa-onnx-online-websocket-client", | 62 | "sherpa-onnx-online-websocket-client", |
| 62 | "sherpa-onnx-online-websocket-server", | 63 | "sherpa-onnx-online-websocket-server", |
| 63 | "sherpa-onnx-vad-microphone", | 64 | "sherpa-onnx-vad-microphone", |
| @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl { | @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl { | ||
| 35 | } | 35 | } |
| 36 | #endif | 36 | #endif |
| 37 | 37 | ||
| 38 | - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) { | ||
| 39 | - std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; | 38 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, |
| 39 | + Ort::Value valid_ids, | ||
| 40 | + Ort::Value label_lens) { | ||
| 41 | + std::array<Ort::Value, 3> inputs = { | ||
| 42 | + std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; | ||
| 40 | 43 | ||
| 41 | auto ans = | 44 | auto ans = |
| 42 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | 45 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), |
| @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( | @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( | ||
| 117 | 120 | ||
| 118 | OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; | 121 | OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; |
| 119 | 122 | ||
| 120 | -std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids, | ||
| 121 | - Ort::Value valid_ids, | ||
| 122 | - Ort::Value label_lens) const { | ||
| 123 | - return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens)); | 123 | +std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward( |
| 124 | + Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const { | ||
| 125 | + return impl_->Forward(std::move(token_ids), std::move(valid_ids), | ||
| 126 | + std::move(label_lens)); | ||
| 124 | } | 127 | } |
| 125 | 128 | ||
| 126 | OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { | 129 | OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { |
| 127 | return impl_->Allocator(); | 130 | return impl_->Allocator(); |
| 128 | } | 131 | } |
| 129 | 132 | ||
| 130 | -const OnlineCNNBiLSTMModelMetaData & | ||
| 131 | -OnlineCNNBiLSTMModel::GetModelMetadata() const { | 133 | +const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata() |
| 134 | + const { | ||
| 132 | return impl_->GetModelMetadata(); | 135 | return impl_->GetModelMetadata(); |
| 133 | } | 136 | } |
| 134 | 137 |
| @@ -23,12 +23,11 @@ namespace sherpa_onnx { | @@ -23,12 +23,11 @@ namespace sherpa_onnx { | ||
| 23 | */ | 23 | */ |
| 24 | class OnlineCNNBiLSTMModel { | 24 | class OnlineCNNBiLSTMModel { |
| 25 | public: | 25 | public: |
| 26 | - explicit OnlineCNNBiLSTMModel( | ||
| 27 | - const OnlinePunctuationModelConfig &config); | 26 | + explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config); |
| 28 | 27 | ||
| 29 | #if __ANDROID_API__ >= 9 | 28 | #if __ANDROID_API__ >= 9 |
| 30 | OnlineCNNBiLSTMModel(AAssetManager *mgr, | 29 | OnlineCNNBiLSTMModel(AAssetManager *mgr, |
| 31 | - const OnlinePunctuationModelConfig &config); | 30 | + const OnlinePunctuationModelConfig &config); |
| 32 | #endif | 31 | #endif |
| 33 | 32 | ||
| 34 | ~OnlineCNNBiLSTMModel(); | 33 | ~OnlineCNNBiLSTMModel(); |
| @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel { | @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel { | ||
| 43 | * - case_logits: A 2-D tensor of shape (T', num_cases). | 42 | * - case_logits: A 2-D tensor of shape (T', num_cases). |
| 44 | * - punct_logits: A 2-D tensor of shape (T', num_puncts). | 43 | * - punct_logits: A 2-D tensor of shape (T', num_puncts). |
| 45 | */ | 44 | */ |
| 46 | - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const; | 45 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, |
| 46 | + Ort::Value valid_ids, | ||
| 47 | + Ort::Value label_lens) const; | ||
| 47 | 48 | ||
| 48 | /** Return an allocator for allocating memory | 49 | /** Return an allocator for allocating memory |
| 49 | */ | 50 | */ |
| @@ -7,27 +7,28 @@ | @@ -7,27 +7,28 @@ | ||
| 7 | 7 | ||
| 8 | #include <math.h> | 8 | #include <math.h> |
| 9 | 9 | ||
| 10 | +#include <algorithm> | ||
| 10 | #include <memory> | 11 | #include <memory> |
| 11 | #include <string> | 12 | #include <string> |
| 12 | #include <utility> | 13 | #include <utility> |
| 13 | #include <vector> | 14 | #include <vector> |
| 14 | -#include <algorithm> | ||
| 15 | 15 | ||
| 16 | #if __ANDROID_API__ >= 9 | 16 | #if __ANDROID_API__ >= 9 |
| 17 | #include "android/asset_manager.h" | 17 | #include "android/asset_manager.h" |
| 18 | #include "android/asset_manager_jni.h" | 18 | #include "android/asset_manager_jni.h" |
| 19 | #endif | 19 | #endif |
| 20 | 20 | ||
| 21 | +#include <chrono> // NOLINT | ||
| 22 | + | ||
| 21 | #include "sherpa-onnx/csrc/macros.h" | 23 | #include "sherpa-onnx/csrc/macros.h" |
| 22 | #include "sherpa-onnx/csrc/math.h" | 24 | #include "sherpa-onnx/csrc/math.h" |
| 25 | +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" | ||
| 23 | #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" | 26 | #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" |
| 24 | #include "sherpa-onnx/csrc/online-punctuation-impl.h" | 27 | #include "sherpa-onnx/csrc/online-punctuation-impl.h" |
| 25 | #include "sherpa-onnx/csrc/online-punctuation.h" | 28 | #include "sherpa-onnx/csrc/online-punctuation.h" |
| 26 | -#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" | ||
| 27 | -#include "sherpa-onnx/csrc/text-utils.h" | ||
| 28 | #include "sherpa-onnx/csrc/onnx-utils.h" | 29 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 30 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 29 | #include "ssentencepiece/csrc/ssentencepiece.h" | 31 | #include "ssentencepiece/csrc/ssentencepiece.h" |
| 30 | -#include <chrono> // NOLINT | ||
| 31 | 32 | ||
| 32 | namespace sherpa_onnx { | 33 | namespace sherpa_onnx { |
| 33 | 34 | ||
| @@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200; | @@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200; | ||
| 35 | 36 | ||
| 36 | class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | 37 | class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { |
| 37 | public: | 38 | public: |
| 38 | - explicit OnlinePunctuationCNNBiLSTMImpl( | ||
| 39 | - const OnlinePunctuationConfig &config) | 39 | + explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config) |
| 40 | : config_(config), model_(config.model) { | 40 | : config_(config), model_(config.model) { |
| 41 | - if (!config_.model.bpe_vocab.empty()) { | ||
| 42 | - bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>( | ||
| 43 | - config_.model.bpe_vocab); | ||
| 44 | - } | ||
| 45 | - } | 41 | + if (!config_.model.bpe_vocab.empty()) { |
| 42 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>( | ||
| 43 | + config_.model.bpe_vocab); | ||
| 44 | + } | ||
| 45 | + } | ||
| 46 | 46 | ||
| 47 | #if __ANDROID_API__ >= 9 | 47 | #if __ANDROID_API__ >= 9 |
| 48 | OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, | 48 | OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, |
| 49 | - const OnlinePunctuationConfig &config) | 49 | + const OnlinePunctuationConfig &config) |
| 50 | : config_(config), model_(mgr, config.model) { | 50 | : config_(config), model_(mgr, config.model) { |
| 51 | - if (!config_.model.bpe_vocab.empty()) { | ||
| 52 | - auto buf = ReadFile(mgr, config_.model.bpe_vocab); | ||
| 53 | - std::istringstream iss(std::string(buf.begin(), buf.end())); | ||
| 54 | - bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss); | ||
| 55 | - } | ||
| 56 | - } | 51 | + if (!config_.model.bpe_vocab.empty()) { |
| 52 | + auto buf = ReadFile(mgr, config_.model.bpe_vocab); | ||
| 53 | + std::istringstream iss(std::string(buf.begin(), buf.end())); | ||
| 54 | + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss); | ||
| 55 | + } | ||
| 56 | + } | ||
| 57 | #endif | 57 | #endif |
| 58 | 58 | ||
| 59 | std::string AddPunctuationWithCase(const std::string &text) const override { | 59 | std::string AddPunctuationWithCase(const std::string &text) const override { |
| @@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 61 | return {}; | 61 | return {}; |
| 62 | } | 62 | } |
| 63 | 63 | ||
| 64 | - std::vector<int32_t> tokens_list; // N * kMaxSeqLen | ||
| 65 | - std::vector<int32_t> valids_list; // N * kMaxSeqLen | ||
| 66 | - std::vector<int32_t> label_len_list; // N | 64 | + std::vector<int32_t> tokens_list; // N * kMaxSeqLen |
| 65 | + std::vector<int32_t> valids_list; // N * kMaxSeqLen | ||
| 66 | + std::vector<int32_t> label_len_list; // N | ||
| 67 | 67 | ||
| 68 | EncodeSentences(text, tokens_list, valids_list, label_len_list); | 68 | EncodeSentences(text, tokens_list, valids_list, label_len_list); |
| 69 | 69 | ||
| @@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 75 | int32_t n = label_len_list.size(); | 75 | int32_t n = label_len_list.size(); |
| 76 | 76 | ||
| 77 | std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen}; | 77 | std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen}; |
| 78 | - Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(), | ||
| 79 | - token_ids_shape.data(), token_ids_shape.size()); | 78 | + Ort::Value token_ids = Ort::Value::CreateTensor( |
| 79 | + memory_info, tokens_list.data(), tokens_list.size(), | ||
| 80 | + token_ids_shape.data(), token_ids_shape.size()); | ||
| 80 | 81 | ||
| 81 | std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen}; | 82 | std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen}; |
| 82 | - Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(), | ||
| 83 | - valid_ids_shape.data(), valid_ids_shape.size()); | 83 | + Ort::Value valid_ids = Ort::Value::CreateTensor( |
| 84 | + memory_info, valids_list.data(), valids_list.size(), | ||
| 85 | + valid_ids_shape.data(), valid_ids_shape.size()); | ||
| 84 | 86 | ||
| 85 | std::array<int64_t, 1> label_len_shape = {n}; | 87 | std::array<int64_t, 1> label_len_shape = {n}; |
| 86 | - Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(), | ||
| 87 | - label_len_shape.data(), label_len_shape.size()); | 88 | + Ort::Value label_len = Ort::Value::CreateTensor( |
| 89 | + memory_info, label_len_list.data(), label_len_list.size(), | ||
| 90 | + label_len_shape.data(), label_len_shape.size()); | ||
| 88 | 91 | ||
| 89 | - auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); | 92 | + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), |
| 93 | + std::move(label_len)); | ||
| 90 | 94 | ||
| 91 | std::vector<int32_t> case_pred; | 95 | std::vector<int32_t> case_pred; |
| 92 | std::vector<int32_t> punct_pred; | 96 | std::vector<int32_t> punct_pred; |
| 93 | - const float* active_case_logits = pair.first.GetTensorData<float>(); | ||
| 94 | - const float* active_punct_logits = pair.second.GetTensorData<float>(); | ||
| 95 | - std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); | 97 | + const float *active_case_logits = pair.first.GetTensorData<float>(); |
| 98 | + const float *active_punct_logits = pair.second.GetTensorData<float>(); | ||
| 99 | + std::vector<int64_t> case_logits_shape = | ||
| 100 | + pair.first.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 96 | 101 | ||
| 97 | for (int32_t i = 0; i < case_logits_shape[0]; ++i) { | 102 | for (int32_t i = 0; i < case_logits_shape[0]; ++i) { |
| 98 | - const float* p_cur_case = active_case_logits + i * meta_data.num_cases; | 103 | + const float *p_cur_case = active_case_logits + i * meta_data.num_cases; |
| 99 | auto index_case = static_cast<int32_t>(std::distance( | 104 | auto index_case = static_cast<int32_t>(std::distance( |
| 100 | - p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); | 105 | + p_cur_case, |
| 106 | + std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); | ||
| 101 | case_pred.push_back(index_case); | 107 | case_pred.push_back(index_case); |
| 102 | 108 | ||
| 103 | - const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; | 109 | + const float *p_cur_punct = |
| 110 | + active_punct_logits + i * meta_data.num_punctuations; | ||
| 104 | auto index_punct = static_cast<int32_t>(std::distance( | 111 | auto index_punct = static_cast<int32_t>(std::distance( |
| 105 | - p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); | 112 | + p_cur_punct, |
| 113 | + std::max_element(p_cur_punct, | ||
| 114 | + p_cur_punct + meta_data.num_punctuations))); | ||
| 106 | punct_pred.push_back(index_punct); | 115 | punct_pred.push_back(index_punct); |
| 107 | } | 116 | } |
| 108 | 117 | ||
| @@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 112 | } | 121 | } |
| 113 | 122 | ||
| 114 | private: | 123 | private: |
| 115 | - void EncodeSentences(const std::string& text, | ||
| 116 | - std::vector<int32_t>& tokens_list, | ||
| 117 | - std::vector<int32_t>& valids_list, | ||
| 118 | - std::vector<int32_t>& label_len_list) const { | 124 | + void EncodeSentences(const std::string &text, |
| 125 | + std::vector<int32_t> &tokens_list, // NOLINT | ||
| 126 | + std::vector<int32_t> &valids_list, // NOLINT | ||
| 127 | + std::vector<int32_t> &label_len_list) const { // NOLINT | ||
| 119 | std::vector<int32_t> tokens; | 128 | std::vector<int32_t> tokens; |
| 120 | std::vector<int32_t> valids; | 129 | std::vector<int32_t> valids; |
| 121 | int32_t label_len = 0; | 130 | int32_t label_len = 0; |
| 122 | 131 | ||
| 123 | - tokens.push_back(1); // hardcode 1 now, 1 - <s> | 132 | + tokens.push_back(1); // hardcode 1 now, 1 - <s> |
| 124 | valids.push_back(1); | 133 | valids.push_back(1); |
| 125 | 134 | ||
| 126 | std::stringstream ss(text); | 135 | std::stringstream ss(text); |
| 127 | std::string word; | 136 | std::string word; |
| 128 | while (ss >> word) { | 137 | while (ss >> word) { |
| 129 | - std::vector<int32_t> word_tokens; | ||
| 130 | - bpe_encoder_->Encode(word, &word_tokens); | 138 | + std::vector<int32_t> word_tokens; |
| 139 | + bpe_encoder_->Encode(word, &word_tokens); | ||
| 131 | 140 | ||
| 132 | - int32_t seq_len = tokens.size() + word_tokens.size(); | ||
| 133 | - if (seq_len > kMaxSeqLen - 1) { | ||
| 134 | - tokens.push_back(2); // hardcode 2 now, 2 - </s> | ||
| 135 | - valids.push_back(1); | 141 | + int32_t seq_len = tokens.size() + word_tokens.size(); |
| 142 | + if (seq_len > kMaxSeqLen - 1) { | ||
| 143 | + tokens.push_back(2); // hardcode 2 now, 2 - </s> | ||
| 144 | + valids.push_back(1); | ||
| 136 | 145 | ||
| 137 | - label_len = std::count(valids.begin(), valids.end(), 1); | 146 | + label_len = std::count(valids.begin(), valids.end(), 1); |
| 138 | 147 | ||
| 139 | - if (tokens.size() < kMaxSeqLen) { | ||
| 140 | - tokens.resize(kMaxSeqLen, 0); | ||
| 141 | - valids.resize(kMaxSeqLen, 0); | ||
| 142 | - } | 148 | + if (tokens.size() < kMaxSeqLen) { |
| 149 | + tokens.resize(kMaxSeqLen, 0); | ||
| 150 | + valids.resize(kMaxSeqLen, 0); | ||
| 151 | + } | ||
| 143 | 152 | ||
| 144 | - assert(tokens.size() == kMaxSeqLen); | ||
| 145 | - assert(valids.size() == kMaxSeqLen); | 153 | + assert(tokens.size() == kMaxSeqLen); |
| 154 | + assert(valids.size() == kMaxSeqLen); | ||
| 146 | 155 | ||
| 147 | - tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); | ||
| 148 | - valids_list.insert(valids_list.end(), valids.begin(), valids.end()); | ||
| 149 | - label_len_list.push_back(label_len); | 156 | + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); |
| 157 | + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); | ||
| 158 | + label_len_list.push_back(label_len); | ||
| 150 | 159 | ||
| 151 | - std::vector<int32_t>().swap(tokens); | ||
| 152 | - std::vector<int32_t>().swap(valids); | ||
| 153 | - label_len = 0; | ||
| 154 | - tokens.push_back(1); // hardcode 1 now, 1 - <s> | ||
| 155 | - valids.push_back(1); | ||
| 156 | - } | 160 | + std::vector<int32_t>().swap(tokens); |
| 161 | + std::vector<int32_t>().swap(valids); | ||
| 162 | + label_len = 0; | ||
| 163 | + tokens.push_back(1); // hardcode 1 now, 1 - <s> | ||
| 164 | + valids.push_back(1); | ||
| 165 | + } | ||
| 157 | 166 | ||
| 158 | - tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); | ||
| 159 | - valids.push_back(1); // only the first sub word is valid | ||
| 160 | - int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1; | ||
| 161 | - if (remaining_size > 0) { | ||
| 162 | - int32_t valids_cur_size = static_cast<int32_t>(valids.size()); | ||
| 163 | - valids.resize(valids_cur_size + remaining_size, 0); | ||
| 164 | - } | 167 | + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); |
| 168 | + valids.push_back(1); // only the first sub word is valid | ||
| 169 | + int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1; | ||
| 170 | + if (remaining_size > 0) { | ||
| 171 | + int32_t valids_cur_size = static_cast<int32_t>(valids.size()); | ||
| 172 | + valids.resize(valids_cur_size + remaining_size, 0); | ||
| 173 | + } | ||
| 165 | } | 174 | } |
| 166 | 175 | ||
| 167 | if (tokens.size() > 0) { | 176 | if (tokens.size() > 0) { |
| 168 | - tokens.push_back(2); // hardcode 2 now, 2 - </s> | 177 | + tokens.push_back(2); // hardcode 2 now, 2 - </s> |
| 169 | valids.push_back(1); | 178 | valids.push_back(1); |
| 170 | 179 | ||
| 171 | label_len = std::count(valids.begin(), valids.end(), 1); | 180 | label_len = std::count(valids.begin(), valids.end(), 1); |
| @@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 176 | } | 185 | } |
| 177 | 186 | ||
| 178 | assert(tokens.size() == kMaxSeqLen); | 187 | assert(tokens.size() == kMaxSeqLen); |
| 179 | - assert(valids.size() == kMaxSeqLen); | 188 | + assert(valids.size() == kMaxSeqLen); |
| 180 | 189 | ||
| 181 | tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); | 190 | tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); |
| 182 | valids_list.insert(valids_list.end(), valids.begin(), valids.end()); | 191 | valids_list.insert(valids_list.end(), valids.begin(), valids.end()); |
| 183 | label_len_list.push_back(label_len); | 192 | label_len_list.push_back(label_len); |
| 184 | - } | 193 | + } |
| 185 | } | 194 | } |
| 186 | 195 | ||
| 187 | - std::string DecodeSentences(const std::string& raw_text, | ||
| 188 | - const std::vector<int32_t>& case_pred, | ||
| 189 | - const std::vector<int32_t>& punct_pred) const { | 196 | + std::string DecodeSentences(const std::string &raw_text, |
| 197 | + const std::vector<int32_t> &case_pred, | ||
| 198 | + const std::vector<int32_t> &punct_pred) const { | ||
| 190 | std::string result_text; | 199 | std::string result_text; |
| 191 | std::istringstream iss(raw_text); | 200 | std::istringstream iss(raw_text); |
| 192 | std::vector<std::string> words; | 201 | std::vector<std::string> words; |
| @@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 203 | std::string prefix = ((i != 0) ? " " : ""); | 212 | std::string prefix = ((i != 0) ? " " : ""); |
| 204 | result_text += prefix; | 213 | result_text += prefix; |
| 205 | switch (case_pred[i]) { | 214 | switch (case_pred[i]) { |
| 206 | - case 1: // upper | 215 | + case 1: // upper |
| 207 | { | 216 | { |
| 208 | - std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); }); | 217 | + std::transform(words[i].begin(), words[i].end(), words[i].begin(), |
| 218 | + [](auto c) { return std::toupper(c); }); | ||
| 209 | result_text += words[i]; | 219 | result_text += words[i]; |
| 210 | break; | 220 | break; |
| 211 | } | 221 | } |
| 212 | - case 2: // cap | 222 | + case 2: // cap |
| 213 | { | 223 | { |
| 214 | words[i][0] = std::toupper(words[i][0]); | 224 | words[i][0] = std::toupper(words[i][0]); |
| 215 | result_text += words[i]; | 225 | result_text += words[i]; |
| 216 | break; | 226 | break; |
| 217 | } | 227 | } |
| 218 | - case 3: // mix case | 228 | + case 3: // mix case |
| 219 | { | 229 | { |
| 220 | - // TODO: | ||
| 221 | - // Need to add a map containing supported mix case words so that we can fetch the predicted word from the map | ||
| 222 | - // e.g. mcdonald's -> McDonald's | 230 | + // TODO(frankyoujian): |
| 231 | + // Need to add a map containing supported mix case words so that we | ||
| 232 | + // can fetch the predicted word from the map e.g. mcdonald's -> | ||
| 233 | + // McDonald's | ||
| 223 | result_text += words[i]; | 234 | result_text += words[i]; |
| 224 | break; | 235 | break; |
| 225 | } | 236 | } |
| 226 | - default: | ||
| 227 | - { | 237 | + default: { |
| 228 | result_text += words[i]; | 238 | result_text += words[i]; |
| 229 | break; | 239 | break; |
| 230 | } | 240 | } |
| @@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 232 | 242 | ||
| 233 | std::string suffix; | 243 | std::string suffix; |
| 234 | switch (punct_pred[i]) { | 244 | switch (punct_pred[i]) { |
| 235 | - case 1: // comma | 245 | + case 1: // comma |
| 236 | { | 246 | { |
| 237 | suffix = ","; | 247 | suffix = ","; |
| 238 | break; | 248 | break; |
| 239 | } | 249 | } |
| 240 | - case 2: // period | 250 | + case 2: // period |
| 241 | { | 251 | { |
| 242 | suffix = "."; | 252 | suffix = "."; |
| 243 | break; | 253 | break; |
| 244 | } | 254 | } |
| 245 | - case 3: // question | 255 | + case 3: // question |
| 246 | { | 256 | { |
| 247 | suffix = "?"; | 257 | suffix = "?"; |
| 248 | break; | 258 | break; |
| @@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | @@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
| 252 | } | 262 | } |
| 253 | 263 | ||
| 254 | result_text += suffix; | 264 | result_text += suffix; |
| 255 | - } | 265 | + } |
| 256 | 266 | ||
| 257 | - return result_text; | 267 | + return result_text; |
| 258 | } | 268 | } |
| 259 | 269 | ||
| 260 | private: | 270 | private: |
| @@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( | @@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( | ||
| 20 | return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config); | 20 | return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config); |
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); | 23 | + SHERPA_ONNX_LOGE( |
| 24 | + "Please specify a punctuation model and bpe vocab! Return a null " | ||
| 25 | + "pointer"); | ||
| 24 | return nullptr; | 26 | return nullptr; |
| 25 | } | 27 | } |
| 26 | 28 | ||
| @@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( | @@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create( | ||
| 31 | return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config); | 33 | return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config); |
| 32 | } | 34 | } |
| 33 | 35 | ||
| 34 | - SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); | 36 | + SHERPA_ONNX_LOGE( |
| 37 | + "Please specify a punctuation model and bpe vocab! Return a null " | ||
| 38 | + "pointer"); | ||
| 35 | return nullptr; | 39 | return nullptr; |
| 36 | } | 40 | } |
| 37 | #endif | 41 | #endif |
| @@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) { | @@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) { | ||
| 13 | po->Register("cnn-bilstm", &cnn_bilstm, | 13 | po->Register("cnn-bilstm", &cnn_bilstm, |
| 14 | "Path to the light-weight CNN-BiLSTM model"); | 14 | "Path to the light-weight CNN-BiLSTM model"); |
| 15 | 15 | ||
| 16 | - po->Register("bpe-vocab", &bpe_vocab, | ||
| 17 | - "Path to the bpe vocab file"); | 16 | + po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file"); |
| 18 | 17 | ||
| 19 | po->Register("num-threads", &num_threads, | 18 | po->Register("num-threads", &num_threads, |
| 20 | "Number of threads to run the neural network"); | 19 | "Number of threads to run the neural network"); |
| @@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | @@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | ||
| 33 | } | 32 | } |
| 34 | 33 | ||
| 35 | if (!FileExists(cnn_bilstm)) { | 34 | if (!FileExists(cnn_bilstm)) { |
| 36 | - SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", | ||
| 37 | - cnn_bilstm.c_str()); | 35 | + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str()); |
| 38 | return false; | 36 | return false; |
| 39 | } | 37 | } |
| 40 | 38 | ||
| @@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | @@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | ||
| 44 | } | 42 | } |
| 45 | 43 | ||
| 46 | if (!FileExists(bpe_vocab)) { | 44 | if (!FileExists(bpe_vocab)) { |
| 47 | - SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", | ||
| 48 | - bpe_vocab.c_str()); | 45 | + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str()); |
| 49 | return false; | 46 | return false; |
| 50 | } | 47 | } |
| 51 | 48 |
| @@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig { | @@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig { | ||
| 22 | OnlinePunctuationModelConfig() = default; | 22 | OnlinePunctuationModelConfig() = default; |
| 23 | 23 | ||
| 24 | OnlinePunctuationModelConfig(const std::string &cnn_bilstm, | 24 | OnlinePunctuationModelConfig(const std::string &cnn_bilstm, |
| 25 | - const std::string &bpe_vocab, | ||
| 26 | - int32_t num_threads, bool debug, | ||
| 27 | - const std::string &provider) | 25 | + const std::string &bpe_vocab, |
| 26 | + int32_t num_threads, bool debug, | ||
| 27 | + const std::string &provider) | ||
| 28 | : cnn_bilstm(cnn_bilstm), | 28 | : cnn_bilstm(cnn_bilstm), |
| 29 | bpe_vocab(bpe_vocab), | 29 | bpe_vocab(bpe_vocab), |
| 30 | num_threads(num_threads), | 30 | num_threads(num_threads), |
| @@ -14,9 +14,7 @@ | @@ -14,9 +14,7 @@ | ||
| 14 | 14 | ||
| 15 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 16 | 16 | ||
| 17 | -void OnlinePunctuationConfig::Register(ParseOptions *po) { | ||
| 18 | - model.Register(po); | ||
| 19 | -} | 17 | +void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); } |
| 20 | 18 | ||
| 21 | bool OnlinePunctuationConfig::Validate() const { | 19 | bool OnlinePunctuationConfig::Validate() const { |
| 22 | if (!model.Validate()) { | 20 | if (!model.Validate()) { |
| @@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) | @@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) | ||
| 40 | 38 | ||
| 41 | #if __ANDROID_API__ >= 9 | 39 | #if __ANDROID_API__ >= 9 |
| 42 | OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, | 40 | OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, |
| 43 | - const OnlinePunctuationConfig &config) | 41 | + const OnlinePunctuationConfig &config) |
| 44 | : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} | 42 | : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} |
| 45 | #endif | 43 | #endif |
| 46 | 44 | ||
| 47 | OnlinePunctuation::~OnlinePunctuation() = default; | 45 | OnlinePunctuation::~OnlinePunctuation() = default; |
| 48 | 46 | ||
| 49 | -std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const { | 47 | +std::string OnlinePunctuation::AddPunctuationWithCase( |
| 48 | + const std::string &text) const { | ||
| 50 | return impl_->AddPunctuationWithCase(text); | 49 | return impl_->AddPunctuationWithCase(text); |
| 51 | } | 50 | } |
| 52 | 51 |
| @@ -40,8 +40,7 @@ class OnlinePunctuation { | @@ -40,8 +40,7 @@ class OnlinePunctuation { | ||
| 40 | explicit OnlinePunctuation(const OnlinePunctuationConfig &config); | 40 | explicit OnlinePunctuation(const OnlinePunctuationConfig &config); |
| 41 | 41 | ||
| 42 | #if __ANDROID_API__ >= 9 | 42 | #if __ANDROID_API__ >= 9 |
| 43 | - OnlinePunctuation(AAssetManager *mgr, | ||
| 44 | - const OnlinePunctuationConfig &config); | 43 | + OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config); |
| 45 | #endif | 44 | #endif |
| 46 | 45 | ||
| 47 | ~OnlinePunctuation(); | 46 | ~OnlinePunctuation(); |
| @@ -3,9 +3,9 @@ | @@ -3,9 +3,9 @@ | ||
| 3 | // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) | 3 | // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) |
| 4 | 4 | ||
| 5 | #include <stdio.h> | 5 | #include <stdio.h> |
| 6 | -#include <iostream> | ||
| 7 | 6 | ||
| 8 | #include <chrono> // NOLINT | 7 | #include <chrono> // NOLINT |
| 8 | +#include <iostream> | ||
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/online-punctuation.h" | 10 | #include "sherpa-onnx/csrc/online-punctuation.h" |
| 11 | #include "sherpa-onnx/csrc/parse-options.h" | 11 | #include "sherpa-onnx/csrc/parse-options.h" |
| @@ -57,7 +57,7 @@ The output text should look like below: | @@ -57,7 +57,7 @@ The output text should look like below: | ||
| 57 | std::string text = po.GetArg(1); | 57 | std::string text = po.GetArg(1); |
| 58 | 58 | ||
| 59 | std::string text_with_punct_case = punct.AddPunctuationWithCase(text); | 59 | std::string text_with_punct_case = punct.AddPunctuationWithCase(text); |
| 60 | - | 60 | + |
| 61 | const auto end = std::chrono::steady_clock::now(); | 61 | const auto end = std::chrono::steady_clock::now(); |
| 62 | fprintf(stderr, "Done\n"); | 62 | fprintf(stderr, "Done\n"); |
| 63 | 63 |
-
请 注册 或 登录 后发表评论