Fangjun Kuang
Committed by GitHub

Add JNI support for spoken language identification (#782)

... ... @@ -161,10 +161,12 @@ jobs:
./run-vits-vctk.sh
rm -rf vits-vctk
echo "Test vits-zh-aishell3"
git clone https://huggingface.co/csukuangfj/vits-zh-aishell3
echo "Test vits-icefall-zh-aishell3"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2
tar xvf vits-icefall-zh-aishell3.tar.bz2
rm vits-icefall-zh-aishell3.tar.bz2
./run-vits-zh-aishell3.sh
rm -rf vits-zh-aishell3
rm -rf vits-icefall-zh-aishell3*
echo "Test vits-piper-en_US-lessac-medium"
git clone https://huggingface.co/csukuangfj/vits-piper-en_US-lessac-medium
... ...
... ... @@ -92,3 +92,4 @@ sr-data
vits-icefall-*
sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
spoken-language-identification-test-wavs
... ...
... ... @@ -6,7 +6,7 @@ import android.util.Log
private val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig(
val model: String,
var model: String,
)
data class AudioTaggingModelConfig(
... ... @@ -134,4 +134,4 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
}
return null
}
\ No newline at end of file
}
... ...
... ... @@ -7,6 +7,7 @@ fun callback(samples: FloatArray): Unit {
}
fun main() {
testSpokenLanguageIdentifcation()
testAudioTagging()
testSpeakerRecognition()
testTts()
... ... @@ -14,6 +15,41 @@ fun main() {
testAsr("zipformer2-ctc")
}
fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(assetManager=null, config=config)
val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)
for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int
val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
}
fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
... ...
... ... @@ -5,32 +5,22 @@ import android.util.Log
private val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig (
val model: String,
data class SpokenLanguageIdentificationWhisperConfig (
var encoder: String,
var decoder: String,
var tailPaddings: Int = -1,
)
data class AudioTaggingModelConfig (
var zipformer: OfflineZipformerAudioTaggingModelConfig,
data class SpokenLanguageIdentificationConfig (
var whisper: SpokenLanguageIdentificationWhisperConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class AudioTaggingConfig (
var model: AudioTaggingModelConfig,
var labels: String,
var topK: Int = 5,
)
data class AudioEvent (
val name: String,
val index: Int,
val prob: Float,
)
class AudioTagging(
class SpokenLanguageIdentification (
assetManager: AssetManager? = null,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
) {
private var ptr: Long
... ... @@ -43,10 +33,10 @@ class AudioTagging(
}
protected fun finalize() {
if(ptr != 0) {
delete(ptr)
ptr = 0
}
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
... ... @@ -56,25 +46,22 @@ class AudioTagging(
return OfflineStream(p)
}
// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
var events :Array<Any> = compute(ptr, stream.ptr, topK)
}
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)
private external fun newFromAsset(
assetManager: AssetManager,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long
private external fun newFromFile(
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long
private external fun delete(ptr: Long)
private external fun createStream(ptr: Long): Long
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
private external fun compute(ptr: Long, streamPtr: Long): String
companion object {
init {
... ...
... ... @@ -30,19 +30,19 @@ cd ../kotlin-api-examples
function testSpeakerEmbeddingExtractor() {
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
fi
if [ ! -f ./speaker1_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
fi
if [ ! -f ./speaker1_b_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
fi
if [ ! -f ./speaker2_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
fi
}
... ... @@ -53,7 +53,7 @@ function testAsr() {
fi
if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
... ... @@ -61,7 +61,7 @@ function testAsr() {
function testTts() {
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2
rm vits-piper-en_US-amy-low.tar.bz2
fi
... ... @@ -75,7 +75,22 @@ function testAudioTagging() {
fi
}
function testSpokenLanguageIdentification() {
if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
fi
if [ ! -f ./spoken-language-identification-test-wavs/ar-arabic.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
fi
}
function test() {
testSpokenLanguageIdentification
testAudioTagging
testSpeakerEmbeddingExtractor
testAsr
... ... @@ -90,6 +105,7 @@ kotlinc-jvm -include-runtime -d main.jar \
OfflineStream.kt \
SherpaOnnx.kt \
Speaker.kt \
SpokenLanguageIdentification.kt \
Tts.kt \
WaveReader.kt \
faked-asset-manager.kt \
... ... @@ -101,13 +117,13 @@ java -Djava.library.path=../build/lib -jar main.jar
function testTwoPass() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
fi
if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
fi
... ...
... ... @@ -13,6 +13,7 @@ add_library(sherpa-onnx-jni
audio-tagging.cc
jni.cc
offline-stream.cc
spoken-language-identification.cc
)
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)
... ...
// sherpa-onnx/jni/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
JNIEnv *env, jobject config) {
SpokenLanguageIdentificationConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(
cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;");
jobject whisper = env->GetObjectField(config, fid);
jclass whisper_cls = env->GetObjectClass(whisper);
fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(whisper, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(whisper_cls, "tailPaddings", "I");
ans.whisper.tail_paddings = env->GetIntField(whisper, fid);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config =
sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s",
config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config);
return (jlong)tagger;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
std::unique_ptr<sherpa_onnx::OfflineStream> s = slid->CreateStream();
// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OfflineStream *p = s.release();
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong s_ptr) {
sherpa_onnx::SpokenLanguageIdentification *slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
sherpa_onnx::OfflineStream *s =
reinterpret_cast<sherpa_onnx::OfflineStream *>(s_ptr);
std::string lang = slid->Compute(s);
return env->NewStringUTF(lang.c_str());
}
... ...