Fangjun Kuang
Committed by GitHub

Add Kotlin API for speech enhancement GTCRN models (#2008)

  1 +../sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt
@@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() { @@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() {
371 java -Djava.library.path=../build/lib -jar $out_filename 371 java -Djava.library.path=../build/lib -jar $out_filename
372 } 372 }
373 373
  374 +function testOfflineSpeechDenoiser() {
  375 + if [ ! -f ./gtcrn_simple.onnx ]; then
  376 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  377 + fi
  378 +
  379 + if [ ! -f ./inp_16k.wav ]; then
  380 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
  381 + fi
  382 +
  383 + out_filename=test_offline_speech_denoiser.jar
  384 + kotlinc-jvm -include-runtime -d $out_filename \
  385 + test_offline_speech_denoiser.kt \
  386 + OfflineSpeechDenoiser.kt \
  387 + WaveReader.kt \
  388 + faked-asset-manager.kt \
  389 + faked-log.kt
  390 +
  391 + ls -lh $out_filename
  392 +
  393 + java -Djava.library.path=../build/lib -jar $out_filename
  394 +
  395 + ls -lh *.wav
  396 +}
  397 +
  398 +testOfflineSpeechDenoiser
374 testOfflineSpeakerDiarization 399 testOfflineSpeakerDiarization
375 testSpeakerEmbeddingExtractor 400 testSpeakerEmbeddingExtractor
376 testOnlineAsr 401 testOnlineAsr
  1 +package com.k2fsa.sherpa.onnx
  2 +// Please download test files in this script from
  3 +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  4 +
  5 +fun main() {
  6 + test()
  7 +}
  8 +
  9 +fun test() {
  10 + val denoiser = createOfflineSpeechDenoiser()
  11 +
  12 + val waveFilename = "./inp_16k.wav";
  13 +
  14 + val objArray = WaveReader.readWaveFromFile(
  15 + filename = waveFilename,
  16 + )
  17 + val samples: FloatArray = objArray[0] as FloatArray
  18 + val sampleRate: Int = objArray[1] as Int
  19 +
  20 + val denoised = denoiser.run(samples, sampleRate);
  21 + denoised.save(filename="./enhanced-16k.wav")
  22 + println("saved to ./enhanced-16k.wav")
  23 +}
  24 +
  25 +fun createOfflineSpeechDenoiser(): OfflineSpeechDenoiser {
  26 + val config = OfflineSpeechDenoiserConfig(
  27 + model = OfflineSpeechDenoiserModelConfig(
  28 + gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(
  29 + model = "./gtcrn_simple.onnx"
  30 + ),
  31 + provider = "cpu",
  32 + numThreads = 1,
  33 + ),
  34 + )
  35 +
  36 + println(config)
  37 +
  38 + return OfflineSpeechDenoiser(config = config)
  39 +}
  40 +
  41 +
@@ -16,6 +16,7 @@ set(sources @@ -16,6 +16,7 @@ set(sources
16 keyword-spotter.cc 16 keyword-spotter.cc
17 offline-punctuation.cc 17 offline-punctuation.cc
18 offline-recognizer.cc 18 offline-recognizer.cc
  19 + offline-speech-denoiser.cc
19 offline-stream.cc 20 offline-stream.cc
20 online-punctuation.cc 21 online-punctuation.cc
21 online-recognizer.cc 22 online-recognizer.cc
@@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) { @@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) {
25 return env->NewObject(cls, constructor, value); 25 return env->NewObject(cls, constructor, value);
26 } 26 }
27 27
28 -SHERPA_ONNX_EXTERN_C  
29 -JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(  
30 - JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,  
31 - jint sample_rate) {  
32 - const char *p_filename = env->GetStringUTFChars(filename, nullptr);  
33 -  
34 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
35 - jsize n = env->GetArrayLength(samples);  
36 -  
37 - bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);  
38 -  
39 - env->ReleaseStringUTFChars(filename, p_filename);  
40 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);  
41 -  
42 - return ok;  
43 -}  
44 -  
45 #if 0 28 #if 0
46 SHERPA_ONNX_EXTERN_C 29 SHERPA_ONNX_EXTERN_C
47 JNIEXPORT void JNICALL 30 JNIEXPORT void JNICALL
  1 +// sherpa-onnx/jni/offline-speech-denoiser.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  5 +
  6 +#include "sherpa-onnx/csrc/macros.h"
  7 +#include "sherpa-onnx/csrc/wave-writer.h"
  8 +#include "sherpa-onnx/jni/common.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig(
  12 + JNIEnv *env, jobject config) {
  13 + OfflineSpeechDenoiserConfig ans;
  14 +
  15 + jclass cls = env->GetObjectClass(config);
  16 + jfieldID fid;
  17 +
  18 + fid = env->GetFieldID(
  19 + cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;");
  20 + jobject model = env->GetObjectField(config, fid);
  21 + jclass model_config_cls = env->GetObjectClass(model);
  22 +
  23 + fid = env->GetFieldID(
  24 + model_config_cls, "gtcrn",
  25 + "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;");
  26 + jobject gtcrn = env->GetObjectField(model, fid);
  27 + jclass gtcrn_cls = env->GetObjectClass(gtcrn);
  28 +
  29 + fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;");
  30 + jstring s = (jstring)env->GetObjectField(gtcrn, fid);
  31 + const char *p = env->GetStringUTFChars(s, nullptr);
  32 + ans.model.gtcrn.model = p;
  33 + env->ReleaseStringUTFChars(s, p);
  34 +
  35 + fid = env->GetFieldID(model_config_cls, "numThreads", "I");
  36 + ans.model.num_threads = env->GetIntField(model, fid);
  37 +
  38 + fid = env->GetFieldID(model_config_cls, "debug", "Z");
  39 + ans.model.debug = env->GetBooleanField(model, fid);
  40 +
  41 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  42 + s = (jstring)env->GetObjectField(model, fid);
  43 + p = env->GetStringUTFChars(s, nullptr);
  44 + ans.model.provider = p;
  45 + env->ReleaseStringUTFChars(s, p);
  46 +
  47 + return ans;
  48 +}
  49 +
  50 +} // namespace sherpa_onnx
  51 +
  52 +SHERPA_ONNX_EXTERN_C
  53 +JNIEXPORT jlong JNICALL
  54 +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset(
  55 + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
  56 +#if __ANDROID_API__ >= 9
  57 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
  58 + if (!mgr) {
  59 + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
  60 + return 0;
  61 + }
  62 +#endif
  63 + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
  64 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  65 +
  66 + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(
  67 +#if __ANDROID_API__ >= 9
  68 + mgr,
  69 +#endif
  70 + config);
  71 +
  72 + return (jlong)speech_denoiser;
  73 +}
  74 +
  75 +SHERPA_ONNX_EXTERN_C
  76 +JNIEXPORT jlong JNICALL
  77 +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env,
  78 + jobject /*obj*/,
  79 + jobject _config) {
  80 + return SafeJNI(
  81 + env, "OfflineSpeechDenoiser_newFromFile",
  82 + [&]() -> jlong {
  83 + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
  84 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
  85 +
  86 + if (!config.Validate()) {
  87 + SHERPA_ONNX_LOGE("Errors found in config!");
  88 + }
  89 +
  90 + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config);
  91 + return reinterpret_cast<jlong>(speech_denoiser);
  92 + },
  93 + 0L);
  94 +}
  95 +
  96 +SHERPA_ONNX_EXTERN_C
  97 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete(
  98 + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
  99 + delete reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
  100 +}
  101 +
  102 +SHERPA_ONNX_EXTERN_C
  103 +JNIEXPORT jint JNICALL
  104 +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/,
  105 + jobject /*obj*/,
  106 + jlong ptr) {
  107 + return reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr)
  108 + ->GetSampleRate();
  109 +}
  110 +
  111 +SHERPA_ONNX_EXTERN_C
  112 +JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run(
  113 + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
  114 + jint sample_rate) {
  115 + auto speech_denoiser =
  116 + reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
  117 +
  118 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  119 + jsize n = env->GetArrayLength(samples);
  120 + auto denoised = speech_denoiser->Run(p, n, sample_rate);
  121 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  122 +
  123 + jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio");
  124 + if (cls == nullptr) {
  125 + SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio");
  126 + return nullptr;
  127 + }
  128 +
  129 + // https://javap.yawk.at/
  130 + jmethodID constructor = env->GetMethodID(cls, "<init>", "([FI)V");
  131 + if (constructor == nullptr) {
  132 + SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio");
  133 + return nullptr;
  134 + }
  135 +
  136 + jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size());
  137 + env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(),
  138 + denoised.samples.data());
  139 +
  140 + return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate);
  141 +}
  142 +
  143 +SHERPA_ONNX_EXTERN_C
  144 +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl(
  145 + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
  146 + jint sample_rate) {
  147 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
  148 +
  149 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  150 + jsize n = env->GetArrayLength(samples);
  151 +
  152 + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
  153 +
  154 + env->ReleaseStringUTFChars(filename, p_filename);
  155 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  156 +
  157 + return ok;
  158 +}
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #include "sherpa-onnx/csrc/offline-tts.h" 5 #include "sherpa-onnx/csrc/offline-tts.h"
6 6
7 #include "sherpa-onnx/csrc/macros.h" 7 #include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/wave-writer.h"
8 #include "sherpa-onnx/jni/common.h" 9 #include "sherpa-onnx/jni/common.h"
9 10
10 namespace sherpa_onnx { 11 namespace sherpa_onnx {
@@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( @@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl(
340 341
341 return obj_arr; 342 return obj_arr;
342 } 343 }
  344 +
  345 +SHERPA_ONNX_EXTERN_C
  346 +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl(
  347 + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
  348 + jint sample_rate) {
  349 + const char *p_filename = env->GetStringUTFChars(filename, nullptr);
  350 +
  351 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  352 + jsize n = env->GetArrayLength(samples);
  353 +
  354 + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
  355 +
  356 + env->ReleaseStringUTFChars(filename, p_filename);
  357 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  358 +
  359 + return ok;
  360 +}
  1 +package com.k2fsa.sherpa.onnx
  2 +
  3 +import android.content.res.AssetManager
  4 +
  5 +data class OfflineSpeechDenoiserGtcrnModelConfig(
  6 + var model: String = "",
  7 +)
  8 +
  9 +data class OfflineSpeechDenoiserModelConfig(
  10 + var gtcrn: OfflineSpeechDenoiserGtcrnModelConfig = OfflineSpeechDenoiserGtcrnModelConfig(),
  11 + var numThreads: Int = 1,
  12 + var debug: Boolean = false,
  13 + var provider: String = "cpu",
  14 +)
  15 +
  16 +data class OfflineSpeechDenoiserConfig(
  17 + var model: OfflineSpeechDenoiserModelConfig = OfflineSpeechDenoiserModelConfig(),
  18 +)
  19 +
  20 +class DenoisedAudio(
  21 + val samples: FloatArray,
  22 + val sampleRate: Int,
  23 +) {
  24 + fun save(filename: String) =
  25 + saveImpl(filename = filename, samples = samples, sampleRate = sampleRate)
  26 +
  27 + private external fun saveImpl(
  28 + filename: String,
  29 + samples: FloatArray,
  30 + sampleRate: Int
  31 + ): Boolean
  32 +}
  33 +
  34 +class OfflineSpeechDenoiser(
  35 + assetManager: AssetManager? = null,
  36 + config: OfflineSpeechDenoiserConfig,
  37 +) {
  38 + private var ptr: Long
  39 +
  40 + init {
  41 + ptr = if (assetManager != null) {
  42 + newFromAsset(assetManager, config)
  43 + } else {
  44 + newFromFile(config)
  45 + }
  46 + }
  47 +
  48 + protected fun finalize() {
  49 + if (ptr != 0L) {
  50 + delete(ptr)
  51 + ptr = 0
  52 + }
  53 + }
  54 +
  55 + fun release() = finalize()
  56 +
  57 + fun run(samples: FloatArray, sampleRate: Int) = run(ptr, samples, sampleRate)
  58 +
  59 + val sampleRate
  60 + get() = getSampleRate(ptr)
  61 +
  62 + private external fun newFromAsset(
  63 + assetManager: AssetManager,
  64 + config: OfflineSpeechDenoiserConfig,
  65 + ): Long
  66 +
  67 + private external fun newFromFile(
  68 + config: OfflineSpeechDenoiserConfig,
  69 + ): Long
  70 +
  71 + private external fun delete(ptr: Long)
  72 +
  73 + private external fun run(ptr: Long, samples: FloatArray, sampleRate: Int): DenoisedAudio
  74 +
  75 + private external fun getSampleRate(ptr: Long): Int
  76 +
  77 + companion object {
  78 + init {
  79 + System.loadLibrary("sherpa-onnx-jni")
  80 + }
  81 + }
  82 +}