offline-speech-denoiser.cc
5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// sherpa-onnx/jni/offline-speech-denoiser.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig(
JNIEnv *env, jobject config) {
OfflineSpeechDenoiserConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(
cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserModelConfig;");
jobject model = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model);
fid = env->GetFieldID(
model_config_cls, "gtcrn",
"Lcom/k2fsa/sherpa/onnx/OfflineSpeechDenoiserGtcrnModelConfig;");
jobject gtcrn = env->GetObjectField(model, fid);
jclass gtcrn_cls = env->GetObjectClass(gtcrn);
fid = env->GetFieldID(gtcrn_cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(gtcrn, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.gtcrn.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromAsset(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)speech_denoiser;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
return SafeJNI(
env, "OfflineSpeechDenoiser_newFromFile",
[&]() -> jlong {
auto config = sherpa_onnx::GetOfflineSpeechDenoiserConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto speech_denoiser = new sherpa_onnx::OfflineSpeechDenoiser(config);
return reinterpret_cast<jlong>(speech_denoiser);
},
0L);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_delete(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_getSampleRate(JNIEnv * /*env*/,
jobject /*obj*/,
jlong ptr) {
return reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr)
->GetSampleRate();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobject JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeechDenoiser_run(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto speech_denoiser =
reinterpret_cast<sherpa_onnx::OfflineSpeechDenoiser *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
auto denoised = speech_denoiser->Run(p, n, sample_rate);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
jclass cls = env->FindClass("com/k2fsa/sherpa/onnx/DenoisedAudio");
if (cls == nullptr) {
SHERPA_ONNX_LOGE("Failed to get class for DenoisedAudio");
return nullptr;
}
// https://javap.yawk.at/
jmethodID constructor = env->GetMethodID(cls, "<init>", "([FI)V");
if (constructor == nullptr) {
SHERPA_ONNX_LOGE("Failed to get constructor for DenoisedAudio");
return nullptr;
}
jfloatArray samples_arr = env->NewFloatArray(denoised.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, denoised.samples.size(),
denoised.samples.data());
return env->NewObject(cls, constructor, samples_arr, denoised.sample_rate);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_onnx_DenoisedAudio_saveImpl(
JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples,
jint sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
bool ok = sherpa_onnx::WriteWave(p_filename, sample_rate, p, n);
env->ReleaseStringUTFChars(filename, p_filename);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
return ok;
}