jni.cc
11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
// sherpa-onnx/jni/jni.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
// 2022 Pingfeng Luo
// TODO(fangjun): Add documentation to functions/methods in this file
// and also show how to use them with kotlin, possibly with java.
// If you use ndk, you can find "jni.h" inside
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
#include "jni.h" // NOLINT
#include <strstream>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#else
#include <fstream>
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#define SHERPA_ONNX_EXTERN_C extern "C"
namespace sherpa_onnx {
class SherpaOnnx {
public:
SherpaOnnx(
#if __ANDROID_API__ >= 9
AAssetManager *mgr,
#endif
const sherpa_onnx::OnlineRecognizerConfig &config)
: recognizer_(
#if __ANDROID_API__ >= 9
mgr,
#endif
config),
stream_(recognizer_.CreateStream()) {
}
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
if (input_sample_rate_ == -1) {
input_sample_rate_ = sample_rate;
}
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
}
const std::string GetText() const {
auto result = recognizer_.GetResult(stream_.get());
return result.text;
}
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); }
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
private:
sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
int32_t input_sample_rate_ = -1;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
//---------- enable endpoint ----------
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
ans.enable_endpoint = env->GetBooleanField(config, fid);
//---------- endpoint_config ----------
fid = env->GetFieldID(cls, "endpointConfig",
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
jobject endpoint_config = env->GetObjectField(config, fid);
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
fid = env->GetFieldID(endpoint_config_cls, "rule1",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule1 = env->GetObjectField(endpoint_config, fid);
jclass rule_class = env->GetObjectClass(rule1);
fid = env->GetFieldID(endpoint_config_cls, "rule2",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule2 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(endpoint_config_cls, "rule3",
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
jobject rule3 = env->GetObjectField(endpoint_config, fid);
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
ans.endpoint_config.rule1.must_contain_nonsilence =
env->GetBooleanField(rule1, fid);
ans.endpoint_config.rule2.must_contain_nonsilence =
env->GetBooleanField(rule2, fid);
ans.endpoint_config.rule3.must_contain_nonsilence =
env->GetBooleanField(rule3, fid);
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
ans.endpoint_config.rule1.min_trailing_silence =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_trailing_silence =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_trailing_silence =
env->GetFloatField(rule3, fid);
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
ans.endpoint_config.rule1.min_utterance_length =
env->GetFloatField(rule1, fid);
ans.endpoint_config.rule2.min_utterance_length =
env->GetFloatField(rule2, fid);
ans.endpoint_config.rule3.min_utterance_length =
env->GetFloatField(rule3, fid);
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.decoder_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.joiner_filename = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsEndpoint();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
// see
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
return env->NewStringUTF(text.c_str());
}
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
exit(-1);
}
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
std::istrstream is(buffer.data(), buffer.size());
#else
std::ifstream is(p_filename, std::ios::binary);
#endif
bool is_ok = false;
int32_t sampling_rate = -1;
std::vector<float> samples =
sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
env->ReleaseStringUTFChars(filename, p_filename);
if (!is_ok) {
SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
exit(-1);
}
jfloatArray ans = env->NewFloatArray(samples.size());
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, ans);
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
return obj_arr;
}