ivan provalov
Committed by GitHub

JNI Exception Handling (#1452)

@@ -5,6 +5,8 @@ @@ -5,6 +5,8 @@
5 #ifndef SHERPA_ONNX_JNI_COMMON_H_ 5 #ifndef SHERPA_ONNX_JNI_COMMON_H_
6 #define SHERPA_ONNX_JNI_COMMON_H_ 6 #define SHERPA_ONNX_JNI_COMMON_H_
7 7
  8 +#include <string>
  9 +
8 #if __ANDROID_API__ >= 9 10 #if __ANDROID_API__ >= 9
9 #include <strstream> 11 #include <strstream>
10 12
@@ -42,4 +44,62 @@ @@ -42,4 +44,62 @@
42 jobject NewInteger(JNIEnv *env, int32_t value); 44 jobject NewInteger(JNIEnv *env, int32_t value);
43 jobject NewFloat(JNIEnv *env, float value); 45 jobject NewFloat(JNIEnv *env, float value);
44 46
  47 +// Template function for non-void return types
  48 +template <typename Func, typename ReturnType>
  49 +ReturnType SafeJNI(JNIEnv *env, const char *functionName, Func func,
  50 + ReturnType defaultValue) {
  51 + try {
  52 + return func();
  53 + } catch (const std::exception &e) {
  54 + jclass exClass = env->FindClass("java/lang/RuntimeException");
  55 + if (exClass != nullptr) {
  56 + std::string errorMessage = std::string(functionName) + ": " + e.what();
  57 + env->ThrowNew(exClass, errorMessage.c_str());
  58 + }
  59 + } catch (...) {
  60 + jclass exClass = env->FindClass("java/lang/RuntimeException");
  61 + if (exClass != nullptr) {
  62 + std::string errorMessage = std::string(functionName) +
  63 + ": Native exception: caught unknown exception";
  64 + env->ThrowNew(exClass, errorMessage.c_str());
  65 + }
  66 + }
  67 + return defaultValue;
  68 +}
  69 +
  70 +// Specialization for void return type
  71 +template <typename Func>
  72 +void SafeJNI(JNIEnv *env, const char *functionName, Func func) {
  73 + try {
  74 + func();
  75 + } catch (const std::exception &e) {
  76 + jclass exClass = env->FindClass("java/lang/RuntimeException");
  77 + if (exClass != nullptr) {
  78 + std::string errorMessage = std::string(functionName) + ": " + e.what();
  79 + env->ThrowNew(exClass, errorMessage.c_str());
  80 + }
  81 + } catch (...) {
  82 + jclass exClass = env->FindClass("java/lang/RuntimeException");
  83 + if (exClass != nullptr) {
  84 + std::string errorMessage = std::string(functionName) +
  85 + ": Native exception: caught unknown exception";
  86 + env->ThrowNew(exClass, errorMessage.c_str());
  87 + }
  88 + }
  89 +}
  90 +
  91 +// Helper function to validate JNI pointers
  92 +inline bool ValidatePointer(JNIEnv *env, jlong ptr,
  93 + const char *functionName, const char *message) {
  94 + if (ptr == 0) {
  95 + jclass exClass = env->FindClass("java/lang/NullPointerException");
  96 + if (exClass != nullptr) {
  97 + std::string errorMessage = std::string(functionName) + ": " + message;
  98 + env->ThrowNew(exClass, errorMessage.c_str());
  99 + }
  100 + return false;
  101 + }
  102 + return true;
  103 +}
  104 +
45 #endif // SHERPA_ONNX_JNI_COMMON_H_ 105 #endif // SHERPA_ONNX_JNI_COMMON_H_
@@ -353,11 +353,19 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv * /*env*/, @@ -353,11 +353,19 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv * /*env*/,
353 353
354 SHERPA_ONNX_EXTERN_C 354 SHERPA_ONNX_EXTERN_C
355 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode( 355 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode(
356 - JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong streamPtr) {  
357 - auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);  
358 - auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);  
359 -  
360 - recognizer->DecodeStream(stream); 356 + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) {
  357 + SafeJNI(env, "OfflineRecognizer_decode", [&] {
  358 + if (!ValidatePointer(env, ptr, "OfflineRecognizer_decode",
  359 + "OfflineRecognizer pointer is null.") ||
  360 + !ValidatePointer(env, streamPtr, "OfflineRecognizer_decode",
  361 + "OfflineStream pointer is null.")) {
  362 + return;
  363 + }
  364 +
  365 + auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
  366 + auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
  367 + recognizer->DecodeStream(stream);
  368 + });
361 } 369 }
362 370
363 SHERPA_ONNX_EXTERN_C 371 SHERPA_ONNX_EXTERN_C
@@ -220,16 +220,17 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( @@ -220,16 +220,17 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset(
220 SHERPA_ONNX_EXTERN_C 220 SHERPA_ONNX_EXTERN_C
221 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( 221 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
222 JNIEnv *env, jobject /*obj*/, jobject _config) { 222 JNIEnv *env, jobject /*obj*/, jobject _config) {
223 - auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);  
224 - SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());  
225 -  
226 - if (!config.Validate()) {  
227 - SHERPA_ONNX_LOGE("Errors found in config!");  
228 - } 223 + return SafeJNI(env, "OfflineTts_newFromFile", [&] -> jlong {
  224 + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
  225 + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
229 226
230 - auto tts = new sherpa_onnx::OfflineTts(config); 227 + if (!config.Validate()) {
  228 + SHERPA_ONNX_LOGE("Errors found in config!");
  229 + }
231 230
232 - return (jlong)tts; 231 + auto tts = new sherpa_onnx::OfflineTts(config);
  232 + return reinterpret_cast<jlong>(tts);
  233 + }, 0L);
233 } 234 }
234 235
235 SHERPA_ONNX_EXTERN_C 236 SHERPA_ONNX_EXTERN_C
@@ -112,14 +112,20 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv * /*env*/, @@ -112,14 +112,20 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv * /*env*/,
112 SHERPA_ONNX_EXTERN_C 112 SHERPA_ONNX_EXTERN_C
113 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform( 113 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
114 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { 114 JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
115 - auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr); 115 + SafeJNI(env, "Vad_acceptWaveform", [&] {
  116 + if (!ValidatePointer(env, ptr, "Vad_acceptWaveform",
  117 + "VoiceActivityDetector pointer is null.")) {
  118 + return;
  119 + }
116 120
117 - jfloat *p = env->GetFloatArrayElements(samples, nullptr);  
118 - jsize n = env->GetArrayLength(samples); 121 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  122 + jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  123 + jsize n = env->GetArrayLength(samples);
119 124
120 - model->AcceptWaveform(p, n); 125 + model->AcceptWaveform(p, n);
121 126
122 - env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); 127 + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
  128 + });
123 } 129 }
124 130
125 SHERPA_ONNX_EXTERN_C 131 SHERPA_ONNX_EXTERN_C
@@ -173,11 +179,17 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected( @@ -173,11 +179,17 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
173 } 179 }
174 180
175 SHERPA_ONNX_EXTERN_C 181 SHERPA_ONNX_EXTERN_C
176 -JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv * /*env*/,  
177 - jobject /*obj*/,  
178 - jlong ptr) {  
179 - auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);  
180 - model->Reset(); 182 +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(
  183 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  184 + SafeJNI(env, "Vad_reset", [&] {
  185 + if (!ValidatePointer(env, ptr, "Vad_reset",
  186 + "VoiceActivityDetector pointer is null.")) {
  187 + return;
  188 + }
  189 +
  190 + auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
  191 + model->Reset();
  192 + });
181 } 193 }
182 194
183 SHERPA_ONNX_EXTERN_C 195 SHERPA_ONNX_EXTERN_C