spoken-language-identification-impl.cc 2.5 KB
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
//
// Copyright (c)  2024  Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"

#include <memory>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"

namespace sherpa_onnx {

namespace {

enum class ModelType {
  kWhisper,
  kUnknown,
};

}

static ModelType GetModelType(char *model_data, size_t model_data_length,
                              bool debug) {
  Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  Ort::SessionOptions sess_opts;

  auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
                                             sess_opts);

  Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  if (debug) {
    std::ostringstream os;
    PrintModelMetadata(os, meta_data);
    SHERPA_ONNX_LOGE("%s", os.str().c_str());
  }

  Ort::AllocatorWithDefaultOptions allocator;
  auto model_type =
      meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  if (!model_type) {
    SHERPA_ONNX_LOGE(
        "No model_type in the metadata!\n"
        "Please make sure you have added metadata to the model.\n\n"
        "For instance, you can use\n"
        "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
        "export-onnx.py "
        "to add metadata to models from whisper\n");
    return ModelType::kUnknown;
  }

  auto model_type_str = std::string(model_type.get());
  if (model_type_str.find("whisper") == 0) {
    return ModelType::kWhisper;
  } else {
    SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
    return ModelType::kUnknown;
  }
}

std::unique_ptr<SpokenLanguageIdentificationImpl>
SpokenLanguageIdentificationImpl::Create(
    const SpokenLanguageIdentificationConfig &config) {
  ModelType model_type = ModelType::kUnknown;
  {
    if (config.whisper.encoder.empty()) {
      SHERPA_ONNX_LOGE("Only whisper models are supported at present");
      exit(-1);
    }
    auto buffer = ReadFile(config.whisper.encoder);

    model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  }

  switch (model_type) {
    case ModelType::kWhisper:
      return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
    case ModelType::kUnknown:
      SHERPA_ONNX_LOGE(
          "Unknown model type for spoken language identification!");
      return nullptr;
  }

  // unreachable code
  return nullptr;
}

}  // namespace sherpa_onnx