zhaomingwork
Committed by GitHub

Add java websocket support (#137)

* add decode example for mic

* some changes to README.md

* add java websocket srv

* change to readwav to static

* make some changes to code comments

* little change for readme.md

* fix bug about multiple threads

* made little modification

* add protocol in readme, removed static Queue and add lmConfig

---------

Co-authored-by: root <root@localhost.localdomain>
@@ -7,11 +7,21 @@ LIB_FILES = \ @@ -7,11 +7,21 @@ LIB_FILES = \
7 $(LIB_SRC_DIR)/EndpointRule.java \ 7 $(LIB_SRC_DIR)/EndpointRule.java \
8 $(LIB_SRC_DIR)/EndpointConfig.java \ 8 $(LIB_SRC_DIR)/EndpointConfig.java \
9 $(LIB_SRC_DIR)/FeatureConfig.java \ 9 $(LIB_SRC_DIR)/FeatureConfig.java \
  10 + $(LIB_SRC_DIR)/OnlineLMConfig.java \
10 $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ 11 $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
11 $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ 12 $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
12 $(LIB_SRC_DIR)/OnlineStream.java \ 13 $(LIB_SRC_DIR)/OnlineStream.java \
13 $(LIB_SRC_DIR)/OnlineRecognizer.java \ 14 $(LIB_SRC_DIR)/OnlineRecognizer.java \
14 15
  16 +WEBSOCKET_DIR:= ./src/websocketsrv
  17 +WEBSOCKET_FILES = \
  18 + $(WEBSOCKET_DIR)/ConnectionData.java \
  19 + $(WEBSOCKET_DIR)/DecoderThreadHandler.java \
  20 + $(WEBSOCKET_DIR)/StreamThreadHandler.java \
  21 + $(WEBSOCKET_DIR)/AsrWebsocketServer.java \
  22 + $(WEBSOCKET_DIR)/AsrWebsocketClient.java \
  23 +
  24 +
15 LIB_BUILD_DIR = ./lib 25 LIB_BUILD_DIR = ./lib
16 26
17 27
@@ -39,7 +49,13 @@ buildmic: @@ -39,7 +49,13 @@ buildmic:
39 49
40 rebuild: clean all 50 rebuild: clean all
41 51
42 -.PHONY: clean run 52 +.PHONY: clean run downjar
  53 +
  54 +downjar:
  55 + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/
  56 + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/
  57 + wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/
  58 +
43 59
44 clean: 60 clean:
45 rm -frv $(BUILD_DIR)/* 61 rm -frv $(BUILD_DIR)/*
@@ -56,6 +72,12 @@ runmic: @@ -56,6 +72,12 @@ runmic:
56 72
57 java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic 73 java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
58 74
  75 +runsrv:
  76 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so ./modelconfig.cfg
  77 +
  78 +runclient:
  79 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
  80 +
59 buildlib: $(LIB_FILES:.java=.class) 81 buildlib: $(LIB_FILES:.java=.class)
60 82
61 83
@@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class) @@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class)
63 85
64 $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< 86 $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
65 87
  88 +buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
  89 +
  90 +
  91 +%.class: %.java
  92 +
  93 + $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $<
  94 +
66 packjar: 95 packjar:
67 jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . 96 jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
68 97
69 -all: clean buildlib packjar buildfile buildmic 98 +all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
  99 +
  100 +
70 101
71 102
72 103
@@ -2,6 +2,7 @@ @@ -2,6 +2,7 @@
2 -------------- 2 --------------
3 3
4 Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform. 4 Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform.
  5 +now support multiple threads for websocket server
5 6
6 ```xml 7 ```xml
7 Depend on: 8 Depend on:
@@ -35,10 +36,10 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: @@ -35,10 +36,10 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362:
35 36
36 3.Config model config.cfg 37 3.Config model config.cfg
37 ------------------------- 38 -------------------------
38 - 39 +/**change model path in config.cfg according to your env**/
39 ```xml 40 ```xml
40 - #model config  
41 - sample_rate=16000 41 + #model config
  42 + sample_rate=16000
42 feature_dim=80 43 feature_dim=80
43 rule1_min_trailing_silence=2.4 44 rule1_min_trailing_silence=2.4
44 rule2_min_trailing_silence=1.2 45 rule2_min_trailing_silence=1.2
@@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: @@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362:
51 enable_endpoint_detection=false 52 enable_endpoint_detection=false
52 decoding_method=greedy_search 53 decoding_method=greedy_search
53 max_active_paths=4 54 max_active_paths=4
  55 +
  56 + #websocket server config
  57 + port=8890
  58 + #number of threads pool for network io
  59 + connection_thread_num=16
  60 + #number of threads pool for stream
  61 + stream_thread_num=16
  62 + #number of threads pool for decoder
  63 + decoder_thread_num=16
  64 + #size of streams for parallel decoding
  65 + parallel_decoder_num=16
  66 + #time(ms) idle for decoder thread when no job
  67 + decoder_time_idle=10
  68 + #time(ms) out for connection data
  69 + deocder_time_out=3000
54 ``` 70 ```
55 71
56 --- 72 ---
@@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar @@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar
114 make runmic 130 make runmic
115 ``` 131 ```
116 132
  133 +---
117 134
  135 +6.WebSocket Server
  136 +----------
  137 +
  138 +support multiple threads for websocket server
  139 +6.0 Protocol for communication
  140 +1) client connect to server
  141 +```shell
  142 + ws client -> srv ws address
  143 + ws address example: ws://127.0.0.1:8889/
  144 +```
  145 +2) client send 16k pcm_s16le binary stream data to server
  146 +```shell
  147 + PCM sampleRate 16000
  148 + single channel
  149 + sampleSize 16bit
  150 + little endian
  151 + type short
  152 +```
  153 +3) client send "Done" text to server when all data is sent
  154 +```shell
  155 + ws_socket.send("Done")
  156 +```
  157 +4) client will receive json message from server whenever asr engine decoded new text
  158 +```shell
  159 + json example: {"text":"甚至出现交易几乎停滞的情况","eof":false"}
  160 +```
  161 +
  162 +
  163 +6.1 Build
  164 +
  165 +```bash
  166 + cd sherpa-onnx/java-api-examples
  167 + make all
  168 +```
  169 +
  170 +6.2 Run srv example
  171 +
  172 +usage: AsrWebsocketServer soPath modelCfgPath
  173 +
  174 +```bash
  175 + make runsrv /**change path in Makefile according to your env**/
  176 +```
  177 +
  178 +6.3 Run multiple threads client example
  179 +
  180 +usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads
  181 +
  182 +json result example: {"text":"甚至出现交易几乎停滞的情况","eof":"true"}
  183 +
  184 +```bash
  185 + make runclient /**change path in Makefile according to your env**/
  186 +```
118 187
@@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh- @@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-
9 joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx 9 joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
10 tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt 10 tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
11 num_threads=4 11 num_threads=4
12 -enable_endpoint_detection=false 12 +enable_endpoint_detection=true
13 decoding_method=modified_beam_search 13 decoding_method=modified_beam_search
14 -max_active_paths=4  
  14 +max_active_paths=4
  15 +lm_model=
  16 +lm_scale=0.5
  17 +
  18 +#websocket server config
  19 +port=8890
  20 +connection_thread_num=16
  21 +stream_thread_num=16
  22 +decoder_thread_num=16
  23 +parallel_decoder_num=16
  24 +decoder_time_idle=200
  25 +deocder_time_out=30000
@@ -49,7 +49,8 @@ public class DecodeFile { @@ -49,7 +49,8 @@ public class DecodeFile {
49 float rule3MinUtteranceLength = 20F; 49 float rule3MinUtteranceLength = 20F;
50 String decodingMethod = "greedy_search"; 50 String decodingMethod = "greedy_search";
51 int maxActivePaths = 4; 51 int maxActivePaths = 4;
52 - 52 + String lm_model="";
  53 + float lm_scale=0.5F;
53 rcgOjb = 54 rcgOjb =
54 new OnlineRecognizer( 55 new OnlineRecognizer(
55 tokens, 56 tokens,
@@ -64,6 +65,8 @@ public class DecodeFile { @@ -64,6 +65,8 @@ public class DecodeFile {
64 rule2MinTrailingSilence, 65 rule2MinTrailingSilence,
65 rule3MinUtteranceLength, 66 rule3MinUtteranceLength,
66 decodingMethod, 67 decodingMethod,
  68 + lm_model,
  69 + lm_scale,
67 maxActivePaths); 70 maxActivePaths);
68 streamObj = rcgOjb.createStream(); 71 streamObj = rcgOjb.createStream();
69 } catch (Exception e) { 72 } catch (Exception e) {
  1 +/*
  2 + * // Copyright 2022-2023 by zhaomingwork
  3 + */
  4 +// java AsrWebsocketClient
  5 +// usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads
  6 +package websocketsrv;
  7 +
  8 +import com.k2fsa.sherpa.onnx.OnlineRecognizer;
  9 +import java.net.URI;
  10 +import java.net.URISyntaxException;
  11 +import java.nio.*;
  12 +import java.util.Map;
  13 +import org.java_websocket.client.WebSocketClient;
  14 +import org.java_websocket.drafts.Draft;
  15 +import org.java_websocket.handshake.ServerHandshake;
  16 +import org.slf4j.Logger;
  17 +import org.slf4j.LoggerFactory;
  18 +
  19 +/** This example demonstrates how to connect to websocket server. */
  20 +public class AsrWebsocketClient extends WebSocketClient {
  21 + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketClient.class);
  22 +
  23 + public AsrWebsocketClient(URI serverUri, Draft draft) {
  24 + super(serverUri, draft);
  25 + }
  26 +
  27 + public AsrWebsocketClient(URI serverURI) {
  28 + super(serverURI);
  29 + }
  30 +
  31 + public AsrWebsocketClient(URI serverUri, Map<String, String> httpHeaders) {
  32 + super(serverUri, httpHeaders);
  33 + }
  34 +
  35 + @Override
  36 + public void onOpen(ServerHandshake handshakedata) {
  37 +
  38 + float[] floats = OnlineRecognizer.readWavFile(AsrWebsocketClient.wavPath);
  39 + ByteBuffer buffer =
  40 + ByteBuffer.allocate(4 * floats.length)
  41 + .order(ByteOrder.LITTLE_ENDIAN); // float is sizeof 4. allocate enough buffer
  42 +
  43 + for (float f : floats) {
  44 + buffer.putFloat(f);
  45 + }
  46 + buffer.rewind();
  47 + buffer.flip();
  48 + buffer.order(ByteOrder.LITTLE_ENDIAN);
  49 +
  50 + send(buffer.array()); // send buf to server
  51 + send("Done"); // send 'Done' means finished
  52 + }
  53 +
  54 + @Override
  55 + public void onMessage(String message) {
  56 +
  57 + logger.info("received: " + message);
  58 + }
  59 +
  60 + @Override
  61 + public void onClose(int code, String reason, boolean remote) {
  62 +
  63 + logger.info(
  64 + "Connection closed by "
  65 + + (remote ? "remote peer" : "us")
  66 + + " Code: "
  67 + + code
  68 + + " Reason: "
  69 + + reason);
  70 + }
  71 +
  72 + @Override
  73 + public void onError(Exception ex) {
  74 + ex.printStackTrace();
  75 + // if the error is fatal then onClose will be called additionally
  76 + }
  77 +
  78 + public static OnlineRecognizer rcgobj;
  79 + public static String wavPath;
  80 +
  81 + public static void main(String[] args) throws URISyntaxException {
  82 +
  83 + if (args.length != 5) {
  84 + System.out.println("usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads");
  85 + return;
  86 + }
  87 +
  88 + String soPath = args[0];
  89 + String srvIp = args[1];
  90 + String srvPort = args[2];
  91 + String wavPath = args[3];
  92 + int numThreads = Integer.parseInt(args[4]);
  93 + System.out.println("serIp=" + srvIp + ",srvPort=" + srvPort + ",wavPath=" + wavPath);
  94 +
  95 + class ClientThread implements Runnable {
  96 +
  97 + String soPath;
  98 + String srvIp;
  99 + String srvPort;
  100 + String wavPath;
  101 +
  102 + ClientThread(String soPath, String srvIp, String srvPort, String wavPath) {
  103 + this.soPath = soPath;
  104 + this.srvIp = srvIp;
  105 + this.srvPort = srvPort;
  106 + this.wavPath = wavPath;
  107 + }
  108 +
  109 + public void run() {
  110 + try {
  111 +
  112 + OnlineRecognizer.setSoPath(soPath);
  113 +
  114 + AsrWebsocketClient.wavPath = wavPath;
  115 +
  116 + String wsAddress = "ws://" + srvIp + ":" + srvPort;
  117 + AsrWebsocketClient c = new AsrWebsocketClient(new URI(wsAddress));
  118 +
  119 + c.connect();
  120 + } catch (Exception e) {
  121 + e.printStackTrace();
  122 + }
  123 + }
  124 + }
  125 + for (int i = 0; i < numThreads; i++) {
  126 + System.out.println("Thread1 is running...");
  127 + Thread t = new Thread(new ClientThread(soPath, srvIp, srvPort, wavPath));
  128 + t.start();
  129 + }
  130 + }
  131 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +// java websocketServer
  5 +// usage: AsrWebsocketServer soPath modelCfgPath
  6 +package websocketsrv;
  7 +
  8 +import com.k2fsa.sherpa.onnx.OnlineRecognizer;
  9 +import com.k2fsa.sherpa.onnx.OnlineStream;
  10 +import java.io.*;
  11 +import java.io.IOException;
  12 +import java.net.InetSocketAddress;
  13 +import java.net.UnknownHostException;
  14 +import java.nio.ByteBuffer;
  15 +import java.nio.ByteOrder;
  16 +import java.nio.FloatBuffer;
  17 +import java.util.*;
  18 +import java.util.Collections;
  19 +import java.util.concurrent.*;
  20 +import java.util.concurrent.LinkedBlockingQueue;
  21 +import org.java_websocket.WebSocket;
  22 +import org.java_websocket.drafts.Draft;
  23 +import org.java_websocket.drafts.Draft_6455;
  24 +import org.java_websocket.handshake.ClientHandshake;
  25 +import org.java_websocket.server.WebSocketServer;
  26 +import org.slf4j.Logger;
  27 +import org.slf4j.LoggerFactory;
  28 +
  29 +/**
  30 + * AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and
  31 + * one pool for asr decoder.
  32 + */
  33 +public class AsrWebsocketServer extends WebSocketServer {
  34 + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class);
  35 + // Queue between io network io thread pool and stream thread pool, use websocket as the key
  36 + private LinkedBlockingQueue<WebSocket> streamQueue = new LinkedBlockingQueue<WebSocket>();
  37 + // Queue waiting for deocdeing, use websocket as the key
  38 + private LinkedBlockingQueue<WebSocket> decoderQueue = new LinkedBlockingQueue<WebSocket>();
  39 +
  40 + // recogizer object
  41 + private OnlineRecognizer rcgOjb = null;
  42 +
  43 + // mapping between websocket connection and connection data
  44 + private ConcurrentHashMap<WebSocket, ConnectionData> connectionMap =
  45 + new ConcurrentHashMap<WebSocket, ConnectionData>();
  46 +
  47 + public AsrWebsocketServer(int port, int numThread) throws UnknownHostException {
  48 + // server port and num of threads for network io
  49 + super(new InetSocketAddress(port), numThread);
  50 + }
  51 +
  52 + public AsrWebsocketServer(InetSocketAddress address) {
  53 + super(address);
  54 + }
  55 +
  56 + public AsrWebsocketServer(int port, Draft_6455 draft) {
  57 + super(new InetSocketAddress(port), Collections.<Draft>singletonList(draft));
  58 + }
  59 +
  60 + @Override
  61 + public void onOpen(WebSocket conn, ClientHandshake handshake) {}
  62 +
  63 + @Override
  64 + public void onClose(WebSocket conn, int code, String reason, boolean remote) {
  65 + connectionMap.remove(conn);
  66 + logger.info(
  67 + conn
  68 + + " remove one connection!, now connection number="
  69 + + String.valueOf(connectionMap.size()));
  70 + }
  71 +
  72 + @Override
  73 + public void onMessage(WebSocket conn, String message) {
  74 + // this is text message
  75 + try {
  76 + // if rec "Done" msg from client
  77 + if (message.equals("Done")) {
  78 + ConnectionData connData = creatOrGetConnectionData(conn);
  79 + connData.setEof(true);
  80 + if (!streamQueueFind(conn)) {
  81 + streamQueue.put(conn);
  82 + }
  83 + }
  84 +
  85 + } catch (Exception e) {
  86 + e.printStackTrace();
  87 + }
  88 + }
  89 +
  90 + private ConnectionData creatOrGetConnectionData(WebSocket conn) {
  91 + // create a new connection data if not in connection map or return the existed one
  92 +
  93 + ConnectionData connData = null;
  94 + try {
  95 + if (!connectionMap.containsKey(conn)) {
  96 + OnlineStream stream = rcgOjb.createStream();
  97 + connData = new ConnectionData(conn, stream);
  98 + connectionMap.put(conn, connData);
  99 + } else {
  100 + connData = connectionMap.get(conn);
  101 + }
  102 +
  103 + logger.info(
  104 + conn.getRemoteSocketAddress().getAddress().getHostAddress()
  105 + + " open one connection,, now connection number="
  106 + + String.valueOf(connectionMap.size()));
  107 +
  108 + } catch (Exception e) {
  109 + System.err.println(e);
  110 + e.printStackTrace();
  111 + }
  112 + return connData;
  113 + }
  114 +
  115 + @Override
  116 + public void onMessage(WebSocket conn, ByteBuffer blob) {
  117 + try {
  118 +
  119 + // for handle binary data
  120 + blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian
  121 +
  122 + // set to float
  123 + FloatBuffer floatbuf = blob.asFloatBuffer();
  124 +
  125 + if (floatbuf.capacity() > 0) {
  126 + // allocate memory for float data
  127 + float[] arr = new float[floatbuf.capacity()];
  128 +
  129 + floatbuf.get(arr);
  130 + ConnectionData connData = creatOrGetConnectionData(conn);
  131 + // put websocket to stream queue with binary type==1
  132 + connData.addSamplesToData(arr);
  133 +
  134 + if (!streamQueueFind(conn)) {
  135 + streamQueue.put(conn);
  136 + }
  137 + }
  138 + } catch (Exception e) {
  139 + e.printStackTrace();
  140 + }
  141 + }
  142 +
  143 +
  144 +
  145 + public boolean streamQueueFind(WebSocket conn) {
  146 + return streamQueue.contains(conn);
  147 + }
  148 +
  149 + public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) {
  150 + try {
  151 +
  152 + rcgOjb = new OnlineRecognizer(cfgPath);
  153 + // size of stream thread pool
  154 + int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
  155 + // size of decoder thread pool
  156 + int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
  157 +
  158 + // time(ms) idle for decoder thread when no job
  159 + int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
  160 + // size of streams for parallel decoding
  161 + int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
  162 + // time(ms) out for connection data
  163 + int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
  164 +
  165 + // create stream threads
  166 + for (int i = 0; i < streamThreadNum; i++) {
  167 + new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start();
  168 + }
  169 + // create decoder threads
  170 + for (int i = 0; i < decoderThreadNum; i++) {
  171 + new DecoderThreadHandler(
  172 + decoderQueue,
  173 + connectionMap,
  174 + rcgOjb,
  175 + decoderTimeIdle,
  176 + parallelDecoderNum,
  177 + deocderTimeOut)
  178 + .start();
  179 + }
  180 + } catch (Exception e) {
  181 + System.err.println(e);
  182 + e.printStackTrace();
  183 + }
  184 + }
  185 +
  186 + public static Map<String, String> readProperties(String CfgPath) {
  187 + // read and parse config file
  188 + Properties props = new Properties();
  189 + Map<String, String> proMap = new HashMap<String, String>();
  190 + try {
  191 +
  192 + File file = new File(CfgPath);
  193 + if (!file.exists()) {
  194 + logger.info(String.valueOf(CfgPath) + " cfg file not exists!");
  195 + System.exit(0);
  196 + }
  197 + InputStream in = new BufferedInputStream(new FileInputStream(CfgPath));
  198 + props.load(in);
  199 + Enumeration en = props.propertyNames();
  200 + while (en.hasMoreElements()) {
  201 + String key = (String) en.nextElement();
  202 + String Property = props.getProperty(key);
  203 + proMap.put(key, Property);
  204 + }
  205 +
  206 + } catch (Exception e) {
  207 + e.printStackTrace();
  208 + }
  209 + return proMap;
  210 + }
  211 +
  212 + public static void main(String[] args) throws InterruptedException, IOException {
  213 + if (args.length != 2) {
  214 + logger.info("usage: AsrWebsocketServer soPath modelCfgPath");
  215 +
  216 + return;
  217 + }
  218 +
  219 + String soPath = args[0];
  220 + String cfgPath = args[1];
  221 +
  222 + OnlineRecognizer.setSoPath(soPath);
  223 +
  224 + Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
  225 + int port = Integer.valueOf(cfgMap.get("port"));
  226 +
  227 + int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
  228 + AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
  229 + s.initModelWithCfg(cfgMap, cfgPath);
  230 + logger.info("Server started on port: " + s.getPort());
  231 + s.start();
  232 + }
  233 +
  234 + @Override
  235 + public void onError(WebSocket conn, Exception ex) {
  236 + ex.printStackTrace();
  237 + if (conn != null) {
  238 + // some errors like port binding failed may not be assignable to a specific websocket
  239 + }
  240 + }
  241 +
  242 + @Override
  243 + public void onStart() {
  244 + logger.info("Server started!");
  245 + setConnectionLostTimeout(0);
  246 + setConnectionLostTimeout(100);
  247 + }
  248 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +// connection data act as a bridge between different threads pools
  5 +
  6 +package websocketsrv;
  7 +
  8 +import com.k2fsa.sherpa.onnx.OnlineStream;
  9 +import java.time.LocalDateTime;
  10 +import java.util.LinkedList;
  11 +import java.util.Queue;
  12 +import java.util.concurrent.*;
  13 +import org.java_websocket.WebSocket;
  14 +
  15 +public class ConnectionData {
  16 +
  17 + private WebSocket webSocket; // the websocket for this connection data
  18 +
  19 + private OnlineStream stream; // connection stream
  20 +
  21 + private Queue<float[]> queueSamples =
  22 + new LinkedList<float[]>(); // binary data rec from the client
  23 +
  24 + private boolean eof = false; // connection data is done
  25 +
  26 + private LocalDateTime lastHandleTime; // used for time out in ms
  27 +
  28 + public ConnectionData(WebSocket webSocket, OnlineStream stream) {
  29 + this.webSocket = webSocket;
  30 +
  31 + this.stream = stream;
  32 + }
  33 +
  34 + public void addSamplesToData(float[] samples) {
  35 + this.queueSamples.add(samples);
  36 + }
  37 +
  38 + public LocalDateTime getLastHandleTime() {
  39 + return this.lastHandleTime;
  40 + }
  41 +
  42 + public void setLastHandleTime(LocalDateTime now) {
  43 + this.lastHandleTime = now;
  44 + }
  45 +
  46 + public boolean getEof() {
  47 + return this.eof;
  48 + }
  49 +
  50 + public void setEof(boolean eof) {
  51 + this.eof = eof;
  52 + }
  53 +
  54 + public WebSocket getWebSocket() {
  55 + return this.webSocket;
  56 + }
  57 +
  58 + public Queue<float[]> getQueueSamples() {
  59 + return this.queueSamples;
  60 + }
  61 +
  62 + public OnlineStream getStream() {
  63 + return this.stream;
  64 + }
  65 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +// java DecoderThreadHandler
  5 +package websocketsrv;
  6 +
  7 +import com.k2fsa.sherpa.onnx.OnlineRecognizer;
  8 +import com.k2fsa.sherpa.onnx.OnlineStream;
  9 +import java.nio.*;
  10 +import java.nio.charset.StandardCharsets;
  11 +import java.time.LocalDateTime;
  12 +import java.util.*;
  13 +import java.util.List;
  14 +import java.util.concurrent.*;
  15 +import java.util.concurrent.LinkedBlockingQueue;
  16 +import org.java_websocket.WebSocket;
  17 +import org.java_websocket.drafts.Draft;
  18 +import org.java_websocket.framing.Framedata;
  19 +import org.slf4j.Logger;
  20 +import org.slf4j.LoggerFactory;
  21 +
  22 +public class DecoderThreadHandler extends Thread {
  23 + private static final Logger logger = LoggerFactory.getLogger(DecoderThreadHandler.class);
  24 + // Websocket Queue that waiting for decoding
  25 + private LinkedBlockingQueue<WebSocket> decoderQueue;
  26 + // the mapping between websocket and connection data
  27 + private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
  28 +
  29 + private OnlineRecognizer rcgOjb = null; // recgnizer object
  30 +
  31 + // connection data list for this thread to decode in parallel
  32 + private List<ConnectionData> connDataList = new ArrayList<ConnectionData>();
  33 +
  34 + private int parallelDecoderNum = 10; // parallel decoding number
  35 + private int deocderTimeIdle = 10; // idle time(ms) when no job
  36 + private int deocderTimeOut = 3000; // if it is timeout(ms), the connection data will be removed
  37 +
  38 + public DecoderThreadHandler(
  39 + LinkedBlockingQueue<WebSocket> decoderQueue,
  40 + ConcurrentHashMap<WebSocket, ConnectionData> connMap,
  41 + OnlineRecognizer rcgOjb,
  42 + int deocderTimeIdle,
  43 + int parallelDecoderNum,
  44 + int deocderTimeOut) {
  45 + this.decoderQueue = decoderQueue;
  46 + this.connMap = connMap;
  47 + this.rcgOjb = rcgOjb;
  48 + this.deocderTimeIdle = deocderTimeIdle;
  49 + this.parallelDecoderNum = parallelDecoderNum;
  50 + this.deocderTimeOut = deocderTimeOut;
  51 + }
  52 +
  53 + public void run() {
  54 + while (true) {
  55 + try {
  56 + // time(ms) idle if there is no job
  57 +
  58 + Thread.sleep(deocderTimeIdle);
  59 + // clear data list for this threads
  60 + connDataList.clear();
  61 + if (rcgOjb == null) continue;
  62 +
  63 + // loop for total decoder Queue
  64 + while (!decoderQueue.isEmpty()) {
  65 +
  66 + // get websocket
  67 + WebSocket conn = decoderQueue.take();
  68 + // get connection data according to websocket
  69 + ConnectionData connData = connMap.get(conn);
  70 +
  71 + // if the websocket closed, continue
  72 + if (connData == null) continue;
  73 + // get the stream
  74 + OnlineStream stream = connData.getStream();
  75 +
  76 + // put to decoder list if 1) stream is ready; 2) and
  77 + // size not > parallelDecoderNum
  78 + if ((rcgOjb.isReady(stream) && connDataList.size() < parallelDecoderNum)) {
  79 +
  80 + // add to this thread's decoder list
  81 + connDataList.add(connData);
  82 + // change the handled time for this connection data
  83 + connData.setLastHandleTime(LocalDateTime.now());
  84 + }
  85 + // break when decoder list size >= parallelDecoderNum
  86 + if (connDataList.size() >= parallelDecoderNum) {
  87 + break;
  88 + }
  89 + }
  90 +
  91 + // if decoder data list for this thread >0
  92 + if (connDataList.size() > 0) {
  93 +
  94 + // create a stream array for parallel decoding
  95 + OnlineStream[] arr = new OnlineStream[connDataList.size()];
  96 + for (int i = 0; i < connDataList.size(); i++) {
  97 +
  98 + arr[i] = connDataList.get(i).getStream();
  99 + }
  100 +
  101 + // parallel decoding
  102 + rcgOjb.decodeStreams(arr);
  103 + }
  104 +
  105 + // get result for each connection
  106 + for (ConnectionData connData : connDataList) {
  107 +
  108 + OnlineStream stream = connData.getStream();
  109 + WebSocket webSocket = connData.getWebSocket();
  110 +
  111 + String txtResult = rcgOjb.getResult(stream);
  112 +
  113 + // decode text in utf-8
  114 + byte[] utf8Data = txtResult.getBytes(StandardCharsets.UTF_8);
  115 +
  116 + boolean isEof = (connData.getEof() == true && !rcgOjb.isReady(stream));
  117 + // result
  118 + if (utf8Data.length > 0) {
  119 +
  120 + String jsonResult =
  121 + "{\"text\":\"" + txtResult + "\",\"eof\":" + String.valueOf(isEof) + "\"}";
  122 +
  123 + if (webSocket.isOpen()) {
  124 + // create a TEXT Frame for send back json result
  125 + Draft draft = webSocket.getDraft();
  126 + List<Framedata> frames = null;
  127 + frames = draft.createFrames(jsonResult, false);
  128 + // send to client
  129 + webSocket.sendFrame(frames);
  130 + }
  131 + }
  132 + }
  133 + // loop for each connection data in this thread
  134 + for (ConnectionData connData : connDataList) {
  135 + OnlineStream stream = connData.getStream();
  136 + WebSocket webSocket = connData.getWebSocket();
  137 + // if the stream is still ready, put it to decoder Queue again for next decoding
  138 + if (rcgOjb.isReady(stream)) {
  139 + decoderQueue.put(webSocket);
  140 + }
  141 + // the duration between last handled time and now
  142 + java.time.Duration duration =
  143 + java.time.Duration.between(connData.getLastHandleTime(), LocalDateTime.now());
  144 + // close the websocket if 1) data is done and stream not ready; 2) or data is time out;
  145 + // 3) or
  146 + // connection is closed
  147 + if ((connData.getEof() == true
  148 + && !rcgOjb.isReady(stream)
  149 + && connData.getQueueSamples().isEmpty())
  150 + || duration.toMillis() > deocderTimeOut
  151 + || !connData.getWebSocket().isOpen()) {
  152 +
  153 + logger.info("close websocket!!!");
  154 +
  155 + // delay close web socket as data may still in processing
  156 + Timer timer = new Timer();
  157 + timer.schedule(
  158 + new TimerTask() {
  159 + public void run() {
  160 +
  161 + webSocket.close();
  162 + }
  163 + },
  164 + 5000); // 5 seconds
  165 + }
  166 + }
  167 +
  168 + } catch (Exception e) {
  169 + e.printStackTrace();
  170 + }
  171 + }
  172 + }
  173 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +// java StreamThreadHandler
  5 +package websocketsrv;
  6 +
  7 +import com.k2fsa.sherpa.onnx.OnlineStream;
  8 +import java.nio.*;
  9 +import java.util.*;
  10 +import java.util.concurrent.*;
  11 +import java.util.concurrent.LinkedBlockingQueue;
  12 +import org.java_websocket.WebSocket;
  13 +// thread for processing stream
  14 +
  15 +public class StreamThreadHandler extends Thread {
  16 + // Queue between io network io thread pool and stream thread pool, use websocket as the key
  17 + private LinkedBlockingQueue<WebSocket> streamQueue;
  18 + // Queue waiting for deocdeing, use websocket as the key
  19 + private LinkedBlockingQueue<WebSocket> decoderQueue;
  20 + // mapping between websocket connection and connection data
  21 + private ConcurrentHashMap<WebSocket, ConnectionData> connMap;
  22 +
  23 + public StreamThreadHandler(
  24 + LinkedBlockingQueue<WebSocket> streamQueue,
  25 + LinkedBlockingQueue<WebSocket> decoderQueue,
  26 + ConcurrentHashMap<WebSocket, ConnectionData> connMap) {
  27 + this.streamQueue = streamQueue;
  28 + this.decoderQueue = decoderQueue;
  29 + this.connMap = connMap;
  30 + }
  31 +
  32 + public void run() {
  33 + while (true) {
  34 + try {
  35 + // fetch one websocket from queue
  36 + WebSocket conn = (WebSocket) this.streamQueue.take();
  37 + // get the connection data according to websocket
  38 + ConnectionData connData = connMap.get(conn);
  39 + OnlineStream stream = connData.getStream();
  40 +
  41 + // handle received binary data
  42 + if (!connData.getQueueSamples().isEmpty()) {
  43 + // loop to put all received binary data to stream
  44 + while (!connData.getQueueSamples().isEmpty()) {
  45 +
  46 + float[] samples = connData.getQueueSamples().poll();
  47 +
  48 + stream.acceptWaveform(samples);
  49 + }
  50 + // if data is finished
  51 + if (connData.getEof() == true) {
  52 +
  53 + stream.inputFinished();
  54 + }
  55 + // add this websocket to decoder Queue if not in the Queue
  56 + if (!decoderQueue.contains(conn)) {
  57 +
  58 + decoderQueue.put(conn);
  59 + }
  60 + }
  61 +
  62 + } catch (Exception e) {
  63 + e.printStackTrace();
  64 + }
  65 + }
  66 + }
  67 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +
  5 +package com.k2fsa.sherpa.onnx;
  6 +
  7 +public class OnlineLMConfig {
  8 + private final String model;
  9 + private final float scale;
  10 +
  11 + public OnlineLMConfig(String model, float scale) {
  12 + this.model = model;
  13 + this.scale = scale;
  14 + }
  15 +
  16 + public String getModel() {
  17 + return model;
  18 + }
  19 +
  20 + public float getScale() {
  21 + return scale;
  22 + }
  23 +}
@@ -65,11 +65,14 @@ public class OnlineRecognizer { @@ -65,11 +65,14 @@ public class OnlineRecognizer {
65 false); 65 false);
66 FeatureConfig featConfig = 66 FeatureConfig featConfig =
67 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 67 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
68 - OnlineRecognizerConfig rcgCfg = 68 + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
  69 +
  70 + OnlineRecognizerConfig rcgCfg =
69 new OnlineRecognizerConfig( 71 new OnlineRecognizerConfig(
70 featConfig, 72 featConfig,
71 modelCfg, 73 modelCfg,
72 endCfg, 74 endCfg,
  75 + onlineLmConfig,
73 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), 76 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
74 proMap.get("decoding_method").trim(), 77 proMap.get("decoding_method").trim(),
75 Integer.parseInt(proMap.get("max_active_paths").trim())); 78 Integer.parseInt(proMap.get("max_active_paths").trim()));
@@ -107,11 +110,15 @@ public class OnlineRecognizer { @@ -107,11 +110,15 @@ public class OnlineRecognizer {
107 false); 110 false);
108 FeatureConfig featConfig = 111 FeatureConfig featConfig =
109 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 112 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
110 - OnlineRecognizerConfig rcgCfg = 113 +
  114 + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));
  115 +
  116 + OnlineRecognizerConfig rcgCfg =
111 new OnlineRecognizerConfig( 117 new OnlineRecognizerConfig(
112 featConfig, 118 featConfig,
113 modelCfg, 119 modelCfg,
114 endCfg, 120 endCfg,
  121 + onlineLmConfig,
115 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), 122 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
116 proMap.get("decoding_method").trim(), 123 proMap.get("decoding_method").trim(),
117 Integer.parseInt(proMap.get("max_active_paths").trim())); 124 Integer.parseInt(proMap.get("max_active_paths").trim()));
@@ -137,6 +144,8 @@ public class OnlineRecognizer { @@ -137,6 +144,8 @@ public class OnlineRecognizer {
137 float rule2MinTrailingSilence, 144 float rule2MinTrailingSilence,
138 float rule3MinUtteranceLength, 145 float rule3MinUtteranceLength,
139 String decodingMethod, 146 String decodingMethod,
  147 + String lm_model,
  148 + float lm_scale,
140 int maxActivePaths) { 149 int maxActivePaths) {
141 this.sampleRate = sampleRate; 150 this.sampleRate = sampleRate;
142 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); 151 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
@@ -146,14 +155,10 @@ public class OnlineRecognizer { @@ -146,14 +155,10 @@ public class OnlineRecognizer {
146 OnlineTransducerModelConfig modelCfg = 155 OnlineTransducerModelConfig modelCfg =
147 new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); 156 new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false);
148 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); 157 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
149 - OnlineRecognizerConfig rcgCfg = 158 + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);
  159 + OnlineRecognizerConfig rcgCfg =
150 new OnlineRecognizerConfig( 160 new OnlineRecognizerConfig(
151 - featConfig,  
152 - modelCfg,  
153 - endCfg,  
154 - enableEndpointDetection,  
155 - decodingMethod,  
156 - maxActivePaths); 161 + featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths);
157 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 162 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
158 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); 163 this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
159 } 164 }
@@ -241,7 +246,7 @@ public class OnlineRecognizer { @@ -241,7 +246,7 @@ public class OnlineRecognizer {
241 return stream; 246 return stream;
242 } 247 }
243 248
244 - public float[] readWavFile(String fileName) { 249 + public static float[] readWavFile(String fileName) {
245 // read data from the filename 250 // read data from the filename
246 Object[] wavdata = readWave(fileName); 251 Object[] wavdata = readWave(fileName);
247 Object data = wavdata[0]; // data[0] is float data, data[1] sample rate 252 Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
@@ -281,7 +286,7 @@ public class OnlineRecognizer { @@ -281,7 +286,7 @@ public class OnlineRecognizer {
281 } 286 }
282 // JNI interface libsherpa-onnx-jni.so 287 // JNI interface libsherpa-onnx-jni.so
283 288
284 - private native Object[] readWave(String fileName); 289 + private static native Object[] readWave(String fileName); // static
285 290
286 private native String getResult(long ptr, long streamPtr); 291 private native String getResult(long ptr, long streamPtr);
287 292
@@ -8,25 +8,33 @@ public class OnlineRecognizerConfig { @@ -8,25 +8,33 @@ public class OnlineRecognizerConfig {
8 private final FeatureConfig featConfig; 8 private final FeatureConfig featConfig;
9 private final OnlineTransducerModelConfig modelConfig; 9 private final OnlineTransducerModelConfig modelConfig;
10 private final EndpointConfig endpointConfig; 10 private final EndpointConfig endpointConfig;
  11 + private final OnlineLMConfig lmConfig;
11 private final boolean enableEndpoint; 12 private final boolean enableEndpoint;
12 private final String decodingMethod; 13 private final String decodingMethod;
13 private final int maxActivePaths; 14 private final int maxActivePaths;
  15 +
14 16
15 public OnlineRecognizerConfig( 17 public OnlineRecognizerConfig(
16 FeatureConfig featConfig, 18 FeatureConfig featConfig,
17 OnlineTransducerModelConfig modelConfig, 19 OnlineTransducerModelConfig modelConfig,
18 EndpointConfig endpointConfig, 20 EndpointConfig endpointConfig,
  21 + OnlineLMConfig lmConfig,
19 boolean enableEndpoint, 22 boolean enableEndpoint,
20 String decodingMethod, 23 String decodingMethod,
21 int maxActivePaths) { 24 int maxActivePaths) {
22 this.featConfig = featConfig; 25 this.featConfig = featConfig;
23 this.modelConfig = modelConfig; 26 this.modelConfig = modelConfig;
24 this.endpointConfig = endpointConfig; 27 this.endpointConfig = endpointConfig;
  28 + this.lmConfig = lmConfig;
25 this.enableEndpoint = enableEndpoint; 29 this.enableEndpoint = enableEndpoint;
26 this.decodingMethod = decodingMethod; 30 this.decodingMethod = decodingMethod;
27 this.maxActivePaths = maxActivePaths; 31 this.maxActivePaths = maxActivePaths;
28 } 32 }
29 33
  34 + public OnlineLMConfig getLmConfig() {
  35 + return lmConfig;
  36 + }
  37 +
30 public FeatureConfig getFeatConfig() { 38 public FeatureConfig getFeatConfig() {
31 return featConfig; 39 return featConfig;
32 } 40 }