Wilson Wongso
Committed by GitHub

Implement Tokens in Swift and Kotlin (#227)

Co-authored-by: duc <duc@appiphany.com.au>
@@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( @@ -128,15 +128,60 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
128 const auto &text = result.text; 128 const auto &text = result.text;
129 129
130 auto r = new SherpaOnnxOnlineRecognizerResult; 130 auto r = new SherpaOnnxOnlineRecognizerResult;
  131 + // copy text
131 r->text = new char[text.size() + 1]; 132 r->text = new char[text.size() + 1];
132 std::copy(text.begin(), text.end(), const_cast<char *>(r->text)); 133 std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
133 const_cast<char *>(r->text)[text.size()] = 0; 134 const_cast<char *>(r->text)[text.size()] = 0;
134 135
  136 + // copy json
  137 + const auto &json = result.AsJsonString();
  138 + r->json = new char[json.size() + 1];
  139 + std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
  140 + const_cast<char *>(r->json)[json.size()] = 0;
  141 +
  142 + // copy tokens
  143 + auto count = result.tokens.size();
  144 + if (count > 0) {
  145 + size_t total_length = 0;
  146 + for (const auto& token : result.tokens) {
  147 + // +1 for the null character at the end of each token
  148 + total_length += token.size() + 1;
  149 + }
  150 +
  151 + r->count = count;
  152 + // Each word ends with nullptr
  153 + r->tokens = new char[total_length];
  154 + memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
  155 + total_length);
  156 + r->timestamps = new float[r->count];
  157 + char **tokens_temp = new char*[r->count];
  158 + int32_t pos = 0;
  159 + for (int32_t i = 0; i < r->count; ++i) {
  160 + tokens_temp[i] = const_cast<char*>(r->tokens) + pos;
  161 + memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
  162 + result.tokens[i].c_str(), result.tokens[i].size());
  163 + // +1 to move past the null character
  164 + pos += result.tokens[i].size() + 1;
  165 + r->timestamps[i] = result.timestamps[i];
  166 + }
  167 +
  168 + r->tokens_arr = tokens_temp;
  169 + } else {
  170 + r->count = 0;
  171 + r->timestamps = nullptr;
  172 + r->tokens = nullptr;
  173 + r->tokens_arr = nullptr;
  174 + }
  175 +
135 return r; 176 return r;
136 } 177 }
137 178
138 void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) { 179 void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) {
139 delete[] r->text; 180 delete[] r->text;
  181 + delete[] r->json;
  182 + delete[] r->tokens;
  183 + delete[] r->tokens_arr;
  184 + delete[] r->timestamps;
140 delete r; 185 delete r;
141 } 186 }
142 187
@@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { @@ -101,8 +101,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
101 } SherpaOnnxOnlineRecognizerConfig; 101 } SherpaOnnxOnlineRecognizerConfig;
102 102
103 SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { 103 SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
  104 + // Recognized text
104 const char *text; 105 const char *text;
105 - // TODO(fangjun): Add more fields 106 +
  107 + // Pointer to continuous memory which holds string based tokens
  108 + // which are seperated by \0
  109 + const char *tokens;
  110 +
  111 + // a pointer array contains the address of the first item in tokens
  112 + const char *const *tokens_arr;
  113 +
  114 + // Pointer to continuous memory which holds timestamps
  115 + float *timestamps;
  116 +
  117 + // The number of tokens/timestamps in above pointer
  118 + int32_t count;
  119 +
  120 + /** Return a json string.
  121 + *
  122 + * The returned string contains:
  123 + * {
  124 + * "text": "The recognition result",
  125 + * "tokens": [x, x, x],
  126 + * "timestamps": [x, x, x],
  127 + * "segment": x,
  128 + * "start_time": x,
  129 + * "is_final": true|false
  130 + * }
  131 + */
  132 + const char *json;
106 } SherpaOnnxOnlineRecognizerResult; 133 } SherpaOnnxOnlineRecognizerResult;
107 134
108 /// Note: OnlineRecognizer here means StreamingRecognizer. 135 /// Note: OnlineRecognizer here means StreamingRecognizer.
@@ -58,6 +58,11 @@ class SherpaOnnx { @@ -58,6 +58,11 @@ class SherpaOnnx {
58 return result.text; 58 return result.text;
59 } 59 }
60 60
  61 + const std::vector<std::string> GetTokens() const {
  62 + auto result = recognizer_.GetResult(stream_.get());
  63 + return result.tokens;
  64 + }
  65 +
61 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } 66 bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
62 67
63 bool IsReady() const { return recognizer_.IsReady(stream_.get()); } 68 bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
@@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText( @@ -312,6 +317,29 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
312 return env->NewStringUTF(text.c_str()); 317 return env->NewStringUTF(text.c_str());
313 } 318 }
314 319
  320 +SHERPA_ONNX_EXTERN_C
  321 +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
  322 + JNIEnv *env, jobject /*obj*/, jlong ptr) {
  323 + auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens();
  324 + int size = tokens.size();
  325 + jclass stringClass = env->FindClass("java/lang/String");
  326 +
  327 + // convert C++ list into jni string array
  328 + jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
  329 + for (int i = 0; i < size; i++) {
  330 + // Convert the C++ string to a C string
  331 + const char* cstr = tokens[i].c_str();
  332 +
  333 + // Convert the C string to a jstring
  334 + jstring jstr = env->NewStringUTF(cstr);
  335 +
  336 + // Set the array element
  337 + env->SetObjectArrayElement(result, i, jstr);
  338 + }
  339 +
  340 + return result;
  341 +}
  342 +
315 // see 343 // see
316 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables 344 // https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
317 static jobject NewInteger(JNIEnv *env, int32_t value) { 345 static jobject NewInteger(JNIEnv *env, int32_t value) {
@@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult { @@ -99,6 +99,26 @@ class SherpaOnnxOnlineRecongitionResult {
99 return String(cString: result.pointee.text) 99 return String(cString: result.pointee.text)
100 } 100 }
101 101
  102 + var count: Int32 {
  103 + return result.pointee.count
  104 + }
  105 +
  106 + var tokens: [String] {
  107 + if let tokensPointer = result.pointee.tokens_arr {
  108 + var tokens: [String] = []
  109 + for index in 0..<count {
  110 + if let tokenPointer = tokensPointer[Int(index)] {
  111 + let token = String(cString: tokenPointer)
  112 + tokens.append(token)
  113 + }
  114 + }
  115 + return tokens
  116 + } else {
  117 + let tokens: [String] = []
  118 + return tokens
  119 + }
  120 + }
  121 +
102 init(result: UnsafePointer<SherpaOnnxOnlineRecognizerResult>!) { 122 init(result: UnsafePointer<SherpaOnnxOnlineRecognizerResult>!) {
103 self.result = result 123 self.result = result
104 } 124 }