Fangjun Kuang
Committed by GitHub

Android JNI support for speaker diarization (#1421)

... ... @@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create(
return nullptr;
}
#if __ANDROID_API__ >= 9
std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");
return nullptr;
}
#endif
} // namespace sherpa_onnx
... ...
... ... @@ -8,6 +8,11 @@
#include <functional>
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace sherpa_onnx {
... ... @@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl {
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif
virtual ~OfflineSpeakerDiarizationImpl() = default;
virtual int32_t SampleRate() const = 0;
... ...
... ... @@ -10,6 +10,11 @@
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h"
... ... @@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Init();
}
#if __ANDROID_API__ >= 9
OfflineSpeakerDiarizationPyannoteImpl(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: config_(config),
segmentation_model_(mgr, config_.segmentation),
embedding_extractor_(mgr, config_.embedding),
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}
#endif
int32_t SampleRate() const override {
const auto &meta_data = segmentation_model_.GetModelMetaData();
... ...
... ... @@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization::OfflineSpeakerDiarization(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {}
#endif
OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
int32_t OfflineSpeakerDiarization::SampleRate() const {
... ...
... ... @@ -9,6 +9,11 @@
#include <memory>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
... ... @@ -57,6 +62,11 @@ class OfflineSpeakerDiarization {
explicit OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config);
#if __ANDROID_API__ >= 9
OfflineSpeakerDiarization(AAssetManager *mgr,
const OfflineSpeakerDiarizationConfig &config);
#endif
~OfflineSpeakerDiarization();
// Expected sample rate of the input audio samples
... ...
... ... @@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.pyannote.model);
Init(buf.data(), buf.size());
}
#endif
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const {
return meta_data_;
... ... @@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel::
const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default;
... ...
... ... @@ -6,6 +6,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
... ... @@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel {
explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config);
#endif
~OfflineSpeakerSegmentationPyannoteModel();
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
... ...
... ... @@ -211,7 +211,7 @@ to download models for offline ASR.
}
while (!vad->Empty()) {
auto &segment = vad->Front();
const auto &segment = vad->Front();
auto s = recognizer.CreateStream();
s->AcceptWaveform(sample_rate, segment.samples.data(),
segment.samples.size());
... ...
... ... @@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
... ...
... ... @@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetKwsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto kws = new sherpa_onnx::KeywordSpotter(
#if __ANDROID_API__ >= 9
mgr,
... ...
... ... @@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OfflinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
... ...
... ... @@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OfflineRecognizer(
#if __ANDROID_API__ >= 9
mgr,
... ...
... ... @@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto sd = new sherpa_onnx::OfflineSpeakerDiarization(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)sd;
}
SHERPA_ONNX_EXTERN_C
... ...
... ... @@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
... ...
... ... @@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
... ...
... ... @@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
... ...
... ... @@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
... ...
... ... @@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::VoiceActivityDetector(
#if __ANDROID_API__ >= 9
mgr,
... ...