spoken-language-identification-whisper-impl.h 3.9 KB
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
//
// Copyright (c)  2024  Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include "sherpa-onnx/csrc/transpose.h"

namespace sherpa_onnx {

class SpokenLanguageIdentificationWhisperImpl
    : public SpokenLanguageIdentificationImpl {
 public:
  explicit SpokenLanguageIdentificationWhisperImpl(
      const SpokenLanguageIdentificationConfig &config)
      : config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
    Check();
  }

  std::unique_ptr<OfflineStream> CreateStream() const override {
    return std::make_unique<OfflineStream>(WhisperTag{});
  }

  std::string Compute(OfflineStream *s) const override {
    int32_t max_num_frames = 3000;
    auto memory_info =
        Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

    int32_t feat_dim = s->FeatureDim();
    std::vector<float> f = s->GetFrames();
    int32_t num_frames = f.size() / feat_dim;

    // we use 50 here so that there will be some zero tail paddings
    if (num_frames >= max_num_frames - 50) {
      SHERPA_ONNX_LOGE(
          "Only waves less than 30 seconds are supported. We process only the "
          "first 30 seconds and discard the remaining data");
      num_frames = max_num_frames - 50;
    }

    model_->NormalizeFeatures(f.data(), num_frames, feat_dim);

    // note that 1000 is an experience-value.
    // You can replace 1000 by other values, say, 100.
    //
    // Since we have removed the 30 seconds constraint, we need
    // tail_padding_frames so that whisper is able to detect the eot token.
    int32_t tail_padding_frames = 1000;

    if (config_.whisper.tail_paddings > 0) {
      tail_padding_frames = config_.whisper.tail_paddings;
    }

    int32_t actual_frames =
        std::min(num_frames + tail_padding_frames, max_num_frames);

    std::array<int64_t, 3> shape{1, actual_frames, feat_dim};

    Ort::Value mel = Ort::Value::CreateTensor<float>(
        model_->Allocator(), shape.data(), shape.size());

    float *p_mel = mel.GetTensorMutableData<float>();
    std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);

    std::fill_n(p_mel + num_frames * feat_dim,
                (actual_frames - num_frames) * feat_dim, 0);

    mel = Transpose12(model_->Allocator(), &mel);

    try {
      auto cross_kv = model_->ForwardEncoder(std::move(mel));
      int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
      const auto &id2lang = model_->GetID2Lang();
      if (id2lang.count(lang_id)) {
        return id2lang.at(lang_id);
      } else {
        SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
                         lang_id);
        return "";
      }
    } catch (const Ort::Exception &ex) {
      SHERPA_ONNX_LOGE(
          "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
          "input frames: %d, Current tail "
          "paddings: %d. If you see a lot of such exceptions, please consider "
          "using a larger --whisper-tail-paddings",
          ex.what(), num_frames, tail_padding_frames);
      return "";
    }
  }

 private:
  void Check() const {
    if (!model_->IsMultiLingual()) {
      SHERPA_ONNX_LOGE(
          "Only whisper multilingual models can be used for spoken language "
          "identification. Given: %s,%s",
          config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
      exit(-1);
    }
  }

 private:
  SpokenLanguageIdentificationConfig config_;
  std::unique_ptr<OfflineWhisperModel> model_;
};

}  // namespace sherpa_onnx

#endif  // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_