jianyou
Committed by GitHub

Add online punctuation and casing prediction model for English language (#1224)

@@ -138,6 +138,10 @@ list(APPEND sources @@ -138,6 +138,10 @@ list(APPEND sources
138 offline-punctuation-impl.cc 138 offline-punctuation-impl.cc
139 offline-punctuation-model-config.cc 139 offline-punctuation-model-config.cc
140 offline-punctuation.cc 140 offline-punctuation.cc
  141 + online-cnn-bilstm-model.cc
  142 + online-punctuation-impl.cc
  143 + online-punctuation-model-config.cc
  144 + online-punctuation.cc
141 ) 145 )
142 146
143 if(SHERPA_ONNX_ENABLE_TTS) 147 if(SHERPA_ONNX_ENABLE_TTS)
@@ -243,6 +247,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -243,6 +247,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
243 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) 247 add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
244 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 248 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
245 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) 249 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
  250 + add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
246 251
247 if(SHERPA_ONNX_ENABLE_TTS) 252 if(SHERPA_ONNX_ENABLE_TTS)
248 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 253 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
@@ -256,6 +261,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -256,6 +261,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
256 sherpa-onnx-offline-language-identification 261 sherpa-onnx-offline-language-identification
257 sherpa-onnx-offline-parallel 262 sherpa-onnx-offline-parallel
258 sherpa-onnx-offline-punctuation 263 sherpa-onnx-offline-punctuation
  264 + sherpa-onnx-online-punctuation
259 ) 265 )
260 if(SHERPA_ONNX_ENABLE_TTS) 266 if(SHERPA_ONNX_ENABLE_TTS)
261 list(APPEND main_exes 267 list(APPEND main_exes
  1 +// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
  7 +
  8 +namespace sherpa_onnx {
  9 +
  10 +struct OnlineCNNBiLSTMModelMetaData {
  11 + int32_t comma_id;
  12 + int32_t period_id;
  13 + int32_t quest_id;
  14 +
  15 + int32_t upper_id;
  16 + int32_t cap_id;
  17 + int32_t mix_case_id;
  18 +
  19 + int32_t num_cases;
  20 + int32_t num_punctuations;
  21 +};
  22 +
  23 +} // namespace sherpa_onnx
  24 +
  25 +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/online-cnn-bilstm-model.cc
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +#include "sherpa-onnx/csrc/session.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OnlineCNNBiLSTMModel::Impl {
  17 + public:
  18 + explicit Impl(const OnlinePunctuationModelConfig &config)
  19 + : config_(config),
  20 + env_(ORT_LOGGING_LEVEL_ERROR),
  21 + sess_opts_(GetSessionOptions(config)),
  22 + allocator_{} {
  23 + auto buf = ReadFile(config_.cnn_bilstm);
  24 + Init(buf.data(), buf.size());
  25 + }
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 + Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
  29 + : config_(config),
  30 + env_(ORT_LOGGING_LEVEL_ERROR),
  31 + sess_opts_(GetSessionOptions(config)),
  32 + allocator_{} {
  33 + auto buf = ReadFile(mgr, config_.cnn_bilstm);
  34 + Init(buf.data(), buf.size());
  35 + }
  36 +#endif
  37 +
  38 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
  39 + std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
  40 +
  41 + auto ans =
  42 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  43 + output_names_ptr_.data(), output_names_ptr_.size());
  44 + return {std::move(ans[0]), std::move(ans[1])};
  45 + }
  46 +
  47 + OrtAllocator *Allocator() const { return allocator_; }
  48 +
  49 + const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
  50 + return meta_data_;
  51 + }
  52 +
  53 + private:
  54 + void Init(void *model_data, size_t model_data_length) {
  55 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  56 + sess_opts_);
  57 +
  58 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  59 +
  60 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  61 +
  62 + // get meta data
  63 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  64 +
  65 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  66 +
  67 + SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA");
  68 + SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD");
  69 + SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION");
  70 +
  71 + // assert here, because we will use the constant value
  72 + assert(meta_data_.comma_id == 1);
  73 + assert(meta_data_.period_id == 2);
  74 + assert(meta_data_.quest_id == 3);
  75 +
  76 + SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER");
  77 + SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP");
  78 + SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE");
  79 +
  80 + assert(meta_data_.upper_id == 1);
  81 + assert(meta_data_.cap_id == 2);
  82 + assert(meta_data_.mix_case_id == 3);
  83 +
  84 + // output shape is (T', num_cases)
  85 + meta_data_.num_cases =
  86 + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
  87 + meta_data_.num_punctuations =
  88 + sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1];
  89 + }
  90 +
  91 + private:
  92 + OnlinePunctuationModelConfig config_;
  93 + Ort::Env env_;
  94 + Ort::SessionOptions sess_opts_;
  95 + Ort::AllocatorWithDefaultOptions allocator_;
  96 +
  97 + std::unique_ptr<Ort::Session> sess_;
  98 +
  99 + std::vector<std::string> input_names_;
  100 + std::vector<const char *> input_names_ptr_;
  101 +
  102 + std::vector<std::string> output_names_;
  103 + std::vector<const char *> output_names_ptr_;
  104 +
  105 + OnlineCNNBiLSTMModelMetaData meta_data_;
  106 +};
  107 +
  108 +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
  109 + const OnlinePunctuationModelConfig &config)
  110 + : impl_(std::make_unique<Impl>(config)) {}
  111 +
  112 +#if __ANDROID_API__ >= 9
  113 +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
  114 + AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
  115 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  116 +#endif
  117 +
  118 +OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
  119 +
  120 +std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
  121 + Ort::Value valid_ids,
  122 + Ort::Value label_lens) const {
  123 + return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
  124 +}
  125 +
  126 +OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
  127 + return impl_->Allocator();
  128 +}
  129 +
  130 +const OnlineCNNBiLSTMModelMetaData &
  131 +OnlineCNNBiLSTMModel::GetModelMetadata() const {
  132 + return impl_->GetModelMetadata();
  133 +}
  134 +
  135 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-cnn-bilstm-model.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
  7 +#include <memory>
  8 +#include <utility>
  9 +
  10 +#if __ANDROID_API__ >= 9
  11 +#include "android/asset_manager.h"
  12 +#include "android/asset_manager_jni.h"
  13 +#endif
  14 +
  15 +#include "onnxruntime_cxx_api.h" // NOLINT
  16 +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
  17 +#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +/** This class implements
  22 + * https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py
  23 + */
  24 +class OnlineCNNBiLSTMModel {
  25 + public:
  26 + explicit OnlineCNNBiLSTMModel(
  27 + const OnlinePunctuationModelConfig &config);
  28 +
  29 +#if __ANDROID_API__ >= 9
  30 + OnlineCNNBiLSTMModel(AAssetManager *mgr,
  31 + const OnlinePunctuationModelConfig &config);
  32 +#endif
  33 +
  34 + ~OnlineCNNBiLSTMModel();
  35 +
  36 + /** Run the forward method of the model.
  37 + *
  38 + * @param token_ids A tensor of shape (N, T) of dtype int32.
  39 + * @param valid_ids A tensor of shape (N, T) of dtype int32.
  40 + * @param label_lens A tensor of shape (N) of dtype int32.
  41 + *
  42 + * @return Return a pair of tensors
  43 + * - case_logits: A 2-D tensor of shape (T', num_cases).
  44 + * - punct_logits: A 2-D tensor of shape (T', num_puncts).
  45 + */
  46 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;
  47 +
  48 + /** Return an allocator for allocating memory
  49 + */
  50 + OrtAllocator *Allocator() const;
  51 +
  52 + const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const;
  53 +
  54 + private:
  55 + class Impl;
  56 + std::unique_ptr<Impl> impl_;
  57 +};
  58 +
  59 +} // namespace sherpa_onnx
  60 +
  61 +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
  1 +// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
  7 +
  8 +#include <math.h>
  9 +
  10 +#include <memory>
  11 +#include <string>
  12 +#include <utility>
  13 +#include <vector>
  14 +#include <algorithm>
  15 +
  16 +#if __ANDROID_API__ >= 9
  17 +#include "android/asset_manager.h"
  18 +#include "android/asset_manager_jni.h"
  19 +#endif
  20 +
  21 +#include "sherpa-onnx/csrc/macros.h"
  22 +#include "sherpa-onnx/csrc/math.h"
  23 +#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
  24 +#include "sherpa-onnx/csrc/online-punctuation-impl.h"
  25 +#include "sherpa-onnx/csrc/online-punctuation.h"
  26 +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
  27 +#include "sherpa-onnx/csrc/text-utils.h"
  28 +#include "sherpa-onnx/csrc/onnx-utils.h"
  29 +#include "ssentencepiece/csrc/ssentencepiece.h"
  30 +#include <chrono> // NOLINT
  31 +
  32 +namespace sherpa_onnx {
  33 +
  34 +static const int32_t kMaxSeqLen = 200;
  35 +
  36 +class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
  37 + public:
  38 + explicit OnlinePunctuationCNNBiLSTMImpl(
  39 + const OnlinePunctuationConfig &config)
  40 + : config_(config), model_(config.model) {
  41 + if (!config_.model.bpe_vocab.empty()) {
  42 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
  43 + config_.model.bpe_vocab);
  44 + }
  45 + }
  46 +
  47 +#if __ANDROID_API__ >= 9
  48 + OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr,
  49 + const OnlinePunctuationConfig &config)
  50 + : config_(config), model_(mgr, config.model) {
  51 + if (!config_.model.bpe_vocab.empty()) {
  52 + auto buf = ReadFile(mgr, config_.model.bpe_vocab);
  53 + std::istringstream iss(std::string(buf.begin(), buf.end()));
  54 + bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
  55 + }
  56 + }
  57 +#endif
  58 +
  59 + std::string AddPunctuationWithCase(const std::string &text) const override {
  60 + if (text.empty()) {
  61 + return {};
  62 + }
  63 +
  64 + std::vector<int32_t> tokens_list; // N * kMaxSeqLen
  65 + std::vector<int32_t> valids_list; // N * kMaxSeqLen
  66 + std::vector<int32_t> label_len_list; // N
  67 +
  68 + EncodeSentences(text, tokens_list, valids_list, label_len_list);
  69 +
  70 + const auto &meta_data = model_.GetModelMetadata();
  71 +
  72 + auto memory_info =
  73 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  74 +
  75 + int32_t n = label_len_list.size();
  76 +
  77 + std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
  78 + Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
  79 + token_ids_shape.data(), token_ids_shape.size());
  80 +
  81 + std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
  82 + Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
  83 + valid_ids_shape.data(), valid_ids_shape.size());
  84 +
  85 + std::array<int64_t, 1> label_len_shape = {n};
  86 + Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
  87 + label_len_shape.data(), label_len_shape.size());
  88 +
  89 + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));
  90 +
  91 + std::vector<int32_t> case_pred;
  92 + std::vector<int32_t> punct_pred;
  93 + const float* active_case_logits = pair.first.GetTensorData<float>();
  94 + const float* active_punct_logits = pair.second.GetTensorData<float>();
  95 + std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();
  96 +
  97 + for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
  98 + const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
  99 + auto index_case = static_cast<int32_t>(std::distance(
  100 + p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
  101 + case_pred.push_back(index_case);
  102 +
  103 + const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
  104 + auto index_punct = static_cast<int32_t>(std::distance(
  105 + p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
  106 + punct_pred.push_back(index_punct);
  107 + }
  108 +
  109 + std::string ans = DecodeSentences(text, case_pred, punct_pred);
  110 +
  111 + return ans;
  112 + }
  113 +
  114 + private:
  115 + void EncodeSentences(const std::string& text,
  116 + std::vector<int32_t>& tokens_list,
  117 + std::vector<int32_t>& valids_list,
  118 + std::vector<int32_t>& label_len_list) const {
  119 + std::vector<int32_t> tokens;
  120 + std::vector<int32_t> valids;
  121 + int32_t label_len = 0;
  122 +
  123 + tokens.push_back(1); // hardcode 1 now, 1 - <s>
  124 + valids.push_back(1);
  125 +
  126 + std::stringstream ss(text);
  127 + std::string word;
  128 + while (ss >> word) {
  129 + std::vector<int32_t> word_tokens;
  130 + bpe_encoder_->Encode(word, &word_tokens);
  131 +
  132 + int32_t seq_len = tokens.size() + word_tokens.size();
  133 + if (seq_len > kMaxSeqLen - 1) {
  134 + tokens.push_back(2); // hardcode 2 now, 2 - </s>
  135 + valids.push_back(1);
  136 +
  137 + label_len = std::count(valids.begin(), valids.end(), 1);
  138 +
  139 + if (tokens.size() < kMaxSeqLen) {
  140 + tokens.resize(kMaxSeqLen, 0);
  141 + valids.resize(kMaxSeqLen, 0);
  142 + }
  143 +
  144 + assert(tokens.size() == kMaxSeqLen);
  145 + assert(valids.size() == kMaxSeqLen);
  146 +
  147 + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
  148 + valids_list.insert(valids_list.end(), valids.begin(), valids.end());
  149 + label_len_list.push_back(label_len);
  150 +
  151 + std::vector<int32_t>().swap(tokens);
  152 + std::vector<int32_t>().swap(valids);
  153 + label_len = 0;
  154 + tokens.push_back(1); // hardcode 1 now, 1 - <s>
  155 + valids.push_back(1);
  156 + }
  157 +
  158 + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
  159 + valids.push_back(1); // only the first sub word is valid
  160 + int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
  161 + if (remaining_size > 0) {
  162 + int32_t valids_cur_size = static_cast<int32_t>(valids.size());
  163 + valids.resize(valids_cur_size + remaining_size, 0);
  164 + }
  165 + }
  166 +
  167 + if (tokens.size() > 0) {
  168 + tokens.push_back(2); // hardcode 2 now, 2 - </s>
  169 + valids.push_back(1);
  170 +
  171 + label_len = std::count(valids.begin(), valids.end(), 1);
  172 +
  173 + if (tokens.size() < kMaxSeqLen) {
  174 + tokens.resize(kMaxSeqLen, 0);
  175 + valids.resize(kMaxSeqLen, 0);
  176 + }
  177 +
  178 + assert(tokens.size() == kMaxSeqLen);
  179 + assert(valids.size() == kMaxSeqLen);
  180 +
  181 + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
  182 + valids_list.insert(valids_list.end(), valids.begin(), valids.end());
  183 + label_len_list.push_back(label_len);
  184 + }
  185 + }
  186 +
  187 + std::string DecodeSentences(const std::string& raw_text,
  188 + const std::vector<int32_t>& case_pred,
  189 + const std::vector<int32_t>& punct_pred) const {
  190 + std::string result_text;
  191 + std::istringstream iss(raw_text);
  192 + std::vector<std::string> words;
  193 + std::string word;
  194 +
  195 + while (iss >> word) {
  196 + words.emplace_back(word);
  197 + }
  198 +
  199 + assert(words.size() == case_pred.size());
  200 + assert(words.size() == punct_pred.size());
  201 +
  202 + for (int32_t i = 0; i < words.size(); ++i) {
  203 + std::string prefix = ((i != 0) ? " " : "");
  204 + result_text += prefix;
  205 + switch (case_pred[i]) {
  206 + case 1: // upper
  207 + {
  208 + std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); });
  209 + result_text += words[i];
  210 + break;
  211 + }
  212 + case 2: // cap
  213 + {
  214 + words[i][0] = std::toupper(words[i][0]);
  215 + result_text += words[i];
  216 + break;
  217 + }
  218 + case 3: // mix case
  219 + {
  220 + // TODO:
  221 + // Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
  222 + // e.g. mcdonald's -> McDonald's
  223 + result_text += words[i];
  224 + break;
  225 + }
  226 + default:
  227 + {
  228 + result_text += words[i];
  229 + break;
  230 + }
  231 + }
  232 +
  233 + std::string suffix;
  234 + switch (punct_pred[i]) {
  235 + case 1: // comma
  236 + {
  237 + suffix = ",";
  238 + break;
  239 + }
  240 + case 2: // period
  241 + {
  242 + suffix = ".";
  243 + break;
  244 + }
  245 + case 3: // question
  246 + {
  247 + suffix = "?";
  248 + break;
  249 + }
  250 + default:
  251 + break;
  252 + }
  253 +
  254 + result_text += suffix;
  255 + }
  256 +
  257 + return result_text;
  258 + }
  259 +
  260 + private:
  261 + OnlinePunctuationConfig config_;
  262 + OnlineCNNBiLSTMModel model_;
  263 + std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
  264 +};
  265 +
  266 +} // namespace sherpa_onnx
  267 +
  268 +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
  1 +// sherpa-onnx/csrc/online-punctuation-impl.cc
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#include "sherpa-onnx/csrc/online-punctuation-impl.h"
  6 +
  7 +#if __ANDROID_API__ >= 9
  8 +#include "android/asset_manager.h"
  9 +#include "android/asset_manager_jni.h"
  10 +#endif
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
  18 + const OnlinePunctuationConfig &config) {
  19 + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) {
  20 + return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
  21 + }
  22 +
  23 + SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
  24 + return nullptr;
  25 +}
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 +std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
  29 + AAssetManager *mgr, const OnlinePunctuationConfig &config) {
  30 + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) {
  31 + return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
  32 + }
  33 +
  34 + SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
  35 + return nullptr;
  36 +}
  37 +#endif
  38 +
  39 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-punctuation-impl.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "sherpa-onnx/csrc/online-punctuation.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +class OnlinePunctuationImpl {
  21 + public:
  22 + virtual ~OnlinePunctuationImpl() = default;
  23 +
  24 + static std::unique_ptr<OnlinePunctuationImpl> Create(
  25 + const OnlinePunctuationConfig &config);
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 + static std::unique_ptr<OnlinePunctuationImpl> Create(
  29 + AAssetManager *mgr, const OnlinePunctuationConfig &config);
  30 +#endif
  31 +
  32 + virtual std::string AddPunctuationWithCase(const std::string &text) const = 0;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
  1 +// sherpa-onnx/csrc/online-punctuation-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
  13 + po->Register("cnn-bilstm", &cnn_bilstm,
  14 + "Path to the light-weight CNN-BiLSTM model");
  15 +
  16 + po->Register("bpe-vocab", &bpe_vocab,
  17 + "Path to the bpe vocab file");
  18 +
  19 + po->Register("num-threads", &num_threads,
  20 + "Number of threads to run the neural network");
  21 +
  22 + po->Register("debug", &debug,
  23 + "true to print model information while loading it.");
  24 +
  25 + po->Register("provider", &provider,
  26 + "Specify a provider to use: cpu, cuda, coreml");
  27 +}
  28 +
  29 +bool OnlinePunctuationModelConfig::Validate() const {
  30 + if (cnn_bilstm.empty()) {
  31 + SHERPA_ONNX_LOGE("Please provide --cnn-bilstm");
  32 + return false;
  33 + }
  34 +
  35 + if (!FileExists(cnn_bilstm)) {
  36 + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",
  37 + cnn_bilstm.c_str());
  38 + return false;
  39 + }
  40 +
  41 + if (bpe_vocab.empty()) {
  42 + SHERPA_ONNX_LOGE("Please provide --bpe-vocab");
  43 + return false;
  44 + }
  45 +
  46 + if (!FileExists(bpe_vocab)) {
  47 + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",
  48 + bpe_vocab.c_str());
  49 + return false;
  50 + }
  51 +
  52 + return true;
  53 +}
  54 +
  55 +std::string OnlinePunctuationModelConfig::ToString() const {
  56 + std::ostringstream os;
  57 +
  58 + os << "OnlinePunctuationModelConfig(";
  59 + os << "cnn_bilstm=\"" << cnn_bilstm << "\", ";
  60 + os << "bpe_vocab=\"" << bpe_vocab << "\", ";
  61 + os << "num_threads=" << num_threads << ", ";
  62 + os << "debug=" << (debug ? "True" : "False") << ", ";
  63 + os << "provider=\"" << provider << "\")";
  64 +
  65 + return os.str();
  66 +}
  67 +
  68 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-punctuation-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OnlinePunctuationModelConfig {
  15 + std::string cnn_bilstm;
  16 + std::string bpe_vocab;
  17 +
  18 + int32_t num_threads = 1;
  19 + bool debug = false;
  20 + std::string provider = "cpu";
  21 +
  22 + OnlinePunctuationModelConfig() = default;
  23 +
  24 + OnlinePunctuationModelConfig(const std::string &cnn_bilstm,
  25 + const std::string &bpe_vocab,
  26 + int32_t num_threads, bool debug,
  27 + const std::string &provider)
  28 + : cnn_bilstm(cnn_bilstm),
  29 + bpe_vocab(bpe_vocab),
  30 + num_threads(num_threads),
  31 + debug(debug),
  32 + provider(provider) {}
  33 +
  34 + void Register(ParseOptions *po);
  35 + bool Validate() const;
  36 +
  37 + std::string ToString() const;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-punctuation.cc
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#include "sherpa-onnx/csrc/online-punctuation.h"
  6 +
  7 +#if __ANDROID_API__ >= 9
  8 +#include "android/asset_manager.h"
  9 +#include "android/asset_manager_jni.h"
  10 +#endif
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/online-punctuation-impl.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +void OnlinePunctuationConfig::Register(ParseOptions *po) {
  18 + model.Register(po);
  19 +}
  20 +
  21 +bool OnlinePunctuationConfig::Validate() const {
  22 + if (!model.Validate()) {
  23 + return false;
  24 + }
  25 +
  26 + return true;
  27 +}
  28 +
  29 +std::string OnlinePunctuationConfig::ToString() const {
  30 + std::ostringstream os;
  31 +
  32 + os << "OnlinePunctuationConfig(";
  33 + os << "model=" << model.ToString() << ")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config)
  39 + : impl_(OnlinePunctuationImpl::Create(config)) {}
  40 +
  41 +#if __ANDROID_API__ >= 9
  42 +OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
  43 + const OnlinePunctuationConfig &config)
  44 + : impl_(OnlinePunctuationImpl::Create(mgr, config)) {}
  45 +#endif
  46 +
  47 +OnlinePunctuation::~OnlinePunctuation() = default;
  48 +
  49 +std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const {
  50 + return impl_->AddPunctuationWithCase(text);
  51 +}
  52 +
  53 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-punctuation.h
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
  17 +#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
  18 +#include "sherpa-onnx/csrc/parse-options.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +struct OnlinePunctuationConfig {
  23 + OnlinePunctuationModelConfig model;
  24 +
  25 + OnlinePunctuationConfig() = default;
  26 +
  27 + explicit OnlinePunctuationConfig(const OnlinePunctuationModelConfig &model)
  28 + : model(model) {}
  29 +
  30 + void Register(ParseOptions *po);
  31 + bool Validate() const;
  32 +
  33 + std::string ToString() const;
  34 +};
  35 +
  36 +class OnlinePunctuationImpl;
  37 +
  38 +class OnlinePunctuation {
  39 + public:
  40 + explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
  41 +
  42 +#if __ANDROID_API__ >= 9
  43 + OnlinePunctuation(AAssetManager *mgr,
  44 + const OnlinePunctuationConfig &config);
  45 +#endif
  46 +
  47 + ~OnlinePunctuation();
  48 +
  49 + // Add punctuation and casing to the input text and return it.
  50 + std::string AddPunctuationWithCase(const std::string &text) const;
  51 +
  52 + private:
  53 + std::unique_ptr<OnlinePunctuationImpl> impl_;
  54 +};
  55 +
  56 +} // namespace sherpa_onnx
  57 +
  58 +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
@@ -300,4 +300,9 @@ Ort::SessionOptions GetSessionOptions( @@ -300,4 +300,9 @@ Ort::SessionOptions GetSessionOptions(
300 return GetSessionOptionsImpl(config.num_threads, config.provider); 300 return GetSessionOptionsImpl(config.num_threads, config.provider);
301 } 301 }
302 302
  303 +Ort::SessionOptions GetSessionOptions(
  304 + const OnlinePunctuationModelConfig &config) {
  305 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  306 +}
  307 +
303 } // namespace sherpa_onnx 308 } // namespace sherpa_onnx
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/offline-lm-config.h" 12 #include "sherpa-onnx/csrc/offline-lm-config.h"
13 #include "sherpa-onnx/csrc/offline-model-config.h" 13 #include "sherpa-onnx/csrc/offline-model-config.h"
14 #include "sherpa-onnx/csrc/offline-punctuation-model-config.h" 14 #include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
  15 +#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
15 #include "sherpa-onnx/csrc/online-lm-config.h" 16 #include "sherpa-onnx/csrc/online-lm-config.h"
16 #include "sherpa-onnx/csrc/online-model-config.h" 17 #include "sherpa-onnx/csrc/online-model-config.h"
17 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 18 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
@@ -52,6 +53,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); @@ -52,6 +53,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
52 Ort::SessionOptions GetSessionOptions( 53 Ort::SessionOptions GetSessionOptions(
53 const OfflinePunctuationModelConfig &config); 54 const OfflinePunctuationModelConfig &config);
54 55
  56 +Ort::SessionOptions GetSessionOptions(
  57 + const OnlinePunctuationModelConfig &config);
  58 +
55 } // namespace sherpa_onnx 59 } // namespace sherpa_onnx
56 60
57 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 61 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
  2 +//
  3 +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
  4 +
  5 +#include <stdio.h>
  6 +#include <iostream>
  7 +
  8 +#include <chrono> // NOLINT
  9 +
  10 +#include "sherpa-onnx/csrc/online-punctuation.h"
  11 +#include "sherpa-onnx/csrc/parse-options.h"
  12 +
  13 +int main(int32_t argc, char *argv[]) {
  14 + const char *kUsageMessage = R"usage(
  15 +Add punctuations to the input text.
  16 +
  17 +The input text can contain English words.
  18 +
  19 +Usage:
  20 +
  21 +Please download the model from:
  22 +https://huggingface.co/frankyoujian/Edge-Punct-Casing/resolve/main/sherpa-onnx-cnn-bilstm-unigram-bpe-en.7z
  23 +
  24 +./bin/Release/sherpa-onnx-online-punctuation \
  25 + --cnn-bilstm=/path/to/model.onnx \
  26 + --bpe-vocab=/path/to/bpe.vocab \
  27 + "how are you i am fine thank you"
  28 +
  29 +The output text should look like below:
  30 + "How are you? I am fine. Thank you."
  31 +)usage";
  32 +
  33 + sherpa_onnx::ParseOptions po(kUsageMessage);
  34 + sherpa_onnx::OnlinePunctuationConfig config;
  35 + config.Register(&po);
  36 + po.Read(argc, argv);
  37 + if (po.NumArgs() != 1) {
  38 + fprintf(stderr,
  39 + "Error: Please provide only 1 positional argument containing the "
  40 + "input text.\n\n");
  41 + po.PrintUsage();
  42 + exit(EXIT_FAILURE);
  43 + }
  44 +
  45 + fprintf(stderr, "%s\n", config.ToString().c_str());
  46 +
  47 + if (!config.Validate()) {
  48 + fprintf(stderr, "Errors in config!\n");
  49 + return -1;
  50 + }
  51 +
  52 + fprintf(stderr, "Creating OnlinePunctuation ...\n");
  53 + sherpa_onnx::OnlinePunctuation punct(config);
  54 + fprintf(stderr, "Started\n");
  55 + const auto begin = std::chrono::steady_clock::now();
  56 +
  57 + std::string text = po.GetArg(1);
  58 +
  59 + std::string text_with_punct_case = punct.AddPunctuationWithCase(text);
  60 +
  61 + const auto end = std::chrono::steady_clock::now();
  62 + fprintf(stderr, "Done\n");
  63 +
  64 + float elapsed_seconds =
  65 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  66 + .count() /
  67 + 1000.;
  68 +
  69 + fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
  70 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  71 + fprintf(stderr, "Input text: %s\n", text.c_str());
  72 + fprintf(stderr, "Output text: %s\n", text_with_punct_case.c_str());
  73 +}