online-punctuation.cc
3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// sherpa-onnx/jni/online-punctuation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OnlinePunctuationConfig GetOnlinePunctuationConfig(JNIEnv *env,
jobject config) {
OnlinePunctuationConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OnlinePunctuationModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "cnnBilstm", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.cnn_bilstm = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromAsset(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OnlinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto model = new sherpa_onnx::OnlinePunctuation(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_delete(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OnlinePunctuation *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_addPunctuation(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring text) {
auto punct = reinterpret_cast<const sherpa_onnx::OnlinePunctuation *>(ptr);
const char *ptext = env->GetStringUTFChars(text, nullptr);
std::string result = punct->AddPunctuationWithCase(ptext);
env->ReleaseStringUTFChars(text, ptext);
return env->NewStringUTF(result.c_str());
}