Fangjun Kuang
Committed by GitHub

Set is_final and start_time for online websocket server. (#342)

* Set is_final and start_time for online websocket server.

* Convert timestamps to a json array
@@ -174,8 +174,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) @@ -174,8 +174,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
174 include(asio) 174 include(asio)
175 endif() 175 endif()
176 176
177 -include(json)  
178 -  
179 add_subdirectory(sherpa-onnx) 177 add_subdirectory(sherpa-onnx)
180 178
181 if(SHERPA_ONNX_ENABLE_C_API) 179 if(SHERPA_ONNX_ENABLE_C_API)
1 -function(download_json)  
2 - include(FetchContent)  
3 -  
4 - set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")  
5 - set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")  
6 - set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")  
7 -  
8 - # If you don't have access to the Internet,  
9 - # please pre-download json  
10 - set(possible_file_locations  
11 - $ENV{HOME}/Downloads/json-3.11.2.tar.gz  
12 - ${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz  
13 - ${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz  
14 - /tmp/json-3.11.2.tar.gz  
15 - /star-fj/fangjun/download/github/json-3.11.2.tar.gz  
16 - )  
17 -  
18 - foreach(f IN LISTS possible_file_locations)  
19 - if(EXISTS ${f})  
20 - set(json_URL "${f}")  
21 - file(TO_CMAKE_PATH "${json_URL}" json_URL)  
22 - message(STATUS "Found local downloaded json: ${json_URL}")  
23 - set(json_URL2)  
24 - break()  
25 - endif()  
26 - endforeach()  
27 -  
28 - FetchContent_Declare(json  
29 - URL  
30 - ${json_URL}  
31 - ${json_URL2}  
32 - URL_HASH ${json_HASH}  
33 - )  
34 -  
35 - FetchContent_GetProperties(json)  
36 - if(NOT json_POPULATED)  
37 - message(STATUS "Downloading json from ${json_URL}")  
38 - FetchContent_Populate(json)  
39 - endif()  
40 - message(STATUS "json is downloaded to ${json_SOURCE_DIR}")  
41 - include_directories(${json_SOURCE_DIR}/include)  
42 - # Use #include "nlohmann/json.hpp"  
43 -endfunction()  
44 -  
45 -download_json()  
@@ -28,197 +28,6 @@ @@ -28,197 +28,6 @@
28 28
29 namespace sherpa_onnx { 29 namespace sherpa_onnx {
30 30
31 -static std::string FixInvalidUtf8(const std::string &s) {  
32 - int32_t s_size = s.size();  
33 -  
34 - std::string ans;  
35 - ans.reserve(s_size);  
36 -  
37 - for (int32_t i = 0; i < s_size;) {  
38 - uint8_t c = s[i];  
39 - if (c < 0x80) {  
40 - // valid  
41 - ans.append(1, c);  
42 - ++i;  
43 - continue;  
44 - } else if ((c >= 0xc0) && (c < 0xe0)) {  
45 - // beginning of two bytes  
46 - if ((i + 1) > (s_size - 1)) {  
47 - // no subsequent byte. invalid!  
48 - i += 1;  
49 - continue;  
50 - }  
51 - uint8_t next = s[i + 1];  
52 - if (!(next >= 0x80 && next < 0xc0)) {  
53 - // invalid  
54 - i += 1;  
55 - continue;  
56 - }  
57 - // valid 2-byte utf-8  
58 - ans.append(1, c);  
59 - ans.append(1, next);  
60 - i += 2;  
61 - continue;  
62 - } else if ((c >= 0xe0) && (c < 0xf0)) {  
63 - // beginning of 3 bytes  
64 - if ((i + 2) > (s_size - 1)) {  
65 - // invalid  
66 - i += 1;  
67 - continue;  
68 - }  
69 -  
70 - uint8_t next = s[i + 1];  
71 - if (!(next >= 0x80 && next < 0xc0)) {  
72 - // invalid  
73 - i += 1;  
74 - continue;  
75 - }  
76 -  
77 - uint8_t next2 = s[i + 2];  
78 - if (!(next2 >= 0x80 && next2 < 0xc0)) {  
79 - // invalid  
80 - i += 1;  
81 - continue;  
82 - }  
83 -  
84 - ans.append(1, c);  
85 - ans.append(1, next);  
86 - ans.append(1, next2);  
87 - i += 3;  
88 - continue;  
89 - } else if ((c >= 0xf0) && (c < 0xf8)) {  
90 - // 4 bytes  
91 - if ((i + 3) > (s_size - 1)) {  
92 - // invalid  
93 - i += 1;  
94 - continue;  
95 - }  
96 -  
97 - uint8_t next = s[i + 1];  
98 - if (!(next >= 0x80 && next < 0xc0)) {  
99 - // invalid  
100 - i += 1;  
101 - continue;  
102 - }  
103 -  
104 - uint8_t next2 = s[i + 2];  
105 - if (!(next2 >= 0x80 && next2 < 0xc0)) {  
106 - // invalid  
107 - i += 1;  
108 - continue;  
109 - }  
110 -  
111 - uint8_t next3 = s[i + 3];  
112 - if (!(next3 >= 0x80 && next3 < 0xc0)) {  
113 - // invalid  
114 - i += 1;  
115 - continue;  
116 - }  
117 - ans.append(1, c);  
118 - ans.append(1, next);  
119 - ans.append(1, next2);  
120 - ans.append(1, next3);  
121 - i += 4;  
122 - continue;  
123 - } else if ((c >= 0xf8) && (c < 0xfc)) {  
124 - // 5 bytes  
125 - if ((i + 4) > (s_size - 1)) {  
126 - // invalid  
127 - i += 1;  
128 - continue;  
129 - }  
130 -  
131 - uint8_t next = s[i + 1];  
132 - if (!(next >= 0x80 && next < 0xc0)) {  
133 - // invalid  
134 - i += 1;  
135 - continue;  
136 - }  
137 -  
138 - uint8_t next2 = s[i + 2];  
139 - if (!(next2 >= 0x80 && next2 < 0xc0)) {  
140 - // invalid  
141 - i += 1;  
142 - continue;  
143 - }  
144 -  
145 - uint8_t next3 = s[i + 3];  
146 - if (!(next3 >= 0x80 && next3 < 0xc0)) {  
147 - // invalid  
148 - i += 1;  
149 - continue;  
150 - }  
151 -  
152 - uint8_t next4 = s[i + 4];  
153 - if (!(next4 >= 0x80 && next4 < 0xc0)) {  
154 - // invalid  
155 - i += 1;  
156 - continue;  
157 - }  
158 - ans.append(1, c);  
159 - ans.append(1, next);  
160 - ans.append(1, next2);  
161 - ans.append(1, next3);  
162 - ans.append(1, next4);  
163 - i += 5;  
164 - continue;  
165 - } else if ((c >= 0xfc) && (c < 0xfe)) {  
166 - // 6 bytes  
167 - if ((i + 5) > (s_size - 1)) {  
168 - // invalid  
169 - i += 1;  
170 - continue;  
171 - }  
172 -  
173 - uint8_t next = s[i + 1];  
174 - if (!(next >= 0x80 && next < 0xc0)) {  
175 - // invalid  
176 - i += 1;  
177 - continue;  
178 - }  
179 -  
180 - uint8_t next2 = s[i + 2];  
181 - if (!(next2 >= 0x80 && next2 < 0xc0)) {  
182 - // invalid  
183 - i += 1;  
184 - continue;  
185 - }  
186 -  
187 - uint8_t next3 = s[i + 3];  
188 - if (!(next3 >= 0x80 && next3 < 0xc0)) {  
189 - // invalid  
190 - i += 1;  
191 - continue;  
192 - }  
193 -  
194 - uint8_t next4 = s[i + 4];  
195 - if (!(next4 >= 0x80 && next4 < 0xc0)) {  
196 - // invalid  
197 - i += 1;  
198 - continue;  
199 - }  
200 -  
201 - uint8_t next5 = s[i + 5];  
202 - if (!(next5 >= 0x80 && next5 < 0xc0)) {  
203 - // invalid  
204 - i += 1;  
205 - continue;  
206 - }  
207 - ans.append(1, c);  
208 - ans.append(1, next);  
209 - ans.append(1, next2);  
210 - ans.append(1, next3);  
211 - ans.append(1, next4);  
212 - ans.append(1, next5);  
213 - i += 6;  
214 - continue;  
215 - } else {  
216 - i += 1;  
217 - }  
218 - }  
219 - return ans;  
220 -}  
221 -  
222 static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, 31 static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
223 const SymbolTable &sym_table) { 32 const SymbolTable &sym_table) {
224 OfflineRecognitionResult r; 33 OfflineRecognitionResult r;
@@ -235,19 +44,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, @@ -235,19 +44,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
235 r.tokens.push_back(s); 44 r.tokens.push_back(s);
236 } 45 }
237 46
238 - // TODO(fangjun): Fix the following error in offline-stream.cc  
239 - //  
240 - // j["text"] = text;  
241 -  
242 - // libc++abi: terminating with uncaught exception of type  
243 - // nlohmann::json_abi_v3_11_2::detail::type_error:  
244 - // [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86  
245 -  
246 -#if 0  
247 - r.text = FixInvalidUtf8(text);  
248 -#else  
249 r.text = text; 47 r.text = text;
250 -#endif  
251 48
252 return r; 49 return r;
253 } 50 }
@@ -267,14 +267,14 @@ std::string OfflineRecognitionResult::AsJsonString() const { @@ -267,14 +267,14 @@ std::string OfflineRecognitionResult::AsJsonString() const {
267 << "timestamps" 267 << "timestamps"
268 << "\"" 268 << "\""
269 << ": "; 269 << ": ";
270 - os << "\"["; 270 + os << "[";
271 271
272 std::string sep = ""; 272 std::string sep = "";
273 for (auto t : timestamps) { 273 for (auto t : timestamps) {
274 os << sep << std::fixed << std::setprecision(2) << t; 274 os << sep << std::fixed << std::setprecision(2) << t;
275 sep = ", "; 275 sep = ", ";
276 } 276 }
277 - os << "]\", "; 277 + os << "], ";
278 278
279 os << "\"" 279 os << "\""
280 << "tokens" 280 << "tokens"
@@ -28,9 +28,10 @@ namespace sherpa_onnx { @@ -28,9 +28,10 @@ namespace sherpa_onnx {
28 28
29 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, 29 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
30 const SymbolTable &sym_table, 30 const SymbolTable &sym_table,
31 - int32_t frame_shift_ms, 31 + float frame_shift_ms,
32 int32_t subsampling_factor, 32 int32_t subsampling_factor,
33 - int32_t segment) { 33 + int32_t segment,
  34 + int32_t frames_since_start) {
34 OnlineRecognizerResult r; 35 OnlineRecognizerResult r;
35 r.tokens.reserve(src.tokens.size()); 36 r.tokens.reserve(src.tokens.size());
36 r.timestamps.reserve(src.tokens.size()); 37 r.timestamps.reserve(src.tokens.size());
@@ -49,6 +50,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -49,6 +50,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
49 } 50 }
50 51
51 r.segment = segment; 52 r.segment = segment;
  53 + r.start_time = frames_since_start * frame_shift_ms / 1000.;
52 54
53 return r; 55 return r;
54 } 56 }
@@ -216,7 +218,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -216,7 +218,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
216 int32_t frame_shift_ms = 10; 218 int32_t frame_shift_ms = 10;
217 int32_t subsampling_factor = 4; 219 int32_t subsampling_factor = 4;
218 return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, 220 return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
219 - s->GetCurrentSegment()); 221 + s->GetCurrentSegment(), s->GetNumFramesSinceStart());
220 } 222 }
221 223
222 bool IsEndpoint(OnlineStream *s) const override { 224 bool IsEndpoint(OnlineStream *s) const override {
@@ -14,37 +14,61 @@ @@ -14,37 +14,61 @@
14 #include <utility> 14 #include <utility>
15 #include <vector> 15 #include <vector>
16 16
17 -#include "nlohmann/json.hpp"  
18 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 17 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
19 18
20 namespace sherpa_onnx { 19 namespace sherpa_onnx {
21 20
22 std::string OnlineRecognizerResult::AsJsonString() const { 21 std::string OnlineRecognizerResult::AsJsonString() const {
23 - using json = nlohmann::json;  
24 - json j;  
25 - j["text"] = text;  
26 - j["tokens"] = tokens;  
27 - j["start_time"] = start_time;  
28 -#if 1  
29 - // This branch chooses number of decimal points to keep in  
30 - // the return json string  
31 std::ostringstream os; 22 std::ostringstream os;
  23 + os << "{";
  24 + os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
  25 + os << "\"segment\":" << segment << ", ";
  26 + os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
  27 + << ", ";
  28 +
  29 + os << "\"text\""
  30 + << ": ";
  31 + os << "\"" << text << "\""
  32 + << ", ";
  33 +
  34 + os << "\""
  35 + << "timestamps"
  36 + << "\""
  37 + << ": ";
32 os << "["; 38 os << "[";
  39 +
33 std::string sep = ""; 40 std::string sep = "";
34 for (auto t : timestamps) { 41 for (auto t : timestamps) {
35 os << sep << std::fixed << std::setprecision(2) << t; 42 os << sep << std::fixed << std::setprecision(2) << t;
36 sep = ", "; 43 sep = ", ";
37 } 44 }
38 - os << "]";  
39 - j["timestamps"] = os.str();  
40 -#else  
41 - j["timestamps"] = timestamps;  
42 -#endif 45 + os << "], ";
  46 +
  47 + os << "\""
  48 + << "tokens"
  49 + << "\""
  50 + << ":";
  51 + os << "[";
43 52
44 - j["segment"] = segment;  
45 - j["is_final"] = is_final; 53 + sep = "";
  54 + auto oldFlags = os.flags();
  55 + for (const auto &t : tokens) {
  56 + if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
  57 + const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
  58 + os << sep << "\""
  59 + << "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
  60 + << ">"
  61 + << "\"";
  62 + os.flags(oldFlags);
  63 + } else {
  64 + os << sep << "\"" << t << "\"";
  65 + }
  66 + sep = ", ";
  67 + }
  68 + os << "]";
  69 + os << "}";
46 70
47 - return j.dump(); 71 + return os.str();
48 } 72 }
49 73
50 void OnlineRecognizerConfig::Register(ParseOptions *po) { 74 void OnlineRecognizerConfig::Register(ParseOptions *po) {
@@ -44,11 +44,11 @@ struct OnlineRecognizerResult { @@ -44,11 +44,11 @@ struct OnlineRecognizerResult {
44 /// When an endpoint is detected, it is incremented 44 /// When an endpoint is detected, it is incremented
45 int32_t segment = 0; 45 int32_t segment = 0;
46 46
47 - /// Starting frame of this segment. 47 + /// Starting time of this segment.
48 /// When an endpoint is detected, it will change 48 /// When an endpoint is detected, it will change
49 float start_time = 0; 49 float start_time = 0;
50 50
51 - /// True if this is the last segment. 51 + /// True if the end of this segment is reached
52 bool is_final = false; 52 bool is_final = false;
53 53
54 /** Return a json string. 54 /** Return a json string.
@@ -43,6 +43,8 @@ class OnlineStream::Impl { @@ -43,6 +43,8 @@ class OnlineStream::Impl {
43 43
44 int32_t &GetNumProcessedFrames() { return num_processed_frames_; } 44 int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
45 45
  46 + int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
  47 +
46 int32_t &GetCurrentSegment() { return segment_; } 48 int32_t &GetCurrentSegment() { return segment_; }
47 49
48 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } 50 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
@@ -126,6 +128,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { @@ -126,6 +128,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
126 return impl_->GetNumProcessedFrames(); 128 return impl_->GetNumProcessedFrames();
127 } 129 }
128 130
  131 +int32_t OnlineStream::GetNumFramesSinceStart() const {
  132 + return impl_->GetNumFramesSinceStart();
  133 +}
  134 +
129 int32_t &OnlineStream::GetCurrentSegment() { 135 int32_t &OnlineStream::GetCurrentSegment() {
130 return impl_->GetCurrentSegment(); 136 return impl_->GetCurrentSegment();
131 } 137 }
@@ -66,7 +66,9 @@ class OnlineStream { @@ -66,7 +66,9 @@ class OnlineStream {
66 // Initially, it is 0. It is always less than NumFramesReady(). 66 // Initially, it is 0. It is always less than NumFramesReady().
67 // 67 //
68 // The returned reference is valid as long as this object is alive. 68 // The returned reference is valid as long as this object is alive.
69 - int32_t &GetNumProcessedFrames(); 69 + int32_t &GetNumProcessedFrames(); // It's reset after calling Reset()
  70 +
  71 + int32_t GetNumFramesSinceStart() const;
70 72
71 int32_t &GetCurrentSegment(); 73 int32_t &GetCurrentSegment();
72 74
@@ -195,9 +195,14 @@ void OnlineWebsocketDecoder::Decode() { @@ -195,9 +195,14 @@ void OnlineWebsocketDecoder::Decode() {
195 for (auto c : c_vec) { 195 for (auto c : c_vec) {
196 auto result = recognizer_->GetResult(c->s.get()); 196 auto result = recognizer_->GetResult(c->s.get());
197 if (recognizer_->IsEndpoint(c->s.get())) { 197 if (recognizer_->IsEndpoint(c->s.get())) {
  198 + result.is_final = true;
198 recognizer_->Reset(c->s.get()); 199 recognizer_->Reset(c->s.get());
199 } 200 }
200 201
  202 + if (!recognizer_->IsReady(c->s.get()) && c->eof) {
  203 + result.is_final = true;
  204 + }
  205 +
201 asio::post(server_->GetConnectionContext(), 206 asio::post(server_->GetConnectionContext(),
202 [this, hdl = c->hdl, str = result.AsJsonString()]() { 207 [this, hdl = c->hdl, str = result.AsJsonString()]() {
203 server_->Send(hdl, str); 208 server_->Send(hdl, str);