Committed by
GitHub
add endpointing for online websocket server (#294)
正在显示
4 个修改的文件
包含
27 行增加
和
2 行删除
| @@ -26,7 +26,8 @@ namespace sherpa_onnx { | @@ -26,7 +26,8 @@ namespace sherpa_onnx { | ||
| 26 | static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | 26 | static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, |
| 27 | const SymbolTable &sym_table, | 27 | const SymbolTable &sym_table, |
| 28 | int32_t frame_shift_ms, | 28 | int32_t frame_shift_ms, |
| 29 | - int32_t subsampling_factor) { | 29 | + int32_t subsampling_factor, |
| 30 | + int32_t segment) { | ||
| 30 | OnlineRecognizerResult r; | 31 | OnlineRecognizerResult r; |
| 31 | r.tokens.reserve(src.tokens.size()); | 32 | r.tokens.reserve(src.tokens.size()); |
| 32 | r.timestamps.reserve(src.tokens.size()); | 33 | r.timestamps.reserve(src.tokens.size()); |
| @@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | @@ -44,6 +45,8 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 44 | r.timestamps.push_back(time); | 45 | r.timestamps.push_back(time); |
| 45 | } | 46 | } |
| 46 | 47 | ||
| 48 | + r.segment = segment; | ||
| 49 | + | ||
| 47 | return r; | 50 | return r; |
| 48 | } | 51 | } |
| 49 | 52 | ||
| @@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -192,7 +195,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 192 | // TODO(fangjun): Remember to change these constants if needed | 195 | // TODO(fangjun): Remember to change these constants if needed |
| 193 | int32_t frame_shift_ms = 10; | 196 | int32_t frame_shift_ms = 10; |
| 194 | int32_t subsampling_factor = 4; | 197 | int32_t subsampling_factor = 4; |
| 195 | - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor); | 198 | + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, |
| 199 | + s->GetCurrentSegment()); | ||
| 196 | } | 200 | } |
| 197 | 201 | ||
| 198 | bool IsEndpoint(OnlineStream *s) const override { | 202 | bool IsEndpoint(OnlineStream *s) const override { |
| @@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -213,6 +217,15 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 213 | } | 217 | } |
| 214 | 218 | ||
| 215 | void Reset(OnlineStream *s) const override { | 219 | void Reset(OnlineStream *s) const override { |
| 220 | + { | ||
| 221 | + // segment is incremented only when the last | ||
| 222 | + // result is not empty | ||
| 223 | + const auto &r = s->GetResult(); | ||
| 224 | + if (!r.tokens.empty() && r.tokens.back() != 0) { | ||
| 225 | + s->GetCurrentSegment() += 1; | ||
| 226 | + } | ||
| 227 | + } | ||
| 228 | + | ||
| 216 | // we keep the decoder_out | 229 | // we keep the decoder_out |
| 217 | decoder_->UpdateDecoderOut(&s->GetResult()); | 230 | decoder_->UpdateDecoderOut(&s->GetResult()); |
| 218 | Ort::Value decoder_out = std::move(s->GetResult().decoder_out); | 231 | Ort::Value decoder_out = std::move(s->GetResult().decoder_out); |
| @@ -43,6 +43,8 @@ class OnlineStream::Impl { | @@ -43,6 +43,8 @@ class OnlineStream::Impl { | ||
| 43 | 43 | ||
| 44 | int32_t &GetNumProcessedFrames() { return num_processed_frames_; } | 44 | int32_t &GetNumProcessedFrames() { return num_processed_frames_; } |
| 45 | 45 | ||
| 46 | + int32_t &GetCurrentSegment() { return segment_; } | ||
| 47 | + | ||
| 46 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } | 48 | void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } |
| 47 | 49 | ||
| 48 | OnlineTransducerDecoderResult &GetResult() { return result_; } | 50 | OnlineTransducerDecoderResult &GetResult() { return result_; } |
| @@ -83,6 +85,7 @@ class OnlineStream::Impl { | @@ -83,6 +85,7 @@ class OnlineStream::Impl { | ||
| 83 | ContextGraphPtr context_graph_; | 85 | ContextGraphPtr context_graph_; |
| 84 | int32_t num_processed_frames_ = 0; // before subsampling | 86 | int32_t num_processed_frames_ = 0; // before subsampling |
| 85 | int32_t start_frame_index_ = 0; // never reset | 87 | int32_t start_frame_index_ = 0; // never reset |
| 88 | + int32_t segment_ = 0; | ||
| 86 | OnlineTransducerDecoderResult result_; | 89 | OnlineTransducerDecoderResult result_; |
| 87 | std::vector<Ort::Value> states_; | 90 | std::vector<Ort::Value> states_; |
| 88 | std::vector<float> paraformer_feat_cache_; | 91 | std::vector<float> paraformer_feat_cache_; |
| @@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { | @@ -123,6 +126,10 @@ int32_t &OnlineStream::GetNumProcessedFrames() { | ||
| 123 | return impl_->GetNumProcessedFrames(); | 126 | return impl_->GetNumProcessedFrames(); |
| 124 | } | 127 | } |
| 125 | 128 | ||
| 129 | +int32_t &OnlineStream::GetCurrentSegment() { | ||
| 130 | + return impl_->GetCurrentSegment(); | ||
| 131 | +} | ||
| 132 | + | ||
| 126 | void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { | 133 | void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { |
| 127 | impl_->SetResult(r); | 134 | impl_->SetResult(r); |
| 128 | } | 135 | } |
| @@ -68,6 +68,8 @@ class OnlineStream { | @@ -68,6 +68,8 @@ class OnlineStream { | ||
| 68 | // The returned reference is valid as long as this object is alive. | 68 | // The returned reference is valid as long as this object is alive. |
| 69 | int32_t &GetNumProcessedFrames(); | 69 | int32_t &GetNumProcessedFrames(); |
| 70 | 70 | ||
| 71 | + int32_t &GetCurrentSegment(); | ||
| 72 | + | ||
| 71 | void SetResult(const OnlineTransducerDecoderResult &r); | 73 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 72 | OnlineTransducerDecoderResult &GetResult(); | 74 | OnlineTransducerDecoderResult &GetResult(); |
| 73 | 75 |
| @@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() { | @@ -194,6 +194,9 @@ void OnlineWebsocketDecoder::Decode() { | ||
| 194 | 194 | ||
| 195 | for (auto c : c_vec) { | 195 | for (auto c : c_vec) { |
| 196 | auto result = recognizer_->GetResult(c->s.get()); | 196 | auto result = recognizer_->GetResult(c->s.get()); |
| 197 | + if (recognizer_->IsEndpoint(c->s.get())) { | ||
| 198 | + recognizer_->Reset(c->s.get()); | ||
| 199 | + } | ||
| 197 | 200 | ||
| 198 | asio::post(server_->GetConnectionContext(), | 201 | asio::post(server_->GetConnectionContext(), |
| 199 | [this, hdl = c->hdl, str = result.AsJsonString()]() { | 202 | [this, hdl = c->hdl, str = result.AsJsonString()]() { |
-
请 注册 或 登录 后发表评论