offline-websocket-server-impl.h
6.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// sherpa-onnx/csrc/offline-websocket-server-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"
using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;
namespace sherpa_onnx {
/** Communication protocol
*
* The client sends a byte stream to the server. The first 4 bytes in little
* endian indicates the sample rate of the audio data that the client will send.
* The next 4 bytes in little endian indicates the total samples in bytes the
* client will send. The remaining bytes represent audio samples. Each audio
* sample is a float occupying 4 bytes and is normalized into the range
* [-1, 1].
*
* The byte stream can be broken into arbitrary number of messages.
* We require that the first message has to be at least 8 bytes so that
* we can get `sample_rate` and `expected_byte_size` from the first message.
*/
struct ConnectionData {
// Sample rate of the audio samples the client
int32_t sample_rate;
// Number of expected bytes sent from the client
int32_t expected_byte_size = 0;
// Number of bytes received so far
int32_t cur = 0;
// It saves the received samples from the client.
// We will **reinterpret_cast** it to float.
// We expect that data.size() == expected_byte_size
std::vector<int8_t> data;
void Clear() {
sample_rate = 0;
expected_byte_size = 0;
cur = 0;
data.clear();
}
};
using ConnectionDataPtr = std::shared_ptr<ConnectionData>;
struct OfflineWebsocketDecoderConfig {
OfflineRecognizerConfig recognizer_config;
int32_t max_batch_size = 5;
float max_utterance_length = 300; // seconds
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer;
class OfflineWebsocketDecoder {
public:
/**
* @param config Configuration for the decoder.
* @param server **Borrowed** from outside.
*/
explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);
/** Insert received data to the queue for decoding.
*
* @param hdl A handle to the connection. We can use it to send the result
* back to the client once it finishes decoding.
* @param d The received data
*/
void Push(connection_hdl hdl, ConnectionDataPtr d);
/** It is called by one of the work thread.
*/
void Decode();
const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }
private:
OfflineWebsocketDecoderConfig config_;
/** When we have received all the data from the client, we put it into
* this queue; the worker threads will get items from this queue for
* decoding.
*
* Number of items to take from this queue is determined by
* `--max-batch-size`. If there are not enough items in the queue, we won't
* wait and take whatever we have for decoding.
*/
std::mutex mutex_;
std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;
OfflineWebsocketServer *server_; // Not owned
OfflineRecognizer recognizer_;
};
struct OfflineWebsocketServerConfig {
OfflineWebsocketDecoderConfig decoder_config;
std::string log_file = "./log.txt";
void Register(ParseOptions *po);
void Validate() const;
};
class OfflineWebsocketServer {
public:
OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT
asio::io_context &io_work, // NOLINT
const OfflineWebsocketServerConfig &config);
asio::io_context &GetConnectionContext() { return io_conn_; }
server &GetServer() { return server_; }
void Run(uint16_t port);
const OfflineWebsocketServerConfig &GetConfig() const { return config_; }
private:
void SetupLog();
// When a websocket client is connected, it will invoke this method
// (Not for HTTP)
void OnOpen(connection_hdl hdl);
// When a websocket client is disconnected, it will invoke this method
void OnClose(connection_hdl hdl);
// When a message is received from a websocket client, this method will
// be invoked.
//
// The protocol between the client and the server is as follows:
//
// (1) The client connects to the server
// (2) The client starts to send binary byte stream to the server.
// The byte stream can be broken into multiple messages or it can
// be put into a single message.
// The first message has to contain at least 8 bytes. The first
// 4 bytes in little endian contains a int32_t indicating the
// sampling rate. The next 4 bytes in little endian contains a int32_t
// indicating total number of bytes of samples the client will send.
// We assume each sample is a float containing 4 bytes and has been
// normalized to the range [-1, 1].
// (4) When the server receives all the samples from the client, it will
// start to decode them. Once decoded, the server sends a text message
// to the client containing the decoded results
// (5) After receiving the decoded results from the server, if the client has
// another audio file to send, it repeats (2), (3), (4)
// (6) If the client has no more audio files to decode, the client sends a
// text message containing "Done" to the server and closes the connection
// (7) The server receives a text message "Done" and closes the connection
//
// Note:
// (a) All models in icefall use features extracted from audio samples
// normalized to the range [-1, 1]. Please send normalized audio samples
// if you use models from icefall.
// (b) Only sound files with a single channel is supported
// (c) Only audio samples are sent. For instance, if we want to decode
// a WAVE file, the RIFF header of the WAVE is not sent.
void OnMessage(connection_hdl hdl, server::message_ptr msg);
// Close a websocket connection with given code and reason
void Close(connection_hdl hdl, websocketpp::close::status::value code,
const std::string &reason);
private:
asio::io_context &io_conn_;
asio::io_context &io_work_;
server server_;
std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
connections_;
std::mutex mutex_;
OfflineWebsocketServerConfig config_;
std::ofstream log_;
TeeStream tee_;
OfflineWebsocketDecoder decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_