Fangjun Kuang
Committed by GitHub

Add timestamps for streaming ASR. (#123)

... ... @@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
include(asio)
endif()
include(json)
add_subdirectory(sherpa-onnx)
if(SHERPA_ONNX_ENABLE_C_API)
... ...
function(download_json)
include(FetchContent)
set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
# If you don't have access to the Internet,
# please pre-download json
set(possible_file_locations
$ENV{HOME}/Downloads/json-3.11.2.tar.gz
${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
/tmp/json-3.11.2.tar.gz
/star-fj/fangjun/download/github/json-3.11.2.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(json_URL "${f}")
file(TO_CMAKE_PATH "${json_URL}" json_URL)
set(json_URL2)
break()
endif()
endforeach()
FetchContent_Declare(json
URL
${json_URL}
${json_URL2}
URL_HASH ${json_HASH}
)
FetchContent_GetProperties(json)
if(NOT json_POPULATED)
message(STATUS "Downloading json from ${json_URL}")
FetchContent_Populate(json)
endif()
message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
include_directories(${json_SOURCE_DIR}/include)
# Use #include "nlohmann/json.hpp"
endfunction()
download_json()
... ...
... ... @@ -8,11 +8,13 @@
#include <assert.h>
#include <algorithm>
#include <iomanip>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
... ... @@ -22,16 +24,56 @@
namespace sherpa_onnx {
std::string OnlineRecognizerResult::AsJsonString() const {
using json = nlohmann::json;
json j;
j["text"] = text;
j["tokens"] = tokens;
j["start_time"] = start_time;
#if 1
// This branch chooses number of decimal points to keep in
// the return json string
std::ostringstream os;
os << "[";
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
sep = ", ";
}
os << "]";
j["timestamps"] = os.str();
#else
j["timestamps"] = timestamps;
#endif
j["segment"] = segment;
j["is_final"] = is_final;
return j.dump();
}
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table) {
std::string text;
for (auto t : src.tokens) {
text += sym_table[t];
const SymbolTable &sym_table,
int32_t frame_shift_ms,
int32_t subsampling_factor) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.tokens.size());
for (auto i : src.tokens) {
auto sym = sym_table[i];
r.text.append(sym);
r.tokens.push_back(std::move(sym));
}
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
OnlineRecognizerResult ans;
ans.text = std::move(text);
return ans;
return r;
}
void OnlineRecognizerConfig::Register(ParseOptions *po) {
... ... @@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
return Convert(decoder_result, sym_);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 4;
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
}
bool IsEndpoint(OnlineStream *s) const {
... ...
... ... @@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -22,10 +23,45 @@
namespace sherpa_onnx {
struct OnlineRecognizerResult {
/// Recognition results.
/// For English, it consists of space separated words.
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
std::string text;
// TODO(fangjun): Add a method to return a json string
std::string ToString() const { return text; }
/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
std::vector<std::string> tokens;
/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;
/// ID of this segment
/// When an endpoint is detected, it is incremented
int32_t segment = 0;
/// Starting frame of this segment.
/// When an endpoint is detected, it will change
float start_time = 0;
/// True if this is the last segment.
bool is_final = false;
/** Return a json string.
*
* The returned string contains:
* {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
* }
*/
std::string AsJsonString() const;
};
struct OnlineRecognizerConfig {
... ...
... ... @@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
hyps = other.hyps;
frame_offset = other.frame_offset;
timestamps = other.timestamps;
return *this;
}
... ... @@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps);
frame_offset = other.frame_offset;
timestamps = std::move(other.timestamps);
return *this;
}
... ...
... ... @@ -13,12 +13,18 @@
namespace sherpa_onnx {
struct OnlineTransducerDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
/// The decoded token IDs so far
std::vector<int64_t> tokens;
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
std::vector<int32_t> timestamps;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
... ...
... ... @@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
bool emitted = false;
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
auto &r = (*result)[i];
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != 0) {
emitted = true;
(*result)[i].tokens.push_back(y);
(*result)[i].num_trailing_blanks = 0;
r.tokens.push_back(y);
r.timestamps.push_back(t + r.frame_offset);
r.num_trailing_blanks = 0;
} else {
++(*result)[i].num_trailing_blanks;
++r.num_trailing_blanks;
}
}
if (emitted) {
... ... @@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
// Update frame_offset
for (auto &r : *result) {
r.frame_offset += num_frames;
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
r->tokens = std::move(tokens);
r->timestamps = std::move(hyp.timestamps);
r->num_trailing_blanks = hyp.num_trailing_blanks;
}
... ... @@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
float *p_logit = logit.GetTensorMutableData<float>();
for (int32_t b = 0; b < batch_size; ++b) {
int32_t frame_offset = (*result)[b].frame_offset;
int32_t start = hyps_num_split[b];
int32_t end = hyps_num_split[b + 1];
LogSoftmax(p_logit, vocab_size, (end - start));
... ... @@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Hypothesis new_hyp = prev[hyp_index];
if (new_token != 0) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t + frame_offset);
new_hyp.num_trailing_blanks = 0;
} else {
++new_hyp.num_trailing_blanks;
... ... @@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
auto &r = (*result)[b];
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
r.hyps = std::move(hyps);
r.tokens = std::move(best_hyp.ys);
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
r.frame_offset += num_frames;
}
}
... ...
... ... @@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
auto result = recognizer_->GetResult(c->s.get());
asio::post(server_->GetConnectionContext(),
[this, hdl = c->hdl, str = result.ToString()]() {
[this, hdl = c->hdl, str = result.AsJsonString()]() {
server_->Send(hdl, str);
});
active_.erase(c->hdl);
... ...
... ... @@ -102,7 +102,7 @@ for a list of pre-trained models to download.
recognizer.DecodeStream(s.get());
}
std::string text = recognizer.GetResult(s.get()).text;
std::string text = recognizer.GetResult(s.get()).AsJsonString();
fprintf(stderr, "Done!\n");
... ...
... ... @@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
sherpa_onnx::OnlineStream *s =
reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
return env->NewStringUTF(result.ToString().c_str());
return env->NewStringUTF(result.text.c_str());
}
SHERPA_ONNX_EXTERN_C
... ...