Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-10-11 14:41:53 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-10-11 14:41:53 +0800
Commit
2d412b1190778bc35f337ef1feeb12292b5c9f92
2d412b11
1 parent
eefc1720
Kotlin API for speaker diarization (#1415)
隐藏空白字符变更
内嵌
并排对比
正在显示
7 个修改的文件
包含
412 行增加
和
1 行删除
kotlin-api-examples/OfflineSpeakerDiarization.kt
kotlin-api-examples/run.sh
kotlin-api-examples/test_offline_speaker_diarization.kt
sherpa-onnx/csrc/offline-speaker-diarization-result.h
sherpa-onnx/jni/CMakeLists.txt
sherpa-onnx/jni/offline-speaker-diarization.cc
sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
kotlin-api-examples/OfflineSpeakerDiarization.kt
0 → 120000
查看文件 @
2d412b1
../sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
\ No newline at end of file
...
...
kotlin-api-examples/run.sh
查看文件 @
2d412b1
...
...
@@ -285,6 +285,37 @@ function testPunctuation() {
java -Djava.library.path
=
../build/lib -jar
$out_filename
}
function
testOfflineSpeakerDiarization
()
{
if
[
! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx
]
;
then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
fi
if
[
! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
]
;
then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
fi
if
[
! -f ./0-four-speakers-zh.wav
]
;
then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
fi
out_filename
=
test_offline_speaker_diarization.jar
kotlinc-jvm -include-runtime -d
$out_filename
\
test_offline_speaker_diarization.kt
\
OfflineSpeakerDiarization.kt
\
Speaker.kt
\
OnlineStream.kt
\
WaveReader.kt
\
faked-asset-manager.kt
\
faked-log.kt
ls -lh
$out_filename
java -Djava.library.path
=
../build/lib -jar
$out_filename
}
testOfflineSpeakerDiarization
testSpeakerEmbeddingExtractor
testOnlineAsr
testTts
...
...
kotlin-api-examples/test_offline_speaker_diarization.kt
0 → 100644
查看文件 @
2d412b1
package com.k2fsa.sherpa.onnx
fun main() {
testOfflineSpeakerDiarization()
}
fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int {
val progress = numProcessedChunks.toFloat() / numTotalChunks * 100
val s = "%.2f".format(progress)
println("Progress: ${s}%");
return 0
}
fun testOfflineSpeakerDiarization() {
var config = OfflineSpeakerDiarizationConfig(
segmentation=OfflineSpeakerSegmentationModelConfig(
pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"),
),
embedding=SpeakerEmbeddingExtractorConfig(
model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
),
// The test wave file ./0-four-speakers-zh.wav contains four speakers, so
// we use numClusters=4 here. If you don't know the number of speakers
// in the test wave file, please set the threshold like below.
//
// clustering=FastClusteringConfig(threshold=0.5),
//
// WARNING: You need to tune threshold by yourself.
// A larger threshold leads to fewer clusters, i.e., few speakers.
// A smaller threshold leads to more clusters, i.e., more speakers.
//
clustering=FastClusteringConfig(numClusters=4),
)
val sd = OfflineSpeakerDiarization(config=config)
val waveData = WaveReader.readWave(
filename = "./0-four-speakers-zh.wav",
)
if (sd.sampleRate() != waveData.sampleRate) {
println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}")
return
}
// val segments = sd.process(waveData.samples) // this one is also ok
val segments = sd.processWithCallback(waveData.samples, callback=::callback)
for (segment in segments) {
println("${segment.start} -- ${segment.end} speaker_${segment.speaker}")
}
}
...
...
sherpa-onnx/csrc/offline-speaker-diarization-result.h
查看文件 @
2d412b1
...
...
@@ -58,7 +58,7 @@ class OfflineSpeakerDiarizationResult {
std
::
vector
<
std
::
vector
<
OfflineSpeakerDiarizationSegment
>>
SortBySpeaker
()
const
;
p
ublic
:
p
rivate
:
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
segments_
;
};
...
...
sherpa-onnx/jni/CMakeLists.txt
查看文件 @
2d412b1
...
...
@@ -33,6 +33,12 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif
()
if
(
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION
)
list
(
APPEND sources
offline-speaker-diarization.cc
)
endif
()
add_library
(
sherpa-onnx-jni
${
sources
}
)
target_compile_definitions
(
sherpa-onnx-jni PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1
)
...
...
sherpa-onnx/jni/offline-speaker-diarization.cc
0 → 100644
查看文件 @
2d412b1
// sherpa-onnx/jni/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace
sherpa_onnx
{
static
OfflineSpeakerDiarizationConfig
GetOfflineSpeakerDiarizationConfig
(
JNIEnv
*
env
,
jobject
config
)
{
OfflineSpeakerDiarizationConfig
ans
;
jclass
cls
=
env
->
GetObjectClass
(
config
);
jfieldID
fid
;
//---------- segmentation ----------
fid
=
env
->
GetFieldID
(
cls
,
"segmentation"
,
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;"
);
jobject
segmentation_config
=
env
->
GetObjectField
(
config
,
fid
);
jclass
segmentation_config_cls
=
env
->
GetObjectClass
(
segmentation_config
);
fid
=
env
->
GetFieldID
(
segmentation_config_cls
,
"pyannote"
,
"Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;"
);
jobject
pyannote_config
=
env
->
GetObjectField
(
segmentation_config
,
fid
);
jclass
pyannote_config_cls
=
env
->
GetObjectClass
(
pyannote_config
);
fid
=
env
->
GetFieldID
(
pyannote_config_cls
,
"model"
,
"Ljava/lang/String;"
);
jstring
s
=
(
jstring
)
env
->
GetObjectField
(
pyannote_config
,
fid
);
const
char
*
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
segmentation
.
pyannote
.
model
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
segmentation_config_cls
,
"numThreads"
,
"I"
);
ans
.
segmentation
.
num_threads
=
env
->
GetIntField
(
segmentation_config
,
fid
);
fid
=
env
->
GetFieldID
(
segmentation_config_cls
,
"debug"
,
"Z"
);
ans
.
segmentation
.
debug
=
env
->
GetBooleanField
(
segmentation_config
,
fid
);
fid
=
env
->
GetFieldID
(
segmentation_config_cls
,
"provider"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
segmentation_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
segmentation
.
provider
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
//---------- embedding ----------
fid
=
env
->
GetFieldID
(
cls
,
"embedding"
,
"Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;"
);
jobject
embedding_config
=
env
->
GetObjectField
(
config
,
fid
);
jclass
embedding_config_cls
=
env
->
GetObjectClass
(
embedding_config
);
fid
=
env
->
GetFieldID
(
embedding_config_cls
,
"model"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
embedding_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
embedding
.
model
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
embedding_config_cls
,
"numThreads"
,
"I"
);
ans
.
embedding
.
num_threads
=
env
->
GetIntField
(
embedding_config
,
fid
);
fid
=
env
->
GetFieldID
(
embedding_config_cls
,
"debug"
,
"Z"
);
ans
.
embedding
.
debug
=
env
->
GetBooleanField
(
embedding_config
,
fid
);
fid
=
env
->
GetFieldID
(
embedding_config_cls
,
"provider"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
embedding_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
embedding
.
provider
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
//---------- clustering ----------
fid
=
env
->
GetFieldID
(
cls
,
"clustering"
,
"Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;"
);
jobject
clustering_config
=
env
->
GetObjectField
(
config
,
fid
);
jclass
clustering_config_cls
=
env
->
GetObjectClass
(
clustering_config
);
fid
=
env
->
GetFieldID
(
clustering_config_cls
,
"numClusters"
,
"I"
);
ans
.
clustering
.
num_clusters
=
env
->
GetIntField
(
clustering_config
,
fid
);
fid
=
env
->
GetFieldID
(
clustering_config_cls
,
"threshold"
,
"F"
);
ans
.
clustering
.
threshold
=
env
->
GetFloatField
(
clustering_config
,
fid
);
// its own fields
fid
=
env
->
GetFieldID
(
cls
,
"minDurationOn"
,
"F"
);
ans
.
min_duration_on
=
env
->
GetFloatField
(
config
,
fid
);
fid
=
env
->
GetFieldID
(
cls
,
"minDurationOff"
,
"F"
);
ans
.
min_duration_off
=
env
->
GetFloatField
(
config
,
fid
);
return
ans
;
}
}
// namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT
jlong
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset
(
JNIEnv
*
env
,
jobject
/*obj*/
,
jobject
asset_manager
,
jobject
_config
)
{
return
0
;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
jlong
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromFile
(
JNIEnv
*
env
,
jobject
/*obj*/
,
jobject
_config
)
{
auto
config
=
sherpa_onnx
::
GetOfflineSpeakerDiarizationConfig
(
env
,
_config
);
SHERPA_ONNX_LOGE
(
"config:
\n
%s"
,
config
.
ToString
().
c_str
());
if
(
!
config
.
Validate
())
{
SHERPA_ONNX_LOGE
(
"Errors found in config!"
);
return
0
;
}
auto
sd
=
new
sherpa_onnx
::
OfflineSpeakerDiarization
(
config
);
return
(
jlong
)
sd
;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
void
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_setConfig
(
JNIEnv
*
env
,
jobject
/*obj*/
,
jlong
ptr
,
jobject
_config
)
{
auto
config
=
sherpa_onnx
::
GetOfflineSpeakerDiarizationConfig
(
env
,
_config
);
SHERPA_ONNX_LOGE
(
"config:
\n
%s"
,
config
.
ToString
().
c_str
());
auto
sd
=
reinterpret_cast
<
sherpa_onnx
::
OfflineSpeakerDiarization
*>
(
ptr
);
sd
->
SetConfig
(
config
);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
void
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_delete
(
JNIEnv
*
/*env*/
,
jobject
/*obj*/
,
jlong
ptr
)
{
delete
reinterpret_cast
<
sherpa_onnx
::
OfflineSpeakerDiarization
*>
(
ptr
);
}
static
jobjectArray
ProcessImpl
(
JNIEnv
*
env
,
const
std
::
vector
<
sherpa_onnx
::
OfflineSpeakerDiarizationSegment
>
&
segments
)
{
jclass
cls
=
env
->
FindClass
(
"com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment"
);
jobjectArray
obj_arr
=
(
jobjectArray
)
env
->
NewObjectArray
(
segments
.
size
(),
cls
,
nullptr
);
jmethodID
constructor
=
env
->
GetMethodID
(
cls
,
"<init>"
,
"(FFI)V"
);
for
(
int32_t
i
=
0
;
i
!=
segments
.
size
();
++
i
)
{
const
auto
&
s
=
segments
[
i
];
jobject
segment
=
env
->
NewObject
(
cls
,
constructor
,
s
.
Start
(),
s
.
End
(),
s
.
Speaker
());
env
->
SetObjectArrayElement
(
obj_arr
,
i
,
segment
);
}
return
obj_arr
;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
jobjectArray
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_process
(
JNIEnv
*
env
,
jobject
/*obj*/
,
jlong
ptr
,
jfloatArray
samples
)
{
auto
sd
=
reinterpret_cast
<
sherpa_onnx
::
OfflineSpeakerDiarization
*>
(
ptr
);
jfloat
*
p
=
env
->
GetFloatArrayElements
(
samples
,
nullptr
);
jsize
n
=
env
->
GetArrayLength
(
samples
);
auto
segments
=
sd
->
Process
(
p
,
n
).
SortByStartTime
();
env
->
ReleaseFloatArrayElements
(
samples
,
p
,
JNI_ABORT
);
return
ProcessImpl
(
env
,
segments
);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
jobjectArray
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_processWithCallback
(
JNIEnv
*
env
,
jobject
/*obj*/
,
jlong
ptr
,
jfloatArray
samples
,
jobject
callback
,
jlong
arg
)
{
std
::
function
<
int32_t
(
int32_t
,
int32_t
,
void
*
)
>
callback_wrapper
=
[
env
,
callback
](
int32_t
num_processed_chunks
,
int32_t
num_total_chunks
,
void
*
data
)
->
int
{
jclass
cls
=
env
->
GetObjectClass
(
callback
);
jmethodID
mid
=
env
->
GetMethodID
(
cls
,
"invoke"
,
"(IIJ)Ljava/lang/Integer;"
);
if
(
mid
==
nullptr
)
{
SHERPA_ONNX_LOGE
(
"Failed to get the callback. Ignore it."
);
return
0
;
}
jobject
ret
=
env
->
CallObjectMethod
(
callback
,
mid
,
num_processed_chunks
,
num_total_chunks
,
(
jlong
)
data
);
jclass
jklass
=
env
->
GetObjectClass
(
ret
);
jmethodID
int_value_mid
=
env
->
GetMethodID
(
jklass
,
"intValue"
,
"()I"
);
return
env
->
CallIntMethod
(
ret
,
int_value_mid
);
};
auto
sd
=
reinterpret_cast
<
sherpa_onnx
::
OfflineSpeakerDiarization
*>
(
ptr
);
jfloat
*
p
=
env
->
GetFloatArrayElements
(
samples
,
nullptr
);
jsize
n
=
env
->
GetArrayLength
(
samples
);
auto
segments
=
sd
->
Process
(
p
,
n
,
callback_wrapper
,
(
void
*
)
arg
).
SortByStartTime
();
env
->
ReleaseFloatArrayElements
(
samples
,
p
,
JNI_ABORT
);
return
ProcessImpl
(
env
,
segments
);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT
jint
JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_getSampleRate
(
JNIEnv
*
/*env*/
,
jobject
/*obj*/
,
jlong
ptr
)
{
return
reinterpret_cast
<
sherpa_onnx
::
OfflineSpeakerDiarization
*>
(
ptr
)
->
SampleRate
();
}
...
...
sherpa-onnx/kotlin-api/OfflineSpeakerDiarization.kt
0 → 100644
查看文件 @
2d412b1
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OfflineSpeakerSegmentationPyannoteModelConfig(
var model: String,
)
data class OfflineSpeakerSegmentationModelConfig(
var pyannote: OfflineSpeakerSegmentationPyannoteModelConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class FastClusteringConfig(
var numClusters: Int = -1,
var threshold: Float = 0.5f,
)
data class OfflineSpeakerDiarizationConfig(
var segmentation: OfflineSpeakerSegmentationModelConfig,
var embedding: SpeakerEmbeddingExtractorConfig,
var clustering: FastClusteringConfig,
var minDurationOn: Float = 0.2f,
var minDurationOff: Float = 0.5f,
)
data class OfflineSpeakerDiarizationSegment(
val start: Float, // in seconds
val end: Float, // in seconds
val speaker: Int, // ID of the speaker; count from 0
)
class OfflineSpeakerDiarization(
assetManager: AssetManager? = null,
config: OfflineSpeakerDiarizationConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
// Only config.clustering is used. All other fields in config
// are ignored
fun setConfig(config: OfflineSpeakerDiarizationConfig) = setConfig(ptr, config)
fun sampleRate() = getSampleRate(ptr)
fun process(samples: FloatArray) = process(ptr, samples)
fun processWithCallback(
samples: FloatArray,
callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
arg: Long = 0,
) = processWithCallback(ptr, samples, callback, arg)
private external fun delete(ptr: Long)
private external fun newFromAsset(
assetManager: AssetManager,
config: OfflineSpeakerDiarizationConfig,
): Long
private external fun newFromFile(
config: OfflineSpeakerDiarizationConfig,
): Long
private external fun setConfig(ptr: Long, config: OfflineSpeakerDiarizationConfig)
private external fun getSampleRate(ptr: Long): Int
private external fun process(ptr: Long, samples: FloatArray): Array<OfflineSpeakerDiarizationSegment>
private external fun processWithCallback(
ptr: Long,
samples: FloatArray,
callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int,
arg: Long,
): Array<OfflineSpeakerDiarizationSegment>
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}
...
...
请
注册
或
登录
后发表评论