Committed by
GitHub
Add Kotlin API for speech enhancement GTCRN models (#2008)
正在显示
8 个修改的文件
包含
326 行增加
和
17 行删除
kotlin-api-examples/OfflineSpeechDenoiser.kt
0 → 120000
| 1 | +../sherpa-onnx/kotlin-api/OfflineSpeechDenoiser.kt |
| @@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() { | @@ -371,6 +371,31 @@ function testOfflineSpeakerDiarization() { | ||
| 371 | java -Djava.library.path=../build/lib -jar $out_filename | 371 | java -Djava.library.path=../build/lib -jar $out_filename |
| 372 | } | 372 | } |
| 373 | 373 | ||
| 374 | +function testOfflineSpeechDenoiser() { | ||
| 375 | + if [ ! -f ./gtcrn_simple.onnx ]; then | ||
| 376 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 377 | + fi | ||
| 378 | + | ||
| 379 | + if [ ! -f ./inp_16k.wav ]; then | ||
| 380 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav | ||
| 381 | + fi | ||
| 382 | + | ||
| 383 | + out_filename=test_offline_speech_denoiser.jar | ||
| 384 | + kotlinc-jvm -include-runtime -d $out_filename \ | ||
| 385 | + test_offline_speech_denoiser.kt \ | ||
| 386 | + OfflineSpeechDenoiser.kt \ | ||
| 387 | + WaveReader.kt \ | ||
| 388 | + faked-asset-manager.kt \ | ||
| 389 | + faked-log.kt | ||
| 390 | + | ||
| 391 | + ls -lh $out_filename | ||
| 392 | + | ||
| 393 | + java -Djava.library.path=../build/lib -jar $out_filename | ||
| 394 | + | ||
| 395 | + ls -lh *.wav | ||
| 396 | +} | ||
| 397 | + | ||
| 398 | +testOfflineSpeechDenoiser | ||
| 374 | testOfflineSpeakerDiarization | 399 | testOfflineSpeakerDiarization |
| 375 | testSpeakerEmbeddingExtractor | 400 | testSpeakerEmbeddingExtractor |
| 376 | testOnlineAsr | 401 | testOnlineAsr |
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | +// Please download test files in this script from | ||
| 3 | +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models | ||
| 4 | + | ||
| 5 | +fun main() { | ||
| 6 | + test() | ||
| 7 | +} | ||
| 8 | + | ||
| 9 | +fun test() { | ||
| 10 | + val denoiser = createOfflineSpeechDenoiser() | ||
| 11 | + | ||
| 12 | + val waveFilename = "./inp_16k.wav"; | ||
| 13 | + | ||
| 14 | + val objArray = WaveReader.readWaveFromFile( | ||
| 15 | + filename = waveFilename, | ||
| 16 | + ) | ||
| 17 | + val samples: FloatArray = objArray[0] as FloatArray | ||
| 18 | + val sampleRate: Int = objArray[1] as Int | ||
| 19 | + | ||
| 20 | + val denoised = denoiser.run(samples, sampleRate); | ||
| 21 | + denoised.save(filename="./enhanced-16k.wav") | ||
| 22 | + println("saved to ./enhanced-16k.wav") | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +fun createOfflineSpeechDenoiser(): OfflineSpeechDenoiser { | ||
| 26 | + val config = OfflineSpeechDenoiserConfig( | ||
| 27 | + model = OfflineSpeechDenoiserModelConfig( | ||
| 28 | + gtcrn = OfflineSpeechDenoiserGtcrnModelConfig( | ||
| 29 | + model = "./gtcrn_simple.onnx" | ||
| 30 | + ), | ||
| 31 | + provider = "cpu", | ||
| 32 | + numThreads = 1, | ||
| 33 | + ), | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + println(config) | ||
| 37 | + | ||
| 38 | + return OfflineSpeechDenoiser(config = config) | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | + |
| @@ -16,6 +16,7 @@ set(sources | @@ -16,6 +16,7 @@ set(sources | ||
| 16 | keyword-spotter.cc | 16 | keyword-spotter.cc |
| 17 | offline-punctuation.cc | 17 | offline-punctuation.cc |
| 18 | offline-recognizer.cc | 18 | offline-recognizer.cc |
| 19 | + offline-speech-denoiser.cc | ||
| 19 | offline-stream.cc | 20 | offline-stream.cc |
| 20 | online-punctuation.cc | 21 | online-punctuation.cc |
| 21 | online-recognizer.cc | 22 | online-recognizer.cc |
| @@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) { | @@ -25,23 +25,6 @@ jobject NewFloat(JNIEnv *env, float value) { | ||
| 25 | return env->NewObject(cls, constructor, value); | 25 | return env->NewObject(cls, constructor, value); |
| 26 | } | 26 | } |
| 27 | 27 | ||
| 28 | -SHERPA_ONNX_EXTERN_C | ||
| 29 | -JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( | ||
| 30 | - JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, | ||
| 31 | - jint sample_rate) { | ||
| 32 | - const char *p_filename = env->GetStringUTFChars(filename, nullptr); | ||
| 33 | - | ||
| 34 | - jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 35 | - jsize n = env->GetArrayLength(samples); | ||
| 36 | - | ||
| 37 | - bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); | ||
| 38 | - | ||
| 39 | - env->ReleaseStringUTFChars(filename, p_filename); | ||
| 40 | - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 41 | - | ||
| 42 | - return ok; | ||
| 43 | -} | ||
| 44 | - | ||
| 45 | #if 0 | 28 | #if 0 |
| 46 | SHERPA_ONNX_EXTERN_C | 29 | SHERPA_ONNX_EXTERN_C |
| 47 | JNIEXPORT void JNICALL | 30 | JNIEXPORT void JNICALL |
sherpa-onnx/jni/offline-speech-denoiser.cc
0 → 100644
| 1 | +// sherpa-onnx/jni/offline-speech-denoiser.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 7 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 8 | +#include "sherpa-onnx/jni/common.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | +static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig( | ||
| 12 | + JNIEnv *env, jobject config) { | ||
| 13 | + OfflineSpeechDenoiserConfig ans; | ||
| 14 | + | ||
| 15 | + jclass cls = env->GetObjectClass(config); | ||
| 16 | + jfieldID fid; | ||
| 17 | + | ||
| 18 | + fid = env->GetFieldID( | ||
| 19 | + cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;"); | ||
| 20 | + jobject model = env->GetObjectField(config, fid); | ||
| 21 | + jclass model_config_cls = env->GetObjectClass(model); | ||
| 22 | + | ||
| 23 | + fid = env->GetFieldID( | ||
| 24 | + model_config_cls, "gtcrn", | ||
| 25 | + "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;"); | ||
| 26 | + jobject gtcrn = env->GetObjectField(model, fid); | ||
| 27 | + jclass gtcrn_cls = env->GetObjectClass(gtcrn); | ||
| 28 | + | ||
| 29 | + fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;"); | ||
| 30 | + jstring s = (jstring)env->GetObjectField(gtcrn, fid); | ||
| 31 | + const char *p = env->GetStringUTFChars(s, nullptr); | ||
| 32 | + ans.model.gtcrn.model = p; | ||
| 33 | + env->ReleaseStringUTFChars(s, p); | ||
| 34 | + | ||
| 35 | + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); | ||
| 36 | + ans.model.num_threads = env->GetIntField(model, fid); | ||
| 37 | + | ||
| 38 | + fid = env->GetFieldID(model_config_cls, "debug", "Z"); | ||
| 39 | + ans.model.debug = env->GetBooleanField(model, fid); | ||
| 40 | + | ||
| 41 | + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | ||
| 42 | + s = (jstring)env->GetObjectField(model, fid); | ||
| 43 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 44 | + ans.model.provider = p; | ||
| 45 | + env->ReleaseStringUTFChars(s, p); | ||
| 46 | + | ||
| 47 | + return ans; | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +} // namespace sherpa_onnx | ||
| 51 | + | ||
| 52 | +SHERPA_ONNX_EXTERN_C | ||
| 53 | +JNIEXPORT jlong JNICALL | ||
| 54 | +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset( | ||
| 55 | + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { | ||
| 56 | +#if __ANDROID_API__ >= 9 | ||
| 57 | + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); | ||
| 58 | + if (!mgr) { | ||
| 59 | + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); | ||
| 60 | + return 0; | ||
| 61 | + } | ||
| 62 | +#endif | ||
| 63 | + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config); | ||
| 64 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 65 | + | ||
| 66 | + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser( | ||
| 67 | +#if __ANDROID_API__ >= 9 | ||
| 68 | + mgr, | ||
| 69 | +#endif | ||
| 70 | + config); | ||
| 71 | + | ||
| 72 | + return (jlong)speech_denoiser; | ||
| 73 | +} | ||
| 74 | + | ||
| 75 | +SHERPA_ONNX_EXTERN_C | ||
| 76 | +JNIEXPORT jlong JNICALL | ||
| 77 | +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env, | ||
| 78 | + jobject /*obj*/, | ||
| 79 | + jobject _config) { | ||
| 80 | + return SafeJNI( | ||
| 81 | + env, "OfflineSpeechDenoiser_newFromFile", | ||
| 82 | + [&]() -> jlong { | ||
| 83 | + auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config); | ||
| 84 | + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); | ||
| 85 | + | ||
| 86 | + if (!config.Validate()) { | ||
| 87 | + SHERPA_ONNX_LOGE("Errors found in config!"); | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config); | ||
| 91 | + return reinterpret_cast<jlong>(speech_denoiser); | ||
| 92 | + }, | ||
| 93 | + 0L); | ||
| 94 | +} | ||
| 95 | + | ||
| 96 | +SHERPA_ONNX_EXTERN_C | ||
| 97 | +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete( | ||
| 98 | + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { | ||
| 99 | + delete reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr); | ||
| 100 | +} | ||
| 101 | + | ||
| 102 | +SHERPA_ONNX_EXTERN_C | ||
| 103 | +JNIEXPORT jint JNICALL | ||
| 104 | +Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/, | ||
| 105 | + jobject /*obj*/, | ||
| 106 | + jlong ptr) { | ||
| 107 | + return reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr) | ||
| 108 | + ->GetSampleRate(); | ||
| 109 | +} | ||
| 110 | + | ||
| 111 | +SHERPA_ONNX_EXTERN_C | ||
| 112 | +JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run( | ||
| 113 | + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, | ||
| 114 | + jint sample_rate) { | ||
| 115 | + auto speech_denoiser = | ||
| 116 | + reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr); | ||
| 117 | + | ||
| 118 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 119 | + jsize n = env->GetArrayLength(samples); | ||
| 120 | + auto denoised = speech_denoiser->Run(p, n, sample_rate); | ||
| 121 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 122 | + | ||
| 123 | + jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio"); | ||
| 124 | + if (cls == nullptr) { | ||
| 125 | + SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio"); | ||
| 126 | + return nullptr; | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + // https://javap.yawk.at/ | ||
| 130 | + jmethodID constructor = env->GetMethodID(cls, "<init>", "([FI)V"); | ||
| 131 | + if (constructor == nullptr) { | ||
| 132 | + SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio"); | ||
| 133 | + return nullptr; | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size()); | ||
| 137 | + env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(), | ||
| 138 | + denoised.samples.data()); | ||
| 139 | + | ||
| 140 | + return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate); | ||
| 141 | +} | ||
| 142 | + | ||
| 143 | +SHERPA_ONNX_EXTERN_C | ||
| 144 | +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl( | ||
| 145 | + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, | ||
| 146 | + jint sample_rate) { | ||
| 147 | + const char *p_filename = env->GetStringUTFChars(filename, nullptr); | ||
| 148 | + | ||
| 149 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 150 | + jsize n = env->GetArrayLength(samples); | ||
| 151 | + | ||
| 152 | + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); | ||
| 153 | + | ||
| 154 | + env->ReleaseStringUTFChars(filename, p_filename); | ||
| 155 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 156 | + | ||
| 157 | + return ok; | ||
| 158 | +} |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | #include "sherpa-onnx/csrc/offline-tts.h" | 5 | #include "sherpa-onnx/csrc/offline-tts.h" |
| 6 | 6 | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 7 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 8 | #include "sherpa-onnx/jni/common.h" | 9 | #include "sherpa-onnx/jni/common.h" |
| 9 | 10 | ||
| 10 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| @@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | @@ -340,3 +341,20 @@ Java_com_k2fsa_sherpa_onnx_OfflineTts_generateWithCallbackImpl( | ||
| 340 | 341 | ||
| 341 | return obj_arr; | 342 | return obj_arr; |
| 342 | } | 343 | } |
| 344 | + | ||
| 345 | +SHERPA_ONNX_EXTERN_C | ||
| 346 | +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_GeneratedAudio_saveImpl( | ||
| 347 | + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, | ||
| 348 | + jint sample_rate) { | ||
| 349 | + const char *p_filename = env->GetStringUTFChars(filename, nullptr); | ||
| 350 | + | ||
| 351 | + jfloat *p = env->GetFloatArrayElements(samples, nullptr); | ||
| 352 | + jsize n = env->GetArrayLength(samples); | ||
| 353 | + | ||
| 354 | + bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n); | ||
| 355 | + | ||
| 356 | + env->ReleaseStringUTFChars(filename, p_filename); | ||
| 357 | + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); | ||
| 358 | + | ||
| 359 | + return ok; | ||
| 360 | +} |
| 1 | +package com.k2fsa.sherpa.onnx | ||
| 2 | + | ||
| 3 | +import android.content.res.AssetManager | ||
| 4 | + | ||
| 5 | +data class OfflineSpeechDenoiserGtcrnModelConfig( | ||
| 6 | + var model: String = "", | ||
| 7 | +) | ||
| 8 | + | ||
| 9 | +data class OfflineSpeechDenoiserModelConfig( | ||
| 10 | + var gtcrn: OfflineSpeechDenoiserGtcrnModelConfig = OfflineSpeechDenoiserGtcrnModelConfig(), | ||
| 11 | + var numThreads: Int = 1, | ||
| 12 | + var debug: Boolean = false, | ||
| 13 | + var provider: String = "cpu", | ||
| 14 | +) | ||
| 15 | + | ||
| 16 | +data class OfflineSpeechDenoiserConfig( | ||
| 17 | + var model: OfflineSpeechDenoiserModelConfig = OfflineSpeechDenoiserModelConfig(), | ||
| 18 | +) | ||
| 19 | + | ||
| 20 | +class DenoisedAudio( | ||
| 21 | + val samples: FloatArray, | ||
| 22 | + val sampleRate: Int, | ||
| 23 | +) { | ||
| 24 | + fun save(filename: String) = | ||
| 25 | + saveImpl(filename = filename, samples = samples, sampleRate = sampleRate) | ||
| 26 | + | ||
| 27 | + private external fun saveImpl( | ||
| 28 | + filename: String, | ||
| 29 | + samples: FloatArray, | ||
| 30 | + sampleRate: Int | ||
| 31 | + ): Boolean | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +class OfflineSpeechDenoiser( | ||
| 35 | + assetManager: AssetManager? = null, | ||
| 36 | + config: OfflineSpeechDenoiserConfig, | ||
| 37 | +) { | ||
| 38 | + private var ptr: Long | ||
| 39 | + | ||
| 40 | + init { | ||
| 41 | + ptr = if (assetManager != null) { | ||
| 42 | + newFromAsset(assetManager, config) | ||
| 43 | + } else { | ||
| 44 | + newFromFile(config) | ||
| 45 | + } | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + protected fun finalize() { | ||
| 49 | + if (ptr != 0L) { | ||
| 50 | + delete(ptr) | ||
| 51 | + ptr = 0 | ||
| 52 | + } | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + fun release() = finalize() | ||
| 56 | + | ||
| 57 | + fun run(samples: FloatArray, sampleRate: Int) = run(ptr, samples, sampleRate) | ||
| 58 | + | ||
| 59 | + val sampleRate | ||
| 60 | + get() = getSampleRate(ptr) | ||
| 61 | + | ||
| 62 | + private external fun newFromAsset( | ||
| 63 | + assetManager: AssetManager, | ||
| 64 | + config: OfflineSpeechDenoiserConfig, | ||
| 65 | + ): Long | ||
| 66 | + | ||
| 67 | + private external fun newFromFile( | ||
| 68 | + config: OfflineSpeechDenoiserConfig, | ||
| 69 | + ): Long | ||
| 70 | + | ||
| 71 | + private external fun delete(ptr: Long) | ||
| 72 | + | ||
| 73 | + private external fun run(ptr: Long, samples: FloatArray, sampleRate: Int): DenoisedAudio | ||
| 74 | + | ||
| 75 | + private external fun getSampleRate(ptr: Long): Int | ||
| 76 | + | ||
| 77 | + companion object { | ||
| 78 | + init { | ||
| 79 | + System.loadLibrary("sherpa-onnx-jni") | ||
| 80 | + } | ||
| 81 | + } | ||
| 82 | +} |
-
请 注册 或 登录 后发表评论