Fangjun Kuang
Committed by GitHub

Kotlin API for speaker diarization (#1415)

  1 +../sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
@@ -285,6 +285,37 @@ function testPunctuation() { @@ -285,6 +285,37 @@ function testPunctuation() {
285 java -Djava.library.path=../build/lib -jar $out_filename 285 java -Djava.library.path=../build/lib -jar $out_filename
286 } 286 }
287 287
  288 +function testOfflineSpeakerDiarization() {
  289 + if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then
  290 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  291 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  292 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  293 + fi
  294 +
  295 + if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then
  296 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  297 + fi
  298 +
  299 + if [ ! -f ./0-four-speakers-zh.wav ]; then
  300 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  301 + fi
  302 +
  303 + out_filename=test_offline_speaker_diarization.jar
  304 + kotlinc-jvm -include-runtime -d $out_filename \
  305 + test_offline_speaker_diarization.kt \
  306 + OfflineSpeakerDiarization.kt \
  307 + Speaker.kt \
  308 + OnlineStream.kt \
  309 + WaveReader.kt \
  310 + faked-asset-manager.kt \
  311 + faked-log.kt
  312 +
  313 + ls -lh $out_filename
  314 +
  315 + java -Djava.library.path=../build/lib -jar $out_filename
  316 +}
  317 +
  318 +testOfflineSpeakerDiarization
288 testSpeakerEmbeddingExtractor 319 testSpeakerEmbeddingExtractor
289 testOnlineAsr 320 testOnlineAsr
290 testTts 321 testTts
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +fun main() {
  4 + testOfflineSpeakerDiarization()
  5 +}
  6 +
  7 +fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int {
  8 + val progress = numProcessedChunks.toFloat() / numTotalChunks * 100
  9 + val s = "%.2f".format(progress)
  10 + println("Progress: ${s}%");
  11 +
  12 + return 0
  13 +}
  14 +
  15 +fun testOfflineSpeakerDiarization() {
  16 + var config = OfflineSpeakerDiarizationConfig(
  17 + segmentation=OfflineSpeakerSegmentationModelConfig(
  18 + pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"),
  19 + ),
  20 + embedding=SpeakerEmbeddingExtractorConfig(
  21 + model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
  22 + ),
  23 +
  24 + // The test wave file ./0-four-speakers-zh.wav contains four speakers, so
  25 + // we use numClusters=4 here. If you don't know the number of speakers
  26 + // in the test wave file, please set the threshold like below.
  27 + //
  28 + // clustering=FastClusteringConfig(threshold=0.5),
  29 + //
  30 + // WARNING: You need to tune threshold by yourself.
  31 + // A larger threshold leads to fewer clusters, i.e., few speakers.
  32 + // A smaller threshold leads to more clusters, i.e., more speakers.
  33 + //
  34 + clustering=FastClusteringConfig(numClusters=4),
  35 + )
  36 +
  37 + val sd = OfflineSpeakerDiarization(config=config)
  38 +
  39 + val waveData = WaveReader.readWave(
  40 + filename = "./0-four-speakers-zh.wav",
  41 + )
  42 +
  43 + if (sd.sampleRate() != waveData.sampleRate) {
  44 + println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}")
  45 + return
  46 + }
  47 +
  48 + // val segments = sd.process(waveData.samples) // this one is also ok
  49 + val segments = sd.processWithCallback(waveData.samples, callback=::callback)
  50 + for (segment in segments) {
  51 + println("${segment.start} -- ${segment.end} speaker_${segment.speaker}")
  52 + }
  53 +}
@@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult { @@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult {
58 std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker() 58 std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
59 const; 59 const;
60 60
61 - public: 61 + private:
62 std::vector<OfflineSpeakerDiarizationSegment> segments_; 62 std::vector<OfflineSpeakerDiarizationSegment> segments_;
63 }; 63 };
64 64
@@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
33 ) 33 )
34 endif() 34 endif()
35 35
  36 +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  37 + list(APPEND sources
  38 + offline-speaker-diarization.cc
  39 + )
  40 +endif()
  41 +
36 add_library(sherpa-onnx-jni ${sources}) 42 add_library(sherpa-onnx-jni ${sources})
37 43
38 target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1) 44 target_compile_definitions(sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1)
  1 +// sherpa-onnx/jni/offline-speaker-diarization.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig(
  13 + JNIEnv *env, jobject config) {
  14 + OfflineSpeakerDiarizationConfig ans;
  15 +
  16 + jclass cls = env->GetObjectClass(config);
  17 + jfieldID fid;
  18 +
  19 + //---------- segmentation ----------
  20 + fid = env->GetFieldID(
  21 + cls, "segmentation",
  22 + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;");
  23 + jobject segmentation_config = env->GetObjectField(config, fid);
  24 + jclass segmentation_config_cls = env->GetObjectClass(segmentation_config);
  25 +
  26 + fid = env->GetFieldID(
  27 + segmentation_config_cls, "pyannote",
  28 + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;");
  29 + jobject pyannote_config = env->GetObjectField(segmentation_config, fid);
  30 + jclass pyannote_config_cls = env->GetObjectClass(pyannote_config);
  31 +
  32 + fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;");
  33 + jstring s = (jstring)env->GetObjectField(pyannote_config, fid);
  34 + const char *p = env->GetStringUTFChars(s, nullptr);
  35 + ans.segmentation.pyannote.model = p;
  36 + env->ReleaseStringUTFChars(s, p);
  37 +
  38 + fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I");
  39 + ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid);
  40 +
  41 + fid = env->GetFieldID(segmentation_config_cls, "debug", "Z");
  42 + ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid);
  43 +
  44 + fid = env->GetFieldID(segmentation_config_cls, "provider",
  45 + "Ljava/lang/String;");
  46 + s = (jstring)env->GetObjectField(segmentation_config, fid);
  47 + p = env->GetStringUTFChars(s, nullptr);
  48 + ans.segmentation.provider = p;
  49 + env->ReleaseStringUTFChars(s, p);
  50 +
  51 + //---------- embedding ----------
  52 + fid = env->GetFieldID(
  53 + cls, "embedding",
  54 + "Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;");
  55 + jobject embedding_config = env->GetObjectField(config, fid);
  56 + jclass embedding_config_cls = env->GetObjectClass(embedding_config);
  57 +
  58 + fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;");
  59 + s = (jstring)env->GetObjectField(embedding_config, fid);
  60 + p = env->GetStringUTFChars(s, nullptr);
  61 + ans.embedding.model = p;
  62 + env->ReleaseStringUTFChars(s, p);
  63 +
  64 + fid = env->GetFieldID(embedding_config_cls, "numThreads", "I");
  65 + ans.embedding.num_threads = env->GetIntField(embedding_config, fid);
  66 +
  67 + fid = env->GetFieldID(embedding_config_cls, "debug", "Z");
  68 + ans.embedding.debug = env->GetBooleanField(embedding_config, fid);
  69 +
  70 + fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;");
  71 + s = (jstring)env->GetObjectField(embedding_config, fid);
  72 + p = env->GetStringUTFChars(s, nullptr);
  73 + ans.embedding.provider = p;
  74 + env->ReleaseStringUTFChars(s, p);
  75 +
  76 + //---------- clustering ----------
  77 + fid = env->GetFieldID(cls, "clustering",
  78 + "Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;");
  79 + jobject clustering_config = env->GetObjectField(config, fid);
  80 + jclass clustering_config_cls = env->GetObjectClass(clustering_config);
  81 +
  82 + fid = env->GetFieldID(clustering_config_cls, "numClusters", "I");
  83 + ans.clustering.num_clusters = env->GetIntField(clustering_config, fid);
  84 +
  85 + fid = env->GetFieldID(clustering_config_cls, "threshold", "F");
  86 + ans.clustering.threshold = env->GetFloatField(clustering_config, fid);
  87 +
  88 + // its own fields
  89 + fid = env->GetFieldID(cls, "minDurationOn", "F");
  90 + ans.min_duration_on = env->GetFloatField(config, fid);
  91 +
  92 + fid = env->GetFieldID(cls, "minDurationOff", "F");
  93 + ans.min_duration_off = env->GetFloatField(config, fid);
  94 +
  95 + return ans;
  96 +}
  97 +
  98 +} // namespace sherpa_onnx
  99 +
  100 +SHERPA_ONNX_EXTERN_C
  101 +JNIEXPORT jlong JNICALL
  102 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset(
  103 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  104 + return 0;
  105 +}
  106 +
  107 +SHERPA_ONNX_EXTERN_C
  108 +JNIEXPORT jlong JNICALL
  109 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile(
  110 + JNIEnv *env, jobject /*obj*/, jobject _config) {
  111 + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
  112 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  113 +
  114 + if (!config.Validate()) {
  115 + SHERPA_ONNX_LOGE("Errors found in config!");
  116 + return 0;
  117 + }
  118 +
  119 + auto sd = new sherpa_onnx::OfflineSpeakerDiarization(config);
  120 +
  121 + return (jlong)sd;
  122 +}
  123 +
  124 +SHERPA_ONNX_EXTERN_C
  125 +JNIEXPORT void JNICALL
  126 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig(
  127 + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) {
  128 + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config);
  129 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  130 +
  131 + auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
  132 + sd->SetConfig(config);
  133 +}
  134 +
  135 +SHERPA_ONNX_EXTERN_C
  136 +JNIEXPORT void JNICALL
  137 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/,
  138 + jobject /*obj*/,
  139 + jlong ptr) {
  140 + delete reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
  141 +}
  142 +
  143 +static jobjectArray ProcessImpl(
  144 + JNIEnv *env,
  145 + const std::vector<sherpa_onnx::OfflineSpeakerDiarizationSegment>
  146 + &segments) {
  147 + jclass cls =
  148 + env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment");
  149 +
  150 + jobjectArray obj_arr =
  151 + (jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr);
  152 +
  153 + jmethodID constructor = env->GetMethodID(cls, "<init>", "(FFI)V");
  154 +
  155 + for (int32_t i = 0; i != segments.size(); ++i) {
  156 + const auto &s = segments[i];
  157 + jobject segment =
  158 + env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker());
  159 + env->SetObjectArrayElement(obj_arr, i, segment);
  160 + }
  161 +
  162 + return obj_arr;
  163 +}
  164 +
  165 +SHERPA_ONNX_EXTERN_C
  166 +JNIEXPORT jobjectArray JNICALL
  167 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process(
  168 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
  169 + auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
  170 +
  171 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  172 + jsize n = env->GetArrayLength(samples);
  173 + auto segments = sd->Process(p, n).SortByStartTime();
  174 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  175 +
  176 + return ProcessImpl(env, segments);
  177 +}
  178 +
  179 +SHERPA_ONNX_EXTERN_C
  180 +JNIEXPORT jobjectArray JNICALL
  181 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback(
  182 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  183 + jobject callback, jlong arg) {
  184 + std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
  185 + [env, callback](int32_t num_processed_chunks, int32_t num_total_chunks,
  186 + void *data) -> int {
  187 + jclass cls = env->GetObjectClass(callback);
  188 +
  189 + jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;");
  190 + if (mid == nullptr) {
  191 + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it.");
  192 + return 0;
  193 + }
  194 +
  195 + jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks,
  196 + num_total_chunks, (jlong)data);
  197 + jclass jklass = env->GetObjectClass(ret);
  198 + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I");
  199 + return env->CallIntMethod(ret, int_value_mid);
  200 + };
  201 +
  202 + auto sd = reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr);
  203 +
  204 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  205 + jsize n = env->GetArrayLength(samples);
  206 + auto segments =
  207 + sd->Process(p, n, callback_wrapper, (void *)arg).SortByStartTime();
  208 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  209 +
  210 + return ProcessImpl(env, segments);
  211 +}
  212 +
  213 +SHERPA_ONNX_EXTERN_C
  214 +JNIEXPORT jint JNICALL
  215 +Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate(
  216 + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
  217 + return reinterpret_cast<sherpa_onnx::OfflineSpeakerDiarization *>(ptr)
  218 + ->SampleRate();
  219 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +data class OfflineSpeakerSegmentationPyannoteModelConfig(
  6 + var model: String,
  7 +)
  8 +
  9 +data class OfflineSpeakerSegmentationModelConfig(
  10 + var pyannote: OfflineSpeakerSegmentationPyannoteModelConfig,
  11 + var numThreads: Int = 1,
  12 + var debug: Boolean = false,
  13 + var provider: String = "cpu",
  14 +)
  15 +
  16 +data class FastClusteringConfig(
  17 + var numClusters: Int = -1,
  18 + var threshold: Float = 0.5f,
  19 +)
  20 +
  21 +data class OfflineSpeakerDiarizationConfig(
  22 + var segmentation: OfflineSpeakerSegmentationModelConfig,
  23 + var embedding: SpeakerEmbeddingExtractorConfig,
  24 + var clustering: FastClusteringConfig,
  25 + var minDurationOn: Float = 0.2f,
  26 + var minDurationOff: Float = 0.5f,
  27 +)
  28 +
  29 +data class OfflineSpeakerDiarizationSegment(
  30 + val start: Float, // in seconds
  31 + val end: Float, // in seconds
  32 + val speaker: Int, // ID of the speaker; count from 0
  33 +)
  34 +
  35 +class OfflineSpeakerDiarization(
  36 + assetManager: AssetManager? = null,
  37 + config: OfflineSpeakerDiarizationConfig,
  38 +) {
  39 + private var ptr: Long
  40 +
  41 + init {
  42 + ptr = if (assetManager != null) {
  43 + newFromAsset(assetManager, config)
  44 + } else {
  45 + newFromFile(config)
  46 + }
  47 + }
  48 +
  49 + protected fun finalize() {
  50 + if (ptr != 0L) {
  51 + delete(ptr)
  52 + ptr = 0
  53 + }
  54 + }
  55 +
  56 + fun release() = finalize()
  57 +
  58 + // Only config.clustering is used. All other fields in config
  59 + // are ignored
  60 + fun setConfig(config: OfflineSpeakerDiarizationConfig) = setConfig(ptr, config)
  61 +
  62 + fun sampleRate() = getSampleRate(ptr)
  63 +
  64 + fun process(samples: FloatArray) = process(ptr, samples)
  65 +
  66 + fun processWithCallback(
  67 + samples: FloatArray,
  68 + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
  69 + arg: Long = 0,
  70 + ) = processWithCallback(ptr, samples, callback, arg)
  71 +
  72 + private external fun delete(ptr: Long)
  73 +
  74 + private external fun newFromAsset(
  75 + assetManager: AssetManager,
  76 + config: OfflineSpeakerDiarizationConfig,
  77 + ): Long
  78 +
  79 + private external fun newFromFile(
  80 + config: OfflineSpeakerDiarizationConfig,
  81 + ): Long
  82 +
  83 + private external fun setConfig(ptr: Long, config: OfflineSpeakerDiarizationConfig)
  84 +
  85 + private external fun getSampleRate(ptr: Long): Int
  86 +
  87 + private external fun process(ptr: Long, samples: FloatArray): Array<OfflineSpeakerDiarizationSegment>
  88 +
  89 + private external fun processWithCallback(
  90 + ptr: Long,
  91 + samples: FloatArray,
  92 + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
  93 + arg: Long,
  94 + ): Array<OfflineSpeakerDiarizationSegment>
  95 +
  96 + companion object {
  97 + init {
  98 + System.loadLibrary("sherpa-onnx-jni")
  99 + }
  100 + }
  101 +}