offline-websocket-server-impl.h 6.7 KB
// sherpa-onnx/csrc/offline-websocket-server-impl.h
//
// Copyright (c)  2022-2023  Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_

#include <deque>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/tee-stream.h"
#include "websocketpp/config/asio_no_tls.hpp"  // TODO(fangjun): support TLS
#include "websocketpp/server.hpp"

using server = websocketpp::server<websocketpp::config::asio>;
using connection_hdl = websocketpp::connection_hdl;

namespace sherpa_onnx {

/** Communication protocol
 *
 * The client sends a byte stream to the server. The first 4 bytes in little
 * endian indicates the sample rate of the audio data that the client will send.
 * The next 4 bytes in little endian indicates the total samples in bytes the
 * client will send. The remaining bytes represent audio samples. Each audio
 * sample is a float occupying 4 bytes and is normalized into the range
 * [-1, 1].
 *
 * The byte stream can be broken into arbitrary number of messages.
 * We require that the first message has to be at least 8 bytes so that
 * we can get `sample_rate` and `expected_byte_size` from the first message.
 */
struct ConnectionData {
  // Sample rate of the audio samples the client
  int32_t sample_rate;

  // Number of expected bytes sent from the client
  int32_t expected_byte_size = 0;

  // Number of bytes received so far
  int32_t cur = 0;

  // It saves the received samples from the client.
  // We will **reinterpret_cast** it to float.
  // We expect that data.size() == expected_byte_size
  std::vector<int8_t> data;

  void Clear() {
    sample_rate = 0;
    expected_byte_size = 0;
    cur = 0;
    data.clear();
  }
};

using ConnectionDataPtr = std::shared_ptr<ConnectionData>;

struct OfflineWebsocketDecoderConfig {
  OfflineRecognizerConfig recognizer_config;

  int32_t max_batch_size = 5;

  float max_utterance_length = 300;  // seconds

  void Register(ParseOptions *po);
  void Validate() const;
};

class OfflineWebsocketServer;

class OfflineWebsocketDecoder {
 public:
  /**
   * @param config Configuration for the decoder.
   * @param server **Borrowed** from outside.
   */
  explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server);

  /** Insert received data to the queue for decoding.
   *
   * @param hdl A handle to the connection. We can use it to send the result
   *            back to the client once it finishes decoding.
   * @param d  The received data
   */
  void Push(connection_hdl hdl, ConnectionDataPtr d);

  /** It is called by one of the work thread.
   */
  void Decode();

  const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; }

 private:
  OfflineWebsocketDecoderConfig config_;

  /** When we have received all the data from the client, we put it into
   * this queue; the worker threads will get items from this queue for
   * decoding.
   *
   * Number of items to take from this queue is determined by
   * `--max-batch-size`. If there are not enough items in the queue, we won't
   * wait and take whatever we have for decoding.
   */
  std::mutex mutex_;
  std::deque<std::pair<connection_hdl, ConnectionDataPtr>> streams_;

  OfflineWebsocketServer *server_;  // Not owned
  OfflineRecognizer recognizer_;
};

struct OfflineWebsocketServerConfig {
  OfflineWebsocketDecoderConfig decoder_config;
  std::string log_file = "./log.txt";

  void Register(ParseOptions *po);
  void Validate() const;
};

class OfflineWebsocketServer {
 public:
  OfflineWebsocketServer(asio::io_context &io_conn,  // NOLINT
                         asio::io_context &io_work,  // NOLINT
                         const OfflineWebsocketServerConfig &config);

  asio::io_context &GetConnectionContext() { return io_conn_; }
  server &GetServer() { return server_; }

  void Run(uint16_t port);

  const OfflineWebsocketServerConfig &GetConfig() const { return config_; }

 private:
  void SetupLog();

  // When a websocket client is connected, it will invoke this method
  // (Not for HTTP)
  void OnOpen(connection_hdl hdl);

  // When a websocket client is disconnected, it will invoke this method
  void OnClose(connection_hdl hdl);

  // When a message is received from a websocket client, this method will
  // be invoked.
  //
  // The protocol between the client and the server is as follows:
  //
  // (1) The client connects to the server
  // (2) The client starts to send binary byte stream to the server.
  //     The byte stream can be broken into multiple messages or it can
  //     be put into a single message.
  //     The first message has to contain at least 8 bytes. The first
  //     4 bytes in little endian contains a int32_t indicating the
  //     sampling rate. The next 4 bytes in little endian contains a int32_t
  //     indicating total number of bytes of samples the client will send.
  //     We assume each sample is a float containing 4 bytes and has been
  //     normalized to the range [-1, 1].
  // (4) When the server receives all the samples from the client, it will
  //     start to decode them. Once decoded, the server sends a text message
  //     to the client containing the decoded results
  // (5) After receiving the decoded results from the server, if the client has
  //     another audio file to send, it repeats (2), (3), (4)
  // (6) If the client has no more audio files to decode, the client sends a
  //     text message containing "Done" to the server and closes the connection
  // (7) The server receives a text message "Done" and closes the connection
  //
  // Note:
  //  (a) All models in icefall use features extracted from audio samples
  //      normalized to the range [-1, 1]. Please send normalized audio samples
  //      if you use models from icefall.
  //  (b) Only sound files with a single channel is supported
  //  (c) Only audio samples are sent. For instance, if we want to decode
  //      a WAVE file, the RIFF header of the WAVE is not sent.
  void OnMessage(connection_hdl hdl, server::message_ptr msg);

  // Close a websocket connection with given code and reason
  void Close(connection_hdl hdl, websocketpp::close::status::value code,
             const std::string &reason);

 private:
  asio::io_context &io_conn_;
  asio::io_context &io_work_;
  server server_;

  std::map<connection_hdl, ConnectionDataPtr, std::owner_less<connection_hdl>>
      connections_;
  std::mutex mutex_;

  OfflineWebsocketServerConfig config_;

  std::ofstream log_;
  TeeStream tee_;

  OfflineWebsocketDecoder decoder_;
};

}  // namespace sherpa_onnx

#endif  // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_