Committed by
GitHub
Implement Tokens in Swift and Kotlin (#227)
Co-authored-by: duc <duc@appiphany.com.au>
正在显示
4 个修改的文件
包含
121 行增加
和
1 行删除
| @@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( | @@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( | ||
| 128 | const auto &text = result.text; | 128 | const auto &text = result.text; |
| 129 | 129 | ||
| 130 | auto r = new SherpaOnnxOnlineRecognizerResult; | 130 | auto r = new SherpaOnnxOnlineRecognizerResult; |
| 131 | + // copy text | ||
| 131 | r->text = new char[text.size() + 1]; | 132 | r->text = new char[text.size() + 1]; |
| 132 | std::copy(text.begin(), text.end(), const_cast<char *>(r->text)); | 133 | std::copy(text.begin(), text.end(), const_cast<char *>(r->text)); |
| 133 | const_cast<char *>(r->text)[text.size()] = 0; | 134 | const_cast<char *>(r->text)[text.size()] = 0; |
| 134 | 135 | ||
| 136 | + // copy json | ||
| 137 | + const auto &json = result.AsJsonString(); | ||
| 138 | + r->json = new char[json.size() + 1]; | ||
| 139 | + std::copy(json.begin(), json.end(), const_cast<char *>(r->json)); | ||
| 140 | + const_cast<char *>(r->json)[json.size()] = 0; | ||
| 141 | + | ||
| 142 | + // copy tokens | ||
| 143 | + auto count = result.tokens.size(); | ||
| 144 | + if (count > 0) { | ||
| 145 | + size_t total_length = 0; | ||
| 146 | + for (const auto& token : result.tokens) { | ||
| 147 | + // +1 for the null character at the end of each token | ||
| 148 | + total_length += token.size() + 1; | ||
| 149 | + } | ||
| 150 | + | ||
| 151 | + r->count = count; | ||
| 152 | + // Each word ends with nullptr | ||
| 153 | + r->tokens = new char[total_length]; | ||
| 154 | + memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0, | ||
| 155 | + total_length); | ||
| 156 | + r->timestamps = new float[r->count]; | ||
| 157 | + char **tokens_temp = new char*[r->count]; | ||
| 158 | + int32_t pos = 0; | ||
| 159 | + for (int32_t i = 0; i < r->count; ++i) { | ||
| 160 | + tokens_temp[i] = const_cast<char*>(r->tokens) + pos; | ||
| 161 | + memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)), | ||
| 162 | + result.tokens[i].c_str(), result.tokens[i].size()); | ||
| 163 | + // +1 to move past the null character | ||
| 164 | + pos += result.tokens[i].size() + 1; | ||
| 165 | + r->timestamps[i] = result.timestamps[i]; | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + r->tokens_arr = tokens_temp; | ||
| 169 | + } else { | ||
| 170 | + r->count = 0; | ||
| 171 | + r->timestamps = nullptr; | ||
| 172 | + r->tokens = nullptr; | ||
| 173 | + r->tokens_arr = nullptr; | ||
| 174 | + } | ||
| 175 | + | ||
| 135 | return r; | 176 | return r; |
| 136 | } | 177 | } |
| 137 | 178 | ||
| 138 | void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { | 179 | void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { |
| 139 | delete[] r->text; | 180 | delete[] r->text; |
| 181 | + delete[] r->json; | ||
| 182 | + delete[] r->tokens; | ||
| 183 | + delete[] r->tokens_arr; | ||
| 184 | + delete[] r->timestamps; | ||
| 140 | delete r; | 185 | delete r; |
| 141 | } | 186 | } |
| 142 | 187 |
| @@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { | @@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { | ||
| 101 | } SherpaOnnxOnlineRecognizerConfig; | 101 | } SherpaOnnxOnlineRecognizerConfig; |
| 102 | 102 | ||
| 103 | SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { | 103 | SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { |
| 104 | + // Recognized text | ||
| 104 | const char *text; | 105 | const char *text; |
| 105 | - // TODO(fangjun): Add more fields | 106 | + |
| 107 | + // Pointer to continuous memory which holds string based tokens | ||
| 108 | + // which are seperated by \0 | ||
| 109 | + const char *tokens; | ||
| 110 | + | ||
| 111 | + // a pointer array contains the address of the first item in tokens | ||
| 112 | + const char *const *tokens_arr; | ||
| 113 | + | ||
| 114 | + // Pointer to continuous memory which holds timestamps | ||
| 115 | + float *timestamps; | ||
| 116 | + | ||
| 117 | + // The number of tokens/timestamps in above pointer | ||
| 118 | + int32_t count; | ||
| 119 | + | ||
| 120 | + /** Return a json string. | ||
| 121 | + * | ||
| 122 | + * The returned string contains: | ||
| 123 | + * { | ||
| 124 | + * "text": "The recognition result", | ||
| 125 | + * "tokens": [x, x, x], | ||
| 126 | + * "timestamps": [x, x, x], | ||
| 127 | + * "segment": x, | ||
| 128 | + * "start_time": x, | ||
| 129 | + * "is_final": true|false | ||
| 130 | + * } | ||
| 131 | + */ | ||
| 132 | + const char *json; | ||
| 106 | } SherpaOnnxOnlineRecognizerResult; | 133 | } SherpaOnnxOnlineRecognizerResult; |
| 107 | 134 | ||
| 108 | /// Note: OnlineRecognizer here means StreamingRecognizer. | 135 | /// Note: OnlineRecognizer here means StreamingRecognizer. |
| @@ -58,6 +58,11 @@ class SherpaOnnx { | @@ -58,6 +58,11 @@ class SherpaOnnx { | ||
| 58 | return result.text; | 58 | return result.text; |
| 59 | } | 59 | } |
| 60 | 60 | ||
| 61 | + const std::vector<std::string> GetTokens() const { | ||
| 62 | + auto result = recognizer_.GetResult(stream_.get()); | ||
| 63 | + return result.tokens; | ||
| 64 | + } | ||
| 65 | + | ||
| 61 | bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } | 66 | bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } |
| 62 | 67 | ||
| 63 | bool IsReady() const { return recognizer_.IsReady(stream_.get()); } | 68 | bool IsReady() const { return recognizer_.IsReady(stream_.get()); } |
| @@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( | @@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( | ||
| 312 | return env->NewStringUTF(text.c_str()); | 317 | return env->NewStringUTF(text.c_str()); |
| 313 | } | 318 | } |
| 314 | 319 | ||
| 320 | +SHERPA_ONNX_EXTERN_C | ||
| 321 | +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( | ||
| 322 | + JNIEnv *env, jobject /*obj*/, jlong ptr) { | ||
| 323 | + auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens(); | ||
| 324 | + int size = tokens.size(); | ||
| 325 | + jclass stringClass = env->FindClass("java/lang/String"); | ||
| 326 | + | ||
| 327 | + // convert C++ list into jni string array | ||
| 328 | + jobjectArray result = env->NewObjectArray(size, stringClass, NULL); | ||
| 329 | + for (int i = 0; i < size; i++) { | ||
| 330 | + // Convert the C++ string to a C string | ||
| 331 | + const char* cstr = tokens[i].c_str(); | ||
| 332 | + | ||
| 333 | + // Convert the C string to a jstring | ||
| 334 | + jstring jstr = env->NewStringUTF(cstr); | ||
| 335 | + | ||
| 336 | + // Set the array element | ||
| 337 | + env->SetObjectArrayElement(result, i, jstr); | ||
| 338 | + } | ||
| 339 | + | ||
| 340 | + return result; | ||
| 341 | +} | ||
| 342 | + | ||
| 315 | // see | 343 | // see |
| 316 | // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables | 344 | // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables |
| 317 | static jobject NewInteger(JNIEnv *env, int32_t value) { | 345 | static jobject NewInteger(JNIEnv *env, int32_t value) { |
| @@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult { | @@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult { | ||
| 99 | return String(cString: result.pointee.text) | 99 | return String(cString: result.pointee.text) |
| 100 | } | 100 | } |
| 101 | 101 | ||
| 102 | + var count: Int32 { | ||
| 103 | + return result.pointee.count | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + var tokens: [String] { | ||
| 107 | + if let tokensPointer = result.pointee.tokens_arr { | ||
| 108 | + var tokens: [String] = [] | ||
| 109 | + for index in 0..<count { | ||
| 110 | + if let tokenPointer = tokensPointer[Int(index)] { | ||
| 111 | + let token = String(cString: tokenPointer) | ||
| 112 | + tokens.append(token) | ||
| 113 | + } | ||
| 114 | + } | ||
| 115 | + return tokens | ||
| 116 | + } else { | ||
| 117 | + let tokens: [String] = [] | ||
| 118 | + return tokens | ||
| 119 | + } | ||
| 120 | + } | ||
| 121 | + | ||
| 102 | init(result: UnsafePointer<SherpaOnnxOnlineRecognizerResult>!) { | 122 | init(result: UnsafePointer<SherpaOnnxOnlineRecognizerResult>!) { |
| 103 | self.result = result | 123 | self.result = result |
| 104 | } | 124 | } |
-
请 注册 或 登录 后发表评论