Fangjun Kuang
Committed by GitHub

Fix style issues for online punctuation source files (#1225)

... ... @@ -58,6 +58,7 @@ def get_binaries():
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
"sherpa-onnx-online-punctuation",
"sherpa-onnx-online-websocket-client",
"sherpa-onnx-online-websocket-server",
"sherpa-onnx-vad-microphone",
... ...
... ... @@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl {
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) {
std::array<Ort::Value, 3> inputs = {
std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
auto ans =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
... ... @@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(
Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const {
return impl_->Forward(std::move(token_ids), std::move(valid_ids),
std::move(label_lens));
}
OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
return impl_->Allocator();
}
const OnlineCNNBiLSTMModelMetaData &
OnlineCNNBiLSTMModel::GetModelMetadata() const {
const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata()
const {
return impl_->GetModelMetadata();
}
... ...
... ... @@ -23,12 +23,11 @@ namespace sherpa_onnx {
*/
class OnlineCNNBiLSTMModel {
public:
explicit OnlineCNNBiLSTMModel(
const OnlinePunctuationModelConfig &config);
explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineCNNBiLSTMModel(AAssetManager *mgr,
const OnlinePunctuationModelConfig &config);
const OnlinePunctuationModelConfig &config);
#endif
~OnlineCNNBiLSTMModel();
... ... @@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel {
* - case_logits: A 2-D tensor of shape (T', num_cases).
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
*/
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
Ort::Value valid_ids,
Ort::Value label_lens) const;
/** Return an allocator for allocating memory
*/
... ...
... ... @@ -7,27 +7,28 @@
#include <math.h>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include <chrono> // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "ssentencepiece/csrc/ssentencepiece.h"
#include <chrono> // NOLINT
namespace sherpa_onnx {
... ... @@ -35,25 +36,24 @@ static const int32_t kMaxSeqLen = 200;
class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
public:
explicit OnlinePunctuationCNNBiLSTMImpl(
const OnlinePunctuationConfig &config)
explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config)
: config_(config), model_(config.model) {
if (!config_.model.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model.bpe_vocab);
}
}
if (!config_.model.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model.bpe_vocab);
}
}
#if __ANDROID_API__ >= 9
OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr,
const OnlinePunctuationConfig &config)
const OnlinePunctuationConfig &config)
: config_(config), model_(mgr, config.model) {
if (!config_.model.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
}
if (!config_.model.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}
}
#endif
std::string AddPunctuationWithCase(const std::string &text) const override {
... ... @@ -61,9 +61,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
return {};
}
std::vector<int32_t> tokens_list; // N * kMaxSeqLen
std::vector<int32_t> valids_list; // N * kMaxSeqLen
std::vector<int32_t> label_len_list; // N
std::vector<int32_t> tokens_list; // N * kMaxSeqLen
std::vector<int32_t> valids_list; // N * kMaxSeqLen
std::vector<int32_t> label_len_list; // N
EncodeSentences(text, tokens_list, valids_list, label_len_list);
... ... @@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
int32_t n = label_len_list.size();
std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
token_ids_shape.data(), token_ids_shape.size());
Ort::Value token_ids = Ort::Value::CreateTensor(
memory_info, tokens_list.data(), tokens_list.size(),
token_ids_shape.data(), token_ids_shape.size());
std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
valid_ids_shape.data(), valid_ids_shape.size());
Ort::Value valid_ids = Ort::Value::CreateTensor(
memory_info, valids_list.data(), valids_list.size(),
valid_ids_shape.data(), valid_ids_shape.size());
std::array<int64_t, 1> label_len_shape = {n};
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
label_len_shape.data(), label_len_shape.size());
Ort::Value label_len = Ort::Value::CreateTensor(
memory_info, label_len_list.data(), label_len_list.size(),
label_len_shape.data(), label_len_shape.size());
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids),
std::move(label_len));
std::vector<int32_t> case_pred;
std::vector<int32_t> punct_pred;
const float* active_case_logits = pair.first.GetTensorData<float>();
const float* active_punct_logits = pair.second.GetTensorData<float>();
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();
const float *active_case_logits = pair.first.GetTensorData<float>();
const float *active_punct_logits = pair.second.GetTensorData<float>();
std::vector<int64_t> case_logits_shape =
pair.first.GetTensorTypeAndShapeInfo().GetShape();
for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
const float *p_cur_case = active_case_logits + i * meta_data.num_cases;
auto index_case = static_cast<int32_t>(std::distance(
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
p_cur_case,
std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
case_pred.push_back(index_case);
const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
const float *p_cur_punct =
active_punct_logits + i * meta_data.num_punctuations;
auto index_punct = static_cast<int32_t>(std::distance(
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
p_cur_punct,
std::max_element(p_cur_punct,
p_cur_punct + meta_data.num_punctuations)));
punct_pred.push_back(index_punct);
}
... ... @@ -112,60 +121,60 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
private:
void EncodeSentences(const std::string& text,
std::vector<int32_t>& tokens_list,
std::vector<int32_t>& valids_list,
std::vector<int32_t>& label_len_list) const {
void EncodeSentences(const std::string &text,
std::vector<int32_t> &tokens_list, // NOLINT
std::vector<int32_t> &valids_list, // NOLINT
std::vector<int32_t> &label_len_list) const { // NOLINT
std::vector<int32_t> tokens;
std::vector<int32_t> valids;
int32_t label_len = 0;
tokens.push_back(1); // hardcode 1 now, 1 - <s>
tokens.push_back(1); // hardcode 1 now, 1 - <s>
valids.push_back(1);
std::stringstream ss(text);
std::string word;
while (ss >> word) {
std::vector<int32_t> word_tokens;
bpe_encoder_->Encode(word, &word_tokens);
std::vector<int32_t> word_tokens;
bpe_encoder_->Encode(word, &word_tokens);
int32_t seq_len = tokens.size() + word_tokens.size();
if (seq_len > kMaxSeqLen - 1) {
tokens.push_back(2); // hardcode 2 now, 2 - </s>
valids.push_back(1);
int32_t seq_len = tokens.size() + word_tokens.size();
if (seq_len > kMaxSeqLen - 1) {
tokens.push_back(2); // hardcode 2 now, 2 - </s>
valids.push_back(1);
label_len = std::count(valids.begin(), valids.end(), 1);
label_len = std::count(valids.begin(), valids.end(), 1);
if (tokens.size() < kMaxSeqLen) {
tokens.resize(kMaxSeqLen, 0);
valids.resize(kMaxSeqLen, 0);
}
if (tokens.size() < kMaxSeqLen) {
tokens.resize(kMaxSeqLen, 0);
valids.resize(kMaxSeqLen, 0);
}
assert(tokens.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);
assert(tokens.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
label_len_list.push_back(label_len);
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
label_len_list.push_back(label_len);
std::vector<int32_t>().swap(tokens);
std::vector<int32_t>().swap(valids);
label_len = 0;
tokens.push_back(1); // hardcode 1 now, 1 - <s>
valids.push_back(1);
}
std::vector<int32_t>().swap(tokens);
std::vector<int32_t>().swap(valids);
label_len = 0;
tokens.push_back(1); // hardcode 1 now, 1 - <s>
valids.push_back(1);
}
tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
valids.push_back(1); // only the first sub word is valid
int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
if (remaining_size > 0) {
int32_t valids_cur_size = static_cast<int32_t>(valids.size());
valids.resize(valids_cur_size + remaining_size, 0);
}
tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
valids.push_back(1); // only the first sub word is valid
int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
if (remaining_size > 0) {
int32_t valids_cur_size = static_cast<int32_t>(valids.size());
valids.resize(valids_cur_size + remaining_size, 0);
}
}
if (tokens.size() > 0) {
tokens.push_back(2); // hardcode 2 now, 2 - </s>
tokens.push_back(2); // hardcode 2 now, 2 - </s>
valids.push_back(1);
label_len = std::count(valids.begin(), valids.end(), 1);
... ... @@ -176,17 +185,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
assert(tokens.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
label_len_list.push_back(label_len);
}
}
}
std::string DecodeSentences(const std::string& raw_text,
const std::vector<int32_t>& case_pred,
const std::vector<int32_t>& punct_pred) const {
std::string DecodeSentences(const std::string &raw_text,
const std::vector<int32_t> &case_pred,
const std::vector<int32_t> &punct_pred) const {
std::string result_text;
std::istringstream iss(raw_text);
std::vector<std::string> words;
... ... @@ -203,28 +212,29 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
std::string prefix = ((i != 0) ? " " : "");
result_text += prefix;
switch (case_pred[i]) {
case 1: // upper
case 1: // upper
{
std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); });
std::transform(words[i].begin(), words[i].end(), words[i].begin(),
[](auto c) { return std::toupper(c); });
result_text += words[i];
break;
}
case 2: // cap
case 2: // cap
{
words[i][0] = std::toupper(words[i][0]);
result_text += words[i];
break;
}
case 3: // mix case
case 3: // mix case
{
// TODO:
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
// e.g. mcdonald's -> McDonald's
// TODO(frankyoujian):
// Need to add a map containing supported mix case words so that we
// can fetch the predicted word from the map e.g. mcdonald's ->
// McDonald's
result_text += words[i];
break;
}
default:
{
default: {
result_text += words[i];
break;
}
... ... @@ -232,17 +242,17 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
std::string suffix;
switch (punct_pred[i]) {
case 1: // comma
case 1: // comma
{
suffix = ",";
break;
}
case 2: // period
case 2: // period
{
suffix = ".";
break;
}
case 3: // question
case 3: // question
{
suffix = "?";
break;
... ... @@ -252,9 +262,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
result_text += suffix;
}
}
return result_text;
return result_text;
}
private:
... ...
... ... @@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
SHERPA_ONNX_LOGE(
"Please specify a punctuation model and bpe vocab! Return a null "
"pointer");
return nullptr;
}
... ... @@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
SHERPA_ONNX_LOGE(
"Please specify a punctuation model and bpe vocab! Return a null "
"pointer");
return nullptr;
}
#endif
... ...
... ... @@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
po->Register("cnn-bilstm", &cnn_bilstm,
"Path to the light-weight CNN-BiLSTM model");
po->Register("bpe-vocab", &bpe_vocab,
"Path to the bpe vocab file");
po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file");
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
... ... @@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}
if (!FileExists(cnn_bilstm)) {
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",
cnn_bilstm.c_str());
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str());
return false;
}
... ... @@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}
if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",
bpe_vocab.c_str());
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str());
return false;
}
... ...
... ... @@ -22,9 +22,9 @@ struct OnlinePunctuationModelConfig {
OnlinePunctuationModelConfig() = default;
OnlinePunctuationModelConfig(const std::string &cnn_bilstm,
const std::string &bpe_vocab,
int32_t num_threads, bool debug,
const std::string &provider)
const std::string &bpe_vocab,
int32_t num_threads, bool debug,
const std::string &provider)
: cnn_bilstm(cnn_bilstm),
bpe_vocab(bpe_vocab),
num_threads(num_threads),
... ...
... ... @@ -14,9 +14,7 @@
namespace sherpa_onnx {
void OnlinePunctuationConfig::Register(ParseOptions *po) {
model.Register(po);
}
void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); }
bool OnlinePunctuationConfig::Validate() const {
if (!model.Validate()) {
... ... @@ -40,13 +38,14 @@ OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config)
#if __ANDROID_API__ >= 9
OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
const OnlinePunctuationConfig &config)
const OnlinePunctuationConfig &config)
: impl_(OnlinePunctuationImpl::Create(mgr, config)) {}
#endif
OnlinePunctuation::~OnlinePunctuation() = default;
std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const {
std::string OnlinePunctuation::AddPunctuationWithCase(
const std::string &text) const {
return impl_->AddPunctuationWithCase(text);
}
... ...
... ... @@ -40,8 +40,7 @@ class OnlinePunctuation {
explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
#if __ANDROID_API__ >= 9
OnlinePunctuation(AAssetManager *mgr,
const OnlinePunctuationConfig &config);
OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config);
#endif
~OnlinePunctuation();
... ...
... ... @@ -3,9 +3,9 @@
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
#include <stdio.h>
#include <iostream>
#include <chrono> // NOLINT
#include <iostream>
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/parse-options.h"
... ... @@ -57,7 +57,7 @@ The output text should look like below:
std::string text = po.GetArg(1);
std::string text_with_punct_case = punct.AddPunctuationWithCase(text);
const auto end = std::chrono::steady_clock::now();
fprintf(stderr, "Done\n");
... ...