offline-punctuation-ct-transformer-impl.h 5.3 KB
// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
//
// Copyright (c)  2024  Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
 public:
  explicit OfflinePunctuationCtTransformerImpl(
      const OfflinePunctuationConfig &config)
      : config_(config), model_(config.model) {}

  std::string AddPunctuation(const std::string &text) const override {
    if (text.empty()) {
      return {};
    }

    std::vector<std::string> tokens = SplitUtf8(text);
    std::vector<int32_t> token_ids;
    token_ids.reserve(tokens.size());

    const auto &meta_data = model_.GetModelMetadata();

    for (const auto &t : tokens) {
      std::string token = ToLowerCase(t);
      if (meta_data.token2id.count(token)) {
        token_ids.push_back(meta_data.token2id.at(token));
      } else {
        token_ids.push_back(meta_data.unk_id);
      }
    }

    auto memory_info =
        Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

    int32_t segment_size = 20;
    int32_t max_len = 200;
    int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size;

    std::vector<int32_t> punctuations;
    int32_t last = -1;
    for (int32_t i = 0; i != num_segments; ++i) {
      int32_t this_start = i * segment_size;         // inclusive
      int32_t this_end = this_start + segment_size;  // exclusive
      if (this_end > token_ids.size()) {
        this_end = token_ids.size();
      }

      if (last != -1) {
        this_start = last;
      }
      // token_ids[this_start:this_end] is sent to the model

      std::array<int64_t, 2> x_shape = {1, this_end - this_start};
      Ort::Value x =
          Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start,
                                   x_shape[1], x_shape.data(), x_shape.size());

      int64_t len_shape = 1;
      int32_t len = x_shape[1];
      Ort::Value x_len =
          Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);

      Ort::Value out = model_.Forward(std::move(x), std::move(x_len));

      // [N, T, num_punctuations]
      std::vector<int64_t> out_shape =
          out.GetTensorTypeAndShapeInfo().GetShape();

      assert(out_shape[0] == 1);
      assert(out_shape[1] == len);
      assert(out_shape[2] == meta_data.num_punctuations);

      std::vector<int32_t> this_punctuations;
      this_punctuations.reserve(len);

      const float *p = out.GetTensorData<float>();
      for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) {
        auto index = static_cast<int32_t>(std::distance(
            p, std::max_element(p, p + meta_data.num_punctuations)));
        this_punctuations.push_back(index);
      }  // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations)

      int32_t dot_index = -1;
      int32_t comma_index = -1;

      for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
        int32_t punct_id = this_punctuations[m];

        if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
          dot_index = m;
          break;
        }

        if (comma_index == -1 && punct_id == meta_data.comma_id) {
          comma_index = m;
        }
      }  // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k)

      if (dot_index == -1 && len >= max_len && comma_index != -1) {
        dot_index = comma_index;
        this_punctuations[dot_index] = meta_data.dot_id;
      }

      if (dot_index == -1) {
        if (last == -1) {
          last = this_start;
        }

        if (i == num_segments - 1) {
          dot_index = token_ids.size() - 1;
        }
      } else {
        last = this_start + dot_index + 1;

        punctuations.insert(punctuations.end(), this_punctuations.begin(),
                            this_punctuations.begin() + (dot_index + 1));
      }
    }  // for (int32_t i = 0; i != num_segments; ++i)

    if (punctuations.size() != token_ids.size() &&
        punctuations.size() + 1 == token_ids.size()) {
      punctuations.push_back(meta_data.dot_id);
    }

    if (punctuations.size() != token_ids.size()) {
      SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
                       text.c_str(), static_cast<int32_t>(punctuations.size()),
                       static_cast<int32_t>(token_ids.size()));
      return text;
    }

    std::string ans;

    for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
      const std::string &w = tokens[i];
      if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
        ans.push_back(' ');
      }
      ans.append(w);
      if (punctuations[i] != meta_data.underline_id) {
        ans.append(meta_data.id2punct[punctuations[i]]);
      }
    }

    return ans;
  }

 private:
  OfflinePunctuationConfig config_;
  OfflineCtTransformerModel model_;
};

}  // namespace sherpa_onnx

#endif  // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_