spoken-language-identification.cc 3.5 KB
// sherpa-onnx/jni/spoken-language-identification.cc
//
// Copyright (c)  2024  Xiaomi Corporation

#include "sherpa-onnx/csrc/spoken-language-identification.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"

namespace sherpa_onnx {

static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
    JNIEnv *env, jobject config) {
  SpokenLanguageIdentificationConfig ans;

  jclass cls = env->GetObjectClass(config);
  jfieldID fid = env->GetFieldID(
      cls, "whisper",
      "Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;");

  jobject whisper = env->GetObjectField(config, fid);
  jclass whisper_cls = env->GetObjectClass(whisper);

  fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;");

  jstring s = (jstring)env->GetObjectField(whisper, fid);
  const char *p = env->GetStringUTFChars(s, nullptr);
  ans.whisper.encoder = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(whisper, fid);
  p = env->GetStringUTFChars(s, nullptr);
  ans.whisper.decoder = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(whisper_cls, "tailPaddings", "I");
  ans.whisper.tail_paddings = env->GetIntField(whisper, fid);

  fid = env->GetFieldID(cls, "numThreads", "I");
  ans.num_threads = env->GetIntField(config, fid);

  fid = env->GetFieldID(cls, "debug", "Z");
  ans.debug = env->GetBooleanField(config, fid);

  fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  ans.provider = p;
  env->ReleaseStringUTFChars(s, p);

  return ans;
}

}  // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
    JNIEnv *env, jobject /*obj*/, jobject _config) {
  auto config =
      sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
  SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s",
                   config.ToString().c_str());

  if (!config.Validate()) {
    SHERPA_ONNX_LOGE("Errors found in config!");
    return 0;
  }

  auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config);

  return (jlong)tagger;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
    JNIEnv *env, jobject /*obj*/, jlong ptr) {
  auto slid =
      reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
  std::unique_ptr<sherpa_onnx::OfflineStream> s = slid->CreateStream();

  // The user is responsible to free the returned pointer.
  //
  // See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
  // ./offline-stream.cc
  sherpa_onnx::OfflineStream *p = s.release();
  return (jlong)p;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env,
                                                                jobject /*obj*/,
                                                                jlong ptr,
                                                                jlong s_ptr) {
  sherpa_onnx::SpokenLanguageIdentification *slid =
      reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
  sherpa_onnx::OfflineStream *s =
      reinterpret_cast<sherpa_onnx::OfflineStream *>(s_ptr);
  std::string lang = slid->Compute(s);
  return env->NewStringUTF(lang.c_str());
}