AsrWebsocketServer.java 7.8 KB
/*
 * // Copyright 2022-2023 by zhaoming
 */
// java websocketServer
// usage: AsrWebsocketServer soPath modelCfgPath
package websocketsrv;

import com.k2fsa.sherpa.onnx.OnlineRecognizer;
import com.k2fsa.sherpa.onnx.OnlineStream;
import java.io.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.*;
import java.util.Collections;
import java.util.concurrent.*;
import java.util.concurrent.LinkedBlockingQueue;
import org.java_websocket.WebSocket;
import org.java_websocket.drafts.Draft;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and
 * one pool for asr decoder.
 */
public class AsrWebsocketServer extends WebSocketServer {
  private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class);
  //  Queue between io network io thread pool and stream thread pool, use websocket as the key
  private LinkedBlockingQueue<WebSocket> streamQueue = new LinkedBlockingQueue<WebSocket>();
  //  Queue waiting for deocdeing, use websocket as the key
  private LinkedBlockingQueue<WebSocket> decoderQueue = new LinkedBlockingQueue<WebSocket>();

  // recogizer object
  private OnlineRecognizer rcgOjb = null;

  // mapping between websocket connection and connection data
  private ConcurrentHashMap<WebSocket, ConnectionData> connectionMap =
      new ConcurrentHashMap<WebSocket, ConnectionData>();

  public AsrWebsocketServer(int port, int numThread) throws UnknownHostException {
    // server port and num of threads for  network io
    super(new InetSocketAddress(port), numThread);
  }

  public AsrWebsocketServer(InetSocketAddress address) {
    super(address);
  }

  public AsrWebsocketServer(int port, Draft_6455 draft) {
    super(new InetSocketAddress(port), Collections.<Draft>singletonList(draft));
  }

  @Override
  public void onOpen(WebSocket conn, ClientHandshake handshake) {}

  @Override
  public void onClose(WebSocket conn, int code, String reason, boolean remote) {
    connectionMap.remove(conn);
    logger.info(
        conn
            + " remove one connection!, now connection number="
            + String.valueOf(connectionMap.size()));
  }

  @Override
  public void onMessage(WebSocket conn, String message) {
    // this is text message
    try {
      // if rec "Done" msg from client
      if (message.equals("Done")) {
        ConnectionData connData = creatOrGetConnectionData(conn);
        connData.setEof(true);
        if (!streamQueueFind(conn)) {
          streamQueue.put(conn);
        }
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  private ConnectionData creatOrGetConnectionData(WebSocket conn) {
    // create a new connection data if not in connection map or return the existed one

    ConnectionData connData = null;
    try {
      if (!connectionMap.containsKey(conn)) {
        OnlineStream stream = rcgOjb.createStream();
        connData = new ConnectionData(conn, stream);
        connectionMap.put(conn, connData);
      } else {
        connData = connectionMap.get(conn);
      }

      logger.info(
          conn.getRemoteSocketAddress().getAddress().getHostAddress()
              + " open one connection,, now connection number="
              + String.valueOf(connectionMap.size()));

    } catch (Exception e) {
      System.err.println(e);
      e.printStackTrace();
    }
    return connData;
  }

  @Override
  public void onMessage(WebSocket conn, ByteBuffer blob) {
    try {

      // for handle binary data
      blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian

      // set to float
      FloatBuffer floatbuf = blob.asFloatBuffer();

      if (floatbuf.capacity() > 0) {
        // allocate memory for float data
        float[] arr = new float[floatbuf.capacity()];

        floatbuf.get(arr);
        ConnectionData connData = creatOrGetConnectionData(conn);
        // put websocket  to stream queue with binary type==1
        connData.addSamplesToData(arr);

        if (!streamQueueFind(conn)) {
          streamQueue.put(conn);
        }
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  public boolean streamQueueFind(WebSocket conn) {
    return streamQueue.contains(conn);
  }

  public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) {
    try {

      rcgOjb = new OnlineRecognizer(cfgPath);
      // size of stream thread pool
      int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
      // size of decoder thread pool
      int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));

      // time(ms) idle for decoder thread when no job
      int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
      // size of streams for parallel decoding
      int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
      // time(ms) out for connection data
      int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));

      // create stream threads
      for (int i = 0; i < streamThreadNum; i++) {
        new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start();
      }
      // create decoder threads
      for (int i = 0; i < decoderThreadNum; i++) {
        new DecoderThreadHandler(
                decoderQueue,
                connectionMap,
                rcgOjb,
                decoderTimeIdle,
                parallelDecoderNum,
                deocderTimeOut)
            .start();
      }
    } catch (Exception e) {
      System.err.println(e);
      e.printStackTrace();
    }
  }

  public static Map<String, String> readProperties(String CfgPath) {
    // read and parse config file
    Properties props = new Properties();
    Map<String, String> proMap = new HashMap<String, String>();
    try {

      File file = new File(CfgPath);
      if (!file.exists()) {
        logger.info(String.valueOf(CfgPath) + " cfg file not exists!");
        System.exit(0);
      }
      InputStream in = new BufferedInputStream(new FileInputStream(CfgPath));
      props.load(in);
      Enumeration en = props.propertyNames();
      while (en.hasMoreElements()) {
        String key = (String) en.nextElement();
        String Property = props.getProperty(key);
        proMap.put(key, Property);
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
    return proMap;
  }

  public static void main(String[] args) throws InterruptedException, IOException {
    if (args.length != 2) {
      logger.info("usage: AsrWebsocketServer soPath modelCfgPath");

      return;
    }

    String soPath = args[0];
    String cfgPath = args[1];

    OnlineRecognizer.setSoPath(soPath);
    logger.info("readProperties");
    Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
    int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));

    int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
    AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
    logger.info("initModelWithCfg");
    s.initModelWithCfg(cfgMap, cfgPath);
    logger.info("Server started on port: " + s.getPort());
    s.start();
  }

  @Override
  public void onError(WebSocket conn, Exception ex) {
    ex.printStackTrace();
    if (conn != null) {
      // some errors like port binding failed may not be assignable to a specific websocket
    }
  }

  @Override
  public void onStart() {
    logger.info("Server started!");
    setConnectionLostTimeout(0);
    setConnectionLostTimeout(100);
  }
}