Fangjun Kuang
Committed by GitHub

Add timestamps for streaming ASR. (#123)

@@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET) @@ -126,6 +126,8 @@ if(SHERPA_ONNX_ENABLE_WEBSOCKET)
126 include(asio) 126 include(asio)
127 endif() 127 endif()
128 128
  129 +include(json)
  130 +
129 add_subdirectory(sherpa-onnx) 131 add_subdirectory(sherpa-onnx)
130 132
131 if(SHERPA_ONNX_ENABLE_C_API) 133 if(SHERPA_ONNX_ENABLE_C_API)
  1 +function(download_json)
  2 + include(FetchContent)
  3 +
  4 + set(json_URL "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz")
  5 + set(json_URL2 "https://huggingface.co/csukuangfj/sherpa-cmake-deps/resolve/main/json-3.11.2.tar.gz")
  6 + set(json_HASH "SHA256=d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download json
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/json-3.11.2.tar.gz
  12 + ${PROJECT_SOURCE_DIR}/json-3.11.2.tar.gz
  13 + ${PROJECT_BINARY_DIR}/json-3.11.2.tar.gz
  14 + /tmp/json-3.11.2.tar.gz
  15 + /star-fj/fangjun/download/github/json-3.11.2.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(json_URL "${f}")
  21 + file(TO_CMAKE_PATH "${json_URL}" json_URL)
  22 + set(json_URL2)
  23 + break()
  24 + endif()
  25 + endforeach()
  26 +
  27 + FetchContent_Declare(json
  28 + URL
  29 + ${json_URL}
  30 + ${json_URL2}
  31 + URL_HASH ${json_HASH}
  32 + )
  33 +
  34 + FetchContent_GetProperties(json)
  35 + if(NOT json_POPULATED)
  36 + message(STATUS "Downloading json from ${json_URL}")
  37 + FetchContent_Populate(json)
  38 + endif()
  39 + message(STATUS "json is downloaded to ${json_SOURCE_DIR}")
  40 + include_directories(${json_SOURCE_DIR}/include)
  41 + # Use #include "nlohmann/json.hpp"
  42 +endfunction()
  43 +
  44 +download_json()
@@ -8,11 +8,13 @@ @@ -8,11 +8,13 @@
8 #include <assert.h> 8 #include <assert.h>
9 9
10 #include <algorithm> 10 #include <algorithm>
  11 +#include <iomanip>
11 #include <memory> 12 #include <memory>
12 #include <sstream> 13 #include <sstream>
13 #include <utility> 14 #include <utility>
14 #include <vector> 15 #include <vector>
15 16
  17 +#include "nlohmann/json.hpp"
16 #include "sherpa-onnx/csrc/file-utils.h" 18 #include "sherpa-onnx/csrc/file-utils.h"
17 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 19 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
18 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 20 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
@@ -22,16 +24,56 @@ @@ -22,16 +24,56 @@
22 24
23 namespace sherpa_onnx { 25 namespace sherpa_onnx {
24 26
  27 +std::string OnlineRecognizerResult::AsJsonString() const {
  28 + using json = nlohmann::json;
  29 + json j;
  30 + j["text"] = text;
  31 + j["tokens"] = tokens;
  32 + j["start_time"] = start_time;
  33 +#if 1
  34 + // This branch chooses number of decimal points to keep in
  35 + // the return json string
  36 + std::ostringstream os;
  37 + os << "[";
  38 + std::string sep = "";
  39 + for (auto t : timestamps) {
  40 + os << sep << std::fixed << std::setprecision(2) << t;
  41 + sep = ", ";
  42 + }
  43 + os << "]";
  44 + j["timestamps"] = os.str();
  45 +#else
  46 + j["timestamps"] = timestamps;
  47 +#endif
  48 +
  49 + j["segment"] = segment;
  50 + j["is_final"] = is_final;
  51 +
  52 + return j.dump();
  53 +}
  54 +
25 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, 55 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
26 - const SymbolTable &sym_table) {  
27 - std::string text;  
28 - for (auto t : src.tokens) {  
29 - text += sym_table[t]; 56 + const SymbolTable &sym_table,
  57 + int32_t frame_shift_ms,
  58 + int32_t subsampling_factor) {
  59 + OnlineRecognizerResult r;
  60 + r.tokens.reserve(src.tokens.size());
  61 + r.timestamps.reserve(src.tokens.size());
  62 +
  63 + for (auto i : src.tokens) {
  64 + auto sym = sym_table[i];
  65 +
  66 + r.text.append(sym);
  67 + r.tokens.push_back(std::move(sym));
  68 + }
  69 +
  70 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  71 + for (auto t : src.timestamps) {
  72 + float time = frame_shift_s * t;
  73 + r.timestamps.push_back(time);
30 } 74 }
31 75
32 - OnlineRecognizerResult ans;  
33 - ans.text = std::move(text);  
34 - return ans; 76 + return r;
35 } 77 }
36 78
37 void OnlineRecognizerConfig::Register(ParseOptions *po) { 79 void OnlineRecognizerConfig::Register(ParseOptions *po) {
@@ -169,7 +211,10 @@ class OnlineRecognizer::Impl { @@ -169,7 +211,10 @@ class OnlineRecognizer::Impl {
169 OnlineTransducerDecoderResult decoder_result = s->GetResult(); 211 OnlineTransducerDecoderResult decoder_result = s->GetResult();
170 decoder_->StripLeadingBlanks(&decoder_result); 212 decoder_->StripLeadingBlanks(&decoder_result);
171 213
172 - return Convert(decoder_result, sym_); 214 + // TODO(fangjun): Remember to change these constants if needed
  215 + int32_t frame_shift_ms = 10;
  216 + int32_t subsampling_factor = 4;
  217 + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
173 } 218 }
174 219
175 bool IsEndpoint(OnlineStream *s) const { 220 bool IsEndpoint(OnlineStream *s) const {
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 7
8 #include <memory> 8 #include <memory>
9 #include <string> 9 #include <string>
  10 +#include <vector>
10 11
11 #if __ANDROID_API__ >= 9 12 #if __ANDROID_API__ >= 9
12 #include "android/asset_manager.h" 13 #include "android/asset_manager.h"
@@ -22,10 +23,45 @@ @@ -22,10 +23,45 @@
22 namespace sherpa_onnx { 23 namespace sherpa_onnx {
23 24
24 struct OnlineRecognizerResult { 25 struct OnlineRecognizerResult {
  26 + /// Recognition results.
  27 + /// For English, it consists of space separated words.
  28 + /// For Chinese, it consists of Chinese words without spaces.
  29 + /// Example 1: "hello world"
  30 + /// Example 2: "你好世界"
25 std::string text; 31 std::string text;
26 32
27 - // TODO(fangjun): Add a method to return a json string  
28 - std::string ToString() const { return text; } 33 + /// Decoded results at the token level.
  34 + /// For instance, for BPE-based models it consists of a list of BPE tokens.
  35 + std::vector<std::string> tokens;
  36 +
  37 + /// timestamps.size() == tokens.size()
  38 + /// timestamps[i] records the time in seconds when tokens[i] is decoded.
  39 + std::vector<float> timestamps;
  40 +
  41 + /// ID of this segment
  42 + /// When an endpoint is detected, it is incremented
  43 + int32_t segment = 0;
  44 +
  45 + /// Starting frame of this segment.
  46 + /// When an endpoint is detected, it will change
  47 + float start_time = 0;
  48 +
  49 + /// True if this is the last segment.
  50 + bool is_final = false;
  51 +
  52 + /** Return a json string.
  53 + *
  54 + * The returned string contains:
  55 + * {
  56 + * "text": "The recognition result",
  57 + * "tokens": [x, x, x],
  58 + * "timestamps": [x, x, x],
  59 + * "segment": x,
  60 + * "start_time": x,
  61 + * "is_final": true|false
  62 + * }
  63 + */
  64 + std::string AsJsonString() const;
29 }; 65 };
30 66
31 struct OnlineRecognizerConfig { 67 struct OnlineRecognizerConfig {
@@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( @@ -34,6 +34,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
34 34
35 hyps = other.hyps; 35 hyps = other.hyps;
36 36
  37 + frame_offset = other.frame_offset;
  38 + timestamps = other.timestamps;
  39 +
37 return *this; 40 return *this;
38 } 41 }
39 42
@@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( @@ -54,6 +57,9 @@ OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
54 decoder_out = std::move(other.decoder_out); 57 decoder_out = std::move(other.decoder_out);
55 hyps = std::move(other.hyps); 58 hyps = std::move(other.hyps);
56 59
  60 + frame_offset = other.frame_offset;
  61 + timestamps = std::move(other.timestamps);
  62 +
57 return *this; 63 return *this;
58 } 64 }
59 65
@@ -13,12 +13,18 @@ @@ -13,12 +13,18 @@
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 struct OnlineTransducerDecoderResult { 15 struct OnlineTransducerDecoderResult {
  16 + /// Number of frames after subsampling we have decoded so far
  17 + int32_t frame_offset = 0;
  18 +
16 /// The decoded token IDs so far 19 /// The decoded token IDs so far
17 std::vector<int64_t> tokens; 20 std::vector<int64_t> tokens;
18 21
19 /// number of trailing blank frames decoded so far 22 /// number of trailing blank frames decoded so far
20 int32_t num_trailing_blanks = 0; 23 int32_t num_trailing_blanks = 0;
21 24
  25 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  26 + std::vector<int32_t> timestamps;
  27 +
22 // Cache decoder_out for endpointing 28 // Cache decoder_out for endpointing
23 Ort::Value decoder_out; 29 Ort::Value decoder_out;
24 30
@@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -102,16 +102,18 @@ void OnlineTransducerGreedySearchDecoder::Decode(
102 102
103 bool emitted = false; 103 bool emitted = false;
104 for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { 104 for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
  105 + auto &r = (*result)[i];
105 auto y = static_cast<int32_t>(std::distance( 106 auto y = static_cast<int32_t>(std::distance(
106 static_cast<const float *>(p_logit), 107 static_cast<const float *>(p_logit),
107 std::max_element(static_cast<const float *>(p_logit), 108 std::max_element(static_cast<const float *>(p_logit),
108 static_cast<const float *>(p_logit) + vocab_size))); 109 static_cast<const float *>(p_logit) + vocab_size)));
109 if (y != 0) { 110 if (y != 0) {
110 emitted = true; 111 emitted = true;
111 - (*result)[i].tokens.push_back(y);  
112 - (*result)[i].num_trailing_blanks = 0; 112 + r.tokens.push_back(y);
  113 + r.timestamps.push_back(t + r.frame_offset);
  114 + r.num_trailing_blanks = 0;
113 } else { 115 } else {
114 - ++(*result)[i].num_trailing_blanks; 116 + ++r.num_trailing_blanks;
115 } 117 }
116 } 118 }
117 if (emitted) { 119 if (emitted) {
@@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -121,6 +123,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
121 } 123 }
122 124
123 UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result); 125 UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
  126 +
  127 + // Update frame_offset
  128 + for (auto &r : *result) {
  129 + r.frame_offset += num_frames;
  130 + }
124 } 131 }
125 132
126 } // namespace sherpa_onnx 133 } // namespace sherpa_onnx
@@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( @@ -87,6 +87,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks(
87 87
88 std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end()); 88 std::vector<int64_t> tokens(hyp.ys.begin() + context_size, hyp.ys.end());
89 r->tokens = std::move(tokens); 89 r->tokens = std::move(tokens);
  90 + r->timestamps = std::move(hyp.timestamps);
90 r->num_trailing_blanks = hyp.num_trailing_blanks; 91 r->num_trailing_blanks = hyp.num_trailing_blanks;
91 } 92 }
92 93
@@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -148,6 +149,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
148 float *p_logit = logit.GetTensorMutableData<float>(); 149 float *p_logit = logit.GetTensorMutableData<float>();
149 150
150 for (int32_t b = 0; b < batch_size; ++b) { 151 for (int32_t b = 0; b < batch_size; ++b) {
  152 + int32_t frame_offset = (*result)[b].frame_offset;
151 int32_t start = hyps_num_split[b]; 153 int32_t start = hyps_num_split[b];
152 int32_t end = hyps_num_split[b + 1]; 154 int32_t end = hyps_num_split[b + 1];
153 LogSoftmax(p_logit, vocab_size, (end - start)); 155 LogSoftmax(p_logit, vocab_size, (end - start));
@@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -162,6 +164,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
162 Hypothesis new_hyp = prev[hyp_index]; 164 Hypothesis new_hyp = prev[hyp_index];
163 if (new_token != 0) { 165 if (new_token != 0) {
164 new_hyp.ys.push_back(new_token); 166 new_hyp.ys.push_back(new_token);
  167 + new_hyp.timestamps.push_back(t + frame_offset);
165 new_hyp.num_trailing_blanks = 0; 168 new_hyp.num_trailing_blanks = 0;
166 } else { 169 } else {
167 ++new_hyp.num_trailing_blanks; 170 ++new_hyp.num_trailing_blanks;
@@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -177,10 +180,12 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
177 for (int32_t b = 0; b != batch_size; ++b) { 180 for (int32_t b = 0; b != batch_size; ++b) {
178 auto &hyps = cur[b]; 181 auto &hyps = cur[b];
179 auto best_hyp = hyps.GetMostProbable(true); 182 auto best_hyp = hyps.GetMostProbable(true);
  183 + auto &r = (*result)[b];
180 184
181 - (*result)[b].hyps = std::move(hyps);  
182 - (*result)[b].tokens = std::move(best_hyp.ys);  
183 - (*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks; 185 + r.hyps = std::move(hyps);
  186 + r.tokens = std::move(best_hyp.ys);
  187 + r.num_trailing_blanks = best_hyp.num_trailing_blanks;
  188 + r.frame_offset += num_frames;
184 } 189 }
185 } 190 }
186 191
@@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() { @@ -196,7 +196,7 @@ void OnlineWebsocketDecoder::Decode() {
196 auto result = recognizer_->GetResult(c->s.get()); 196 auto result = recognizer_->GetResult(c->s.get());
197 197
198 asio::post(server_->GetConnectionContext(), 198 asio::post(server_->GetConnectionContext(),
199 - [this, hdl = c->hdl, str = result.ToString()]() { 199 + [this, hdl = c->hdl, str = result.AsJsonString()]() {
200 server_->Send(hdl, str); 200 server_->Send(hdl, str);
201 }); 201 });
202 active_.erase(c->hdl); 202 active_.erase(c->hdl);
@@ -102,7 +102,7 @@ for a list of pre-trained models to download. @@ -102,7 +102,7 @@ for a list of pre-trained models to download.
102 recognizer.DecodeStream(s.get()); 102 recognizer.DecodeStream(s.get());
103 } 103 }
104 104
105 - std::string text = recognizer.GetResult(s.get()).text; 105 + std::string text = recognizer.GetResult(s.get()).AsJsonString();
106 106
107 fprintf(stderr, "Done!\n"); 107 fprintf(stderr, "Done!\n");
108 108
@@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult( @@ -434,7 +434,7 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(
434 sherpa_onnx::OnlineStream *s = 434 sherpa_onnx::OnlineStream *s =
435 reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr); 435 reinterpret_cast<sherpa_onnx::OnlineStream *>(s_ptr);
436 sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s); 436 sherpa_onnx::OnlineRecognizerResult result = model->GetResult(s);
437 - return env->NewStringUTF(result.ToString().c_str()); 437 + return env->NewStringUTF(result.text.c_str());
438 } 438 }
439 439
440 SHERPA_ONNX_EXTERN_C 440 SHERPA_ONNX_EXTERN_C