Fangjun Kuang
Committed by GitHub

Set is_final and start_time for online websocket server. (#342)

* Set is_final and start_time for online websocket server.

* Convert timestamps to a json array
... ... @@ -174,8 +174,6 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(asio)
endif()
include(json)
add_subdirectory(sherpa-onnx)
if(SHERPA_ONNX_ENABLE_C_API)
... ...
function(download_json)
include(FetchContent)
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
# If you don't have access to the Internet,
# please pre-download json
set(possible_file_locations
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
/tmp/json-3.11.2.tar.gz
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(json_URL "${f}")
file(TO_CMAKE_PATH "${json_URL}" json_URL)
message(STATUS "Found local downloaded json: ${json_URL}")
set(json_URL2)
break()
endif()
endforeach()
FetchContent_Declare(json
URL
${json_URL}
${json_URL2}
URL_HASH ${json_HASH}
)
FetchContent_GetProperties(json)
if(NOT json_POPULATED)
message(STATUS "Downloading json from ${json_URL}")
FetchContent_Populate(json)
endif()
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
include_directories(${json_SOURCE_DIR}/include)
# Use #include "nlohmann/json.hpp"
endfunction()
download_json()
... ... @@ -28,197 +28,6 @@
namespace sherpa_onnx {
static std::string FixInvalidUtf8(const std::string &s) {
int32_t s_size = s.size();
std::string ans;
ans.reserve(s_size);
for (int32_t i = 0; i < s_size;) {
uint8_t c = s[i];
if (c < 0x80) {
// valid
ans.append(1, c);
++i;
continue;
} else if ((c >= 0xc0) && (c < 0xe0)) {
// beginning of two bytes
if ((i + 1) > (s_size - 1)) {
// no subsequent byte. invalid!
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
// valid 2-byte utf-8
ans.append(1, c);
ans.append(1, next);
i += 2;
continue;
} else if ((c >= 0xe0) && (c < 0xf0)) {
// beginning of 3 bytes
if ((i + 2) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
i += 3;
continue;
} else if ((c >= 0xf0) && (c < 0xf8)) {
// 4 bytes
if ((i + 3) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
i += 4;
continue;
} else if ((c >= 0xf8) && (c < 0xfc)) {
// 5 bytes
if ((i + 4) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
i += 5;
continue;
} else if ((c >= 0xfc) && (c < 0xfe)) {
// 6 bytes
if ((i + 5) > (s_size - 1)) {
// invalid
i += 1;
continue;
}
uint8_t next = s[i + 1];
if (!(next >= 0x80 && next < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next2 = s[i + 2];
if (!(next2 >= 0x80 && next2 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next3 = s[i + 3];
if (!(next3 >= 0x80 && next3 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next4 = s[i + 4];
if (!(next4 >= 0x80 && next4 < 0xc0)) {
// invalid
i += 1;
continue;
}
uint8_t next5 = s[i + 5];
if (!(next5 >= 0x80 && next5 < 0xc0)) {
// invalid
i += 1;
continue;
}
ans.append(1, c);
ans.append(1, next);
ans.append(1, next2);
ans.append(1, next3);
ans.append(1, next4);
ans.append(1, next5);
i += 6;
continue;
} else {
i += 1;
}
}
return ans;
}
static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
... ... @@ -235,19 +44,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
r.tokens.push_back(s);
}
// TODO(fangjun): Fix the following error in offline-stream.cc
//
// j["text"] = text;
// libc++abi: terminating with uncaught exception of type
// nlohmann::json_abi_v3_11_2::detail::type_error:
// [json.exception.type_error.316] incomplete UTF-8 string; last byte: 0x86
#if 0
r.text = FixInvalidUtf8(text);
#else
r.text = text;
#endif
return r;
}
... ...
... ... @@ -267,14 +267,14 @@ std::string OfflineRecognitionResult::AsJsonString() const {
<< "timestamps"
<< "\""
<< ": ";
os << "\"[";
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]\", ";
os << "], ";
os << "\""
<< "tokens"
... ...
... ... @@ -28,9 +28,10 @@ namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
int32_t frame_shift_ms,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment) {
int32_t segment,
int32_t frames_since_start) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
... ... @@ -49,6 +50,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
}
r.segment = segment;
r.start_time = frames_since_start * frame_shift_ms / 1000.;
return r;
}
... ... @@ -216,7 +218,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment());
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
}
bool IsEndpoint(OnlineStream *s) const override {
... ...
... ... @@ -14,37 +14,61 @@
#include <utility>
#include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
using json = nlohmann::json;
json j;
j["text"] = text;
j["tokens"] = tokens;
j["start_time"] = start_time;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std::ostringstream os;
os << "{";
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
os << "\"segment\":" << segment << ", ";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";
os << "\"text\""
<< ": ";
os << "\"" << text << "\""
<< ", ";
os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]";
j["timestamps"] = os.str();
#else
j["timestamps"] = timestamps;
#endif
os << "], ";
os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";
j["segment"] = segment;
j["is_final"] = is_final;
sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
sep = ", ";
}
os << "]";
os << "}";
return j.dump();
return os.str();
}
void OnlineRecognizerConfig::Register(ParseOptions *po) {
... ...
... ... @@ -44,11 +44,11 @@ struct OnlineRecognizerResult {
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
/// Starting frame of this segment.
/// Starting time of this segment.
/// When an endpoint is detected, it will change
float start_time = 0;
/// True if this is the last segment.
/// True if the end of this segment is reached
bool is_final = false;
/** Return a json string.
... ...
... ... @@ -43,6 +43,8 @@ class OnlineStream::Impl {
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
int32_t GetNumFramesSinceStart() const { return start_frame_index_; }
int32_t &GetCurrentSegment() { return segment_; }
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
... ... @@ -126,6 +128,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
return impl_->GetNumProcessedFrames();
}
int32_t OnlineStream::GetNumFramesSinceStart() const {
return impl_->GetNumFramesSinceStart();
}
int32_t &OnlineStream::GetCurrentSegment() {
return impl_->GetCurrentSegment();
}
... ...
... ... @@ -66,7 +66,9 @@ class OnlineStream {
// Initially, it is 0. It is always less than NumFramesReady().
//
// The returned reference is valid as long as this object is alive.
int32_t &GetNumProcessedFrames();
int32_t &GetNumProcessedFrames(); // It's reset after calling Reset()
int32_t GetNumFramesSinceStart() const;
int32_t &GetCurrentSegment();
... ...
... ... @@ -195,9 +195,14 @@ void OnlineWebsocketDecoder::Decode() {
for (auto c : c_vec) {
auto result = recognizer_->GetResult(c->s.get());
if (recognizer_->IsEndpoint(c->s.get())) {
result.is_final = true;
recognizer_->Reset(c->s.get());
}
if (!recognizer_->IsReady(c->s.get()) && c->eof) {
result.is_final = true;
}
asio::post(server_->GetConnectionContext(),
[this, hdl = c->hdl, str = result.AsJsonString()]() {
server_->Send(hdl, str);
... ...