Fangjun Kuang
Committed by GitHub

add endpointing for online websocket server (#294)

@@ -26,7 +26,8 @@ namespace sherpa_onnx { @@ -26,7 +26,8 @@ namespace sherpa_onnx {
26 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, 26 static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
27 const SymbolTable &sym_table, 27 const SymbolTable &sym_table,
28 int32_t frame_shift_ms, 28 int32_t frame_shift_ms,
29 - int32_t subsampling_factor) { 29 + int32_t subsampling_factor,
  30 + int32_t segment) {
30 OnlineRecognizerResult r; 31 OnlineRecognizerResult r;
31 r.tokens.reserve(src.tokens.size()); 32 r.tokens.reserve(src.tokens.size());
32 r.timestamps.reserve(src.tokens.size()); 33 r.timestamps.reserve(src.tokens.size());
@@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, @@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
44 r.timestamps.push_back(time); 45 r.timestamps.push_back(time);
45 } 46 }
46 47
  48 + r.segment = segment;
  49 +
47 return r; 50 return r;
48 } 51 }
49 52
@@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
192 // TODO(fangjun): Remember to change these constants if needed 195 // TODO(fangjun): Remember to change these constants if needed
193 int32_t frame_shift_ms = 10; 196 int32_t frame_shift_ms = 10;
194 int32_t subsampling_factor = 4; 197 int32_t subsampling_factor = 4;
195 - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); 198 + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
  199 + s->GetCurrentSegment());
196 } 200 }
197 201
198 bool IsEndpoint(OnlineStream *s) const override { 202 bool IsEndpoint(OnlineStream *s) const override {
@@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
213 } 217 }
214 218
215 void Reset(OnlineStream *s) const override { 219 void Reset(OnlineStream *s) const override {
  220 + {
  221 + // segment is incremented only when the last
  222 + // result is not empty
  223 + const auto &r = s->GetResult();
  224 + if (!r.tokens.empty() && r.tokens.back() != 0) {
  225 + s->GetCurrentSegment() += 1;
  226 + }
  227 + }
  228 +
216 // we keep the decoder_out 229 // we keep the decoder_out
217 decoder_->UpdateDecoderOut(&s->GetResult()); 230 decoder_->UpdateDecoderOut(&s->GetResult());
218 Ort::Value decoder_out = std::move(s->GetResult().decoder_out); 231 Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
@@ -43,6 +43,8 @@ class OnlineStream::Impl { @@ -43,6 +43,8 @@ class OnlineStream::Impl {
43 43
44 int32_t &GetNumProcessedFrames() { return num_processed_frames_; } 44 int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
45 45
  46 + int32_t &GetCurrentSegment() { return segment_; }
  47 +
46 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } 48 void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
47 49
48 OnlineTransducerDecoderResult &GetResult() { return result_; } 50 OnlineTransducerDecoderResult &GetResult() { return result_; }
@@ -83,6 +85,7 @@ class OnlineStream::Impl { @@ -83,6 +85,7 @@ class OnlineStream::Impl {
83 ContextGraphPtr context_graph_; 85 ContextGraphPtr context_graph_;
84 int32_t num_processed_frames_ = 0; // before subsampling 86 int32_t num_processed_frames_ = 0; // before subsampling
85 int32_t start_frame_index_ = 0; // never reset 87 int32_t start_frame_index_ = 0; // never reset
  88 + int32_t segment_ = 0;
86 OnlineTransducerDecoderResult result_; 89 OnlineTransducerDecoderResult result_;
87 std::vector<Ort::Value> states_; 90 std::vector<Ort::Value> states_;
88 std::vector<float> paraformer_feat_cache_; 91 std::vector<float> paraformer_feat_cache_;
@@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { @@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() {
123 return impl_->GetNumProcessedFrames(); 126 return impl_->GetNumProcessedFrames();
124 } 127 }
125 128
  129 +int32_t &OnlineStream::GetCurrentSegment() {
  130 + return impl_->GetCurrentSegment();
  131 +}
  132 +
126 void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { 133 void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
127 impl_->SetResult(r); 134 impl_->SetResult(r);
128 } 135 }
@@ -68,6 +68,8 @@ class OnlineStream { @@ -68,6 +68,8 @@ class OnlineStream {
68 // The returned reference is valid as long as this object is alive. 68 // The returned reference is valid as long as this object is alive.
69 int32_t &GetNumProcessedFrames(); 69 int32_t &GetNumProcessedFrames();
70 70
  71 + int32_t &GetCurrentSegment();
  72 +
71 void SetResult(const OnlineTransducerDecoderResult &r); 73 void SetResult(const OnlineTransducerDecoderResult &r);
72 OnlineTransducerDecoderResult &GetResult(); 74 OnlineTransducerDecoderResult &GetResult();
73 75
@@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() { @@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() {
194 194
195 for (auto c : c_vec) { 195 for (auto c : c_vec) {
196 auto result = recognizer_->GetResult(c->s.get()); 196 auto result = recognizer_->GetResult(c->s.get());
  197 + if (recognizer_->IsEndpoint(c->s.get())) {
  198 + recognizer_->Reset(c->s.get());
  199 + }
197 200
198 asio::post(server_->GetConnectionContext(), 201 asio::post(server_->GetConnectionContext(),
199 [this, hdl = c->hdl, str = result.AsJsonString()]() { 202 [this, hdl = c->hdl, str = result.AsJsonString()]() {