Fangjun Kuang
Committed by GitHub

Android JNI support for speaker diarization (#1421)

@@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create( @@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create(
23 return nullptr; 23 return nullptr;
24 } 24 }
25 25
  26 +#if __ANDROID_API__ >= 9
  27 +std::unique_ptr<OfflineSpeakerDiarizationImpl>
  28 +OfflineSpeakerDiarizationImpl::Create(
  29 + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) {
  30 + if (!config.segmentation.pyannote.model.empty()) {
  31 + return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config);
  32 + }
  33 +
  34 + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");
  35 +
  36 + return nullptr;
  37 +}
  38 +#endif
  39 +
26 } // namespace sherpa_onnx 40 } // namespace sherpa_onnx
@@ -8,6 +8,11 @@ @@ -8,6 +8,11 @@
8 #include <functional> 8 #include <functional>
9 #include <memory> 9 #include <memory>
10 10
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
11 #include "sherpa-onnx/csrc/offline-speaker-diarization.h" 16 #include "sherpa-onnx/csrc/offline-speaker-diarization.h"
12 namespace sherpa_onnx { 17 namespace sherpa_onnx {
13 18
@@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl { @@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl {
16 static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create( 21 static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
17 const OfflineSpeakerDiarizationConfig &config); 22 const OfflineSpeakerDiarizationConfig &config);
18 23
  24 +#if __ANDROID_API__ >= 9
  25 + static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
  26 + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
  27 +#endif
  28 +
19 virtual ~OfflineSpeakerDiarizationImpl() = default; 29 virtual ~OfflineSpeakerDiarizationImpl() = default;
20 30
21 virtual int32_t SampleRate() const = 0; 31 virtual int32_t SampleRate() const = 0;
@@ -10,6 +10,11 @@ @@ -10,6 +10,11 @@
10 #include <utility> 10 #include <utility>
11 #include <vector> 11 #include <vector>
12 12
  13 +#if __ANDROID_API__ >= 9
  14 +#include "android/asset_manager.h"
  15 +#include "android/asset_manager_jni.h"
  16 +#endif
  17 +
13 #include "Eigen/Dense" 18 #include "Eigen/Dense"
14 #include "sherpa-onnx/csrc/fast-clustering.h" 19 #include "sherpa-onnx/csrc/fast-clustering.h"
15 #include "sherpa-onnx/csrc/math.h" 20 #include "sherpa-onnx/csrc/math.h"
@@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl @@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
65 Init(); 70 Init();
66 } 71 }
67 72
  73 +#if __ANDROID_API__ >= 9
  74 + OfflineSpeakerDiarizationPyannoteImpl(
  75 + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
  76 + : config_(config),
  77 + segmentation_model_(mgr, config_.segmentation),
  78 + embedding_extractor_(mgr, config_.embedding),
  79 + clustering_(std::make_unique<FastClustering>(config_.clustering)) {
  80 + Init();
  81 + }
  82 +#endif
  83 +
68 int32_t SampleRate() const override { 84 int32_t SampleRate() const override {
69 const auto &meta_data = segmentation_model_.GetModelMetaData(); 85 const auto &meta_data = segmentation_model_.GetModelMetaData();
70 86
@@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization( @@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization(
73 const OfflineSpeakerDiarizationConfig &config) 73 const OfflineSpeakerDiarizationConfig &config)
74 : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} 74 : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
75 75
  76 +#if __ANDROID_API__ >= 9
  77 +OfflineSpeakerDiarization::OfflineSpeakerDiarization(
  78 + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config)
  79 + : impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {}
  80 +#endif
  81 +
76 OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; 82 OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
77 83
78 int32_t OfflineSpeakerDiarization::SampleRate() const { 84 int32_t OfflineSpeakerDiarization::SampleRate() const {
@@ -9,6 +9,11 @@ @@ -9,6 +9,11 @@
9 #include <memory> 9 #include <memory>
10 #include <string> 10 #include <string>
11 11
  12 +#if __ANDROID_API__ >= 9
  13 +#include "android/asset_manager.h"
  14 +#include "android/asset_manager_jni.h"
  15 +#endif
  16 +
12 #include "sherpa-onnx/csrc/fast-clustering-config.h" 17 #include "sherpa-onnx/csrc/fast-clustering-config.h"
13 #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" 18 #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
14 #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" 19 #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
@@ -57,6 +62,11 @@ class OfflineSpeakerDiarization { @@ -57,6 +62,11 @@ class OfflineSpeakerDiarization {
57 explicit OfflineSpeakerDiarization( 62 explicit OfflineSpeakerDiarization(
58 const OfflineSpeakerDiarizationConfig &config); 63 const OfflineSpeakerDiarizationConfig &config);
59 64
  65 +#if __ANDROID_API__ >= 9
  66 + OfflineSpeakerDiarization(AAssetManager *mgr,
  67 + const OfflineSpeakerDiarizationConfig &config);
  68 +#endif
  69 +
60 ~OfflineSpeakerDiarization(); 70 ~OfflineSpeakerDiarization();
61 71
62 // Expected sample rate of the input audio samples 72 // Expected sample rate of the input audio samples
@@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { @@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
24 Init(buf.data(), buf.size()); 24 Init(buf.data(), buf.size());
25 } 25 }
26 26
  27 +#if __ANDROID_API__ >= 9
  28 + Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
  29 + : config_(config),
  30 + env_(ORT_LOGGING_LEVEL_ERROR),
  31 + sess_opts_(GetSessionOptions(config)),
  32 + allocator_{} {
  33 + auto buf = ReadFile(mgr, config_.pyannote.model);
  34 + Init(buf.data(), buf.size());
  35 + }
  36 +#endif
  37 +
27 const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() 38 const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
28 const { 39 const {
29 return meta_data_; 40 return meta_data_;
@@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel:: @@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel::
92 const OfflineSpeakerSegmentationModelConfig &config) 103 const OfflineSpeakerSegmentationModelConfig &config)
93 : impl_(std::make_unique<Impl>(config)) {} 104 : impl_(std::make_unique<Impl>(config)) {}
94 105
  106 +#if __ANDROID_API__ >= 9
  107 +OfflineSpeakerSegmentationPyannoteModel::
  108 + OfflineSpeakerSegmentationPyannoteModel(
  109 + AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
  110 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  111 +#endif
  112 +
95 OfflineSpeakerSegmentationPyannoteModel:: 113 OfflineSpeakerSegmentationPyannoteModel::
96 ~OfflineSpeakerSegmentationPyannoteModel() = default; 114 ~OfflineSpeakerSegmentationPyannoteModel() = default;
97 115
@@ -6,6 +6,11 @@ @@ -6,6 +6,11 @@
6 6
7 #include <memory> 7 #include <memory>
8 8
  9 +#if __ANDROID_API__ >= 9
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
9 #include "onnxruntime_cxx_api.h" // NOLINT 14 #include "onnxruntime_cxx_api.h" // NOLINT
10 #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" 15 #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
11 #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" 16 #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
@@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel { @@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel {
17 explicit OfflineSpeakerSegmentationPyannoteModel( 22 explicit OfflineSpeakerSegmentationPyannoteModel(
18 const OfflineSpeakerSegmentationModelConfig &config); 23 const OfflineSpeakerSegmentationModelConfig &config);
19 24
  25 +#if __ANDROID_API__ >= 9
  26 + OfflineSpeakerSegmentationPyannoteModel(
  27 + AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config);
  28 +#endif
  29 +
20 ~OfflineSpeakerSegmentationPyannoteModel(); 30 ~OfflineSpeakerSegmentationPyannoteModel();
21 31
22 const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() 32 const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
@@ -211,7 +211,7 @@ to download models for offline ASR. @@ -211,7 +211,7 @@ to download models for offline ASR.
211 } 211 }
212 212
213 while (!vad->Empty()) { 213 while (!vad->Empty()) {
214 - auto &segment = vad->Front(); 214 + const auto &segment = vad->Front();
215 auto s = recognizer.CreateStream(); 215 auto s = recognizer.CreateStream();
216 s->AcceptWaveform(sample_rate, segment.samples.data(), 216 s->AcceptWaveform(sample_rate, segment.samples.data(),
217 segment.samples.size()); 217 segment.samples.size());
@@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset( @@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset(
70 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 70 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
71 if (!mgr) { 71 if (!mgr) {
72 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 72 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  73 + return 0;
73 } 74 }
74 #endif 75 #endif
75 76
@@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset( @@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
115 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 115 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
116 if (!mgr) { 116 if (!mgr) {
117 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 117 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  118 + return 0;
118 } 119 }
119 #endif 120 #endif
120 auto config = sherpa_onnx::GetKwsConfig(env, _config); 121 auto config = sherpa_onnx::GetKwsConfig(env, _config);
121 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); 122 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  123 +
122 auto kws = new sherpa_onnx::KeywordSpotter( 124 auto kws = new sherpa_onnx::KeywordSpotter(
123 #if __ANDROID_API__ >= 9 125 #if __ANDROID_API__ >= 9
124 mgr, 126 mgr,
@@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset( @@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset(
53 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 53 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
54 if (!mgr) { 54 if (!mgr) {
55 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 55 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  56 + return 0;
56 } 57 }
57 #endif 58 #endif
58 auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config); 59 auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config);
59 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); 60 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  61 +
60 auto model = new sherpa_onnx::OfflinePunctuation( 62 auto model = new sherpa_onnx::OfflinePunctuation(
61 #if __ANDROID_API__ >= 9 63 #if __ANDROID_API__ >= 9
62 mgr, 64 mgr,
@@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env, @@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
233 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 233 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
234 if (!mgr) { 234 if (!mgr) {
235 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 235 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  236 + return 0;
236 } 237 }
237 #endif 238 #endif
238 auto config = sherpa_onnx::GetOfflineConfig(env, _config); 239 auto config = sherpa_onnx::GetOfflineConfig(env, _config);
239 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); 240 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  241 +
240 auto model = new sherpa_onnx::OfflineRecognizer( 242 auto model = new sherpa_onnx::OfflineRecognizer(
241 #if __ANDROID_API__ >= 9 243 #if __ANDROID_API__ >= 9
242 mgr, 244 mgr,
@@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C @@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C
101 JNIEXPORT jlong JNICALL 101 JNIEXPORT jlong JNICALL
102 Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset( 102 Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
103 JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { 103 JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
104 - return 0; 104 +#if __ANDROID_API__ >= 9
  105 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  106 + if (!mgr) {
  107 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  108 + return 0;
  109 + }
  110 +#endif
  111 +
  112 + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
  113 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  114 +
  115 + auto sd = new sherpa_onnx::OfflineSpeakerDiarization(
  116 +#if __ANDROID_API__ >= 9
  117 + mgr,
  118 +#endif
  119 + config);
  120 +
  121 + return (jlong)sd;
105 } 122 }
106 123
107 SHERPA_ONNX_EXTERN_C 124 SHERPA_ONNX_EXTERN_C
@@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( @@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
105 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 105 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
106 if (!mgr) { 106 if (!mgr) {
107 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 107 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  108 + return 0;
108 } 109 }
109 #endif 110 #endif
110 auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); 111 auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
@@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env, @@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
267 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 267 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
268 if (!mgr) { 268 if (!mgr) {
269 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 269 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  270 + return 0;
270 } 271 }
271 #endif 272 #endif
272 auto config = sherpa_onnx::GetConfig(env, _config); 273 auto config = sherpa_onnx::GetConfig(env, _config);
@@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset( @@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
45 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 45 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
46 if (!mgr) { 46 if (!mgr) {
47 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 47 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  48 + return 0;
48 } 49 }
49 #endif 50 #endif
50 auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); 51 auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
@@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset( @@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset(
62 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 62 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
63 if (!mgr) { 63 if (!mgr) {
64 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 64 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  65 + return 0;
65 } 66 }
66 #endif 67 #endif
67 68
@@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset( @@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
71 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); 71 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
72 if (!mgr) { 72 if (!mgr) {
73 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); 73 SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  74 + return 0;
74 } 75 }
75 #endif 76 #endif
76 auto config = sherpa_onnx::GetVadModelConfig(env, _config); 77 auto config = sherpa_onnx::GetVadModelConfig(env, _config);
77 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); 78 SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  79 +
78 auto model = new sherpa_onnx::VoiceActivityDetector( 80 auto model = new sherpa_onnx::VoiceActivityDetector(
79 #if __ANDROID_API__ >= 9 81 #if __ANDROID_API__ >= 9
80 mgr, 82 mgr,