online-websocket-server-impl.h
5.0 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// sherpa-onnx/csrc/online-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "asio.hpp"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
struct Connection {
// handle to the connection. We can use it to send messages to the client
connection_hdl hdl;
std::shared_ptr<OnlineStream> s;
// set it to true when InputFinished() is called
bool eof = false;
// The last time we received a message from the client
// TODO(fangjun): Use it to disconnect from a client if it is inactive
// for a specified time.
std::chrono::steady_clock::time_point last_active;
std::mutex mutex; // protect samples
// Audio samples received from the client.
//
// The I/O threads receive audio samples into this queue
// and invoke work threads to compute features
std::deque<std::vector<float>> samples;
Connection() = default;
Connection(connection_hdl hdl, std::shared_ptr<OnlineStream> s)
: hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {}
};
struct OnlineWebsocketDecoderConfig {
OnlineRecognizerConfig recognizer_config;
// It determines how often the decoder loop runs.
int32_t loop_interval_ms = 10;
int32_t max_batch_size = 5;
void Register(ParseOptions *po);
void Validate() const;
};
class OnlineWebsocketServer;
class OnlineWebsocketDecoder {
public:
/**
* @param server Not owned.
*/
explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server);
std::shared_ptr<Connection> GetOrCreateConnection(connection_hdl hdl);
// Compute features for a stream given audio samples
void AcceptWaveform(std::shared_ptr<Connection> c);
// signal that there will be no more audio samples for a stream
void InputFinished(std::shared_ptr<Connection> c);
void Run();
private:
void ProcessConnections(const asio::error_code &ec);
/** It is called by one of the worker thread.
*/
void Decode();
private:
OnlineWebsocketServer *server_; // not owned
std::unique_ptr<OnlineRecognizer> recognizer_;
OnlineWebsocketDecoderConfig config_;
asio::steady_timer timer_;
// It protects `connections_`, `ready_connections_`, and `active_`
std::mutex mutex_;
std::map<connection_hdl, std::shared_ptr<Connection>,
std::owner_less<connection_hdl>>
connections_;
// Whenever a connection has enough feature frames for decoding, we put
// it in this queue
std::deque<std::shared_ptr<Connection>> ready_connections_;
// If we are decoding a stream, we put it in the active_ set so that
// only one thread can decode a stream at a time.
std::set<connection_hdl, std::owner_less<connection_hdl>> active_;
};
struct OnlineWebsocketServerConfig {
OnlineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(sherpa_onnx::ParseOptions *po);
void Validate() const;
};
class OnlineWebsocketServer {
public:
explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OnlineWebsocketServerConfig &config);
void Run(uint16_t port);
const OnlineWebsocketServerConfig &GetConfig() const { return config_; }
asio::io_context &GetConnectionContext() { return io_conn_; }
asio::io_context &GetWorkContext() { return io_work_; }
server &GetServer() { return server_; }
void Send(connection_hdl hdl, const std::string &text);
bool Contains(connection_hdl hdl) const;
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
OnlineWebsocketServerConfig config_;
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::ofstream log_;
sherpa_onnx::TeeStream tee_;
OnlineWebsocketDecoder decoder_;
mutable std::mutex mutex_;
std::set<connection_hdl, std::owner_less<connection_hdl>> connections_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_