Committed by
GitHub
Add java websocket support (#137)
* add decode example for mic * some changes to README.md * add java websocket srv * change to readwav to static * make some changes to code comments * little change for readme.md * fix bug about multiple threads * made little modification * add protocol in readme, removed static Queue and add lmConfig --------- Co-authored-by: root <root@localhost.localdomain>
正在显示
12 个修改的文件
包含
853 行增加
和
19 行删除
| @@ -7,11 +7,21 @@ LIB_FILES = \ | @@ -7,11 +7,21 @@ LIB_FILES = \ | ||
| 7 | $(LIB_SRC_DIR)/EndpointRule.java \ | 7 | $(LIB_SRC_DIR)/EndpointRule.java \ |
| 8 | $(LIB_SRC_DIR)/EndpointConfig.java \ | 8 | $(LIB_SRC_DIR)/EndpointConfig.java \ |
| 9 | $(LIB_SRC_DIR)/FeatureConfig.java \ | 9 | $(LIB_SRC_DIR)/FeatureConfig.java \ |
| 10 | + $(LIB_SRC_DIR)/OnlineLMConfig.java \ | ||
| 10 | $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ | 11 | $(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \ |
| 11 | $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ | 12 | $(LIB_SRC_DIR)/OnlineRecognizerConfig.java \ |
| 12 | $(LIB_SRC_DIR)/OnlineStream.java \ | 13 | $(LIB_SRC_DIR)/OnlineStream.java \ |
| 13 | $(LIB_SRC_DIR)/OnlineRecognizer.java \ | 14 | $(LIB_SRC_DIR)/OnlineRecognizer.java \ |
| 14 | 15 | ||
| 16 | +WEBSOCKET_DIR:= ./src/websocketsrv | ||
| 17 | +WEBSOCKET_FILES = \ | ||
| 18 | + $(WEBSOCKET_DIR)/ConnectionData.java \ | ||
| 19 | + $(WEBSOCKET_DIR)/DecoderThreadHandler.java \ | ||
| 20 | + $(WEBSOCKET_DIR)/StreamThreadHandler.java \ | ||
| 21 | + $(WEBSOCKET_DIR)/AsrWebsocketServer.java \ | ||
| 22 | + $(WEBSOCKET_DIR)/AsrWebsocketClient.java \ | ||
| 23 | + | ||
| 24 | + | ||
| 15 | LIB_BUILD_DIR = ./lib | 25 | LIB_BUILD_DIR = ./lib |
| 16 | 26 | ||
| 17 | 27 | ||
| @@ -39,7 +49,13 @@ buildmic: | @@ -39,7 +49,13 @@ buildmic: | ||
| 39 | 49 | ||
| 40 | rebuild: clean all | 50 | rebuild: clean all |
| 41 | 51 | ||
| 42 | -.PHONY: clean run | 52 | +.PHONY: clean run downjar |
| 53 | + | ||
| 54 | +downjar: | ||
| 55 | + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar -P ./lib/ | ||
| 56 | + wget https://repo1.maven.org/maven2/org/slf4j/slf4j-simple/1.7.25/slf4j-simple-1.7.25.jar -P ./lib/ | ||
| 57 | + wget https://github.com/TooTallNate/Java-WebSocket/releases/download/v1.5.3/Java-WebSocket-1.5.3.jar -P ./lib/ | ||
| 58 | + | ||
| 43 | 59 | ||
| 44 | clean: | 60 | clean: |
| 45 | rm -frv $(BUILD_DIR)/* | 61 | rm -frv $(BUILD_DIR)/* |
| @@ -56,6 +72,12 @@ runmic: | @@ -56,6 +72,12 @@ runmic: | ||
| 56 | 72 | ||
| 57 | java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic | 73 | java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic |
| 58 | 74 | ||
| 75 | +runsrv: | ||
| 76 | + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so ./modelconfig.cfg | ||
| 77 | + | ||
| 78 | +runclient: | ||
| 79 | + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient /sherpa-onnx/20230515/zhaoming/sherpa-onnx/build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 | ||
| 80 | + | ||
| 59 | buildlib: $(LIB_FILES:.java=.class) | 81 | buildlib: $(LIB_FILES:.java=.class) |
| 60 | 82 | ||
| 61 | 83 | ||
| @@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class) | @@ -63,10 +85,19 @@ buildlib: $(LIB_FILES:.java=.class) | ||
| 63 | 85 | ||
| 64 | $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< | 86 | $(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $< |
| 65 | 87 | ||
| 88 | +buildwebsocket: $(WEBSOCKET_FILES:.java=.class) | ||
| 89 | + | ||
| 90 | + | ||
| 91 | +%.class: %.java | ||
| 92 | + | ||
| 93 | + $(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $< | ||
| 94 | + | ||
| 66 | packjar: | 95 | packjar: |
| 67 | jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . | 96 | jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) . |
| 68 | 97 | ||
| 69 | -all: clean buildlib packjar buildfile buildmic | 98 | +all: clean buildlib packjar buildfile buildmic downjar buildwebsocket |
| 99 | + | ||
| 100 | + | ||
| 70 | 101 | ||
| 71 | 102 | ||
| 72 | 103 |
| @@ -2,6 +2,7 @@ | @@ -2,6 +2,7 @@ | ||
| 2 | -------------- | 2 | -------------- |
| 3 | 3 | ||
| 4 | Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform. | 4 | Java wrapper `com.k2fsa.sherpa.onnx.OnlineRecognizer` for `sherpa-onnx`. Java is a cross-platform language; you can build jni .so lib according to your system, and then use the same java api for all your platform. |
| 5 | +now support multiple threads for websocket server | ||
| 5 | 6 | ||
| 6 | ```xml | 7 | ```xml |
| 7 | Depend on: | 8 | Depend on: |
| @@ -35,10 +36,10 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: | @@ -35,10 +36,10 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: | ||
| 35 | 36 | ||
| 36 | 3.Config model config.cfg | 37 | 3.Config model config.cfg |
| 37 | ------------------------- | 38 | ------------------------- |
| 38 | - | 39 | +/**change model path in config.cfg according to your env**/ |
| 39 | ```xml | 40 | ```xml |
| 40 | - #model config | ||
| 41 | - sample_rate=16000 | 41 | + #model config |
| 42 | + sample_rate=16000 | ||
| 42 | feature_dim=80 | 43 | feature_dim=80 |
| 43 | rule1_min_trailing_silence=2.4 | 44 | rule1_min_trailing_silence=2.4 |
| 44 | rule2_min_trailing_silence=1.2 | 45 | rule2_min_trailing_silence=1.2 |
| @@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: | @@ -51,6 +52,21 @@ Example for Ubuntu 18.04 LTS, Openjdk 1.8.0_362: | ||
| 51 | enable_endpoint_detection=false | 52 | enable_endpoint_detection=false |
| 52 | decoding_method=greedy_search | 53 | decoding_method=greedy_search |
| 53 | max_active_paths=4 | 54 | max_active_paths=4 |
| 55 | + | ||
| 56 | + #websocket server config | ||
| 57 | + port=8890 | ||
| 58 | + #number of threads pool for network io | ||
| 59 | + connection_thread_num=16 | ||
| 60 | + #number of threads pool for stream | ||
| 61 | + stream_thread_num=16 | ||
| 62 | + #number of threads pool for decoder | ||
| 63 | + decoder_thread_num=16 | ||
| 64 | + #size of streams for parallel decoding | ||
| 65 | + parallel_decoder_num=16 | ||
| 66 | + #time(ms) idle for decoder thread when no job | ||
| 67 | + decoder_time_idle=10 | ||
| 68 | + #time(ms) out for connection data | ||
| 69 | + deocder_time_out=3000 | ||
| 54 | ``` | 70 | ``` |
| 55 | 71 | ||
| 56 | --- | 72 | --- |
| @@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar | @@ -114,5 +130,58 @@ Build package path: /sherpa-onnx/java-api-examples/lib/sherpaonnx.jar | ||
| 114 | make runmic | 130 | make runmic |
| 115 | ``` | 131 | ``` |
| 116 | 132 | ||
| 133 | +--- | ||
| 117 | 134 | ||
| 135 | +6.WebSocket Server | ||
| 136 | +---------- | ||
| 137 | + | ||
| 138 | +support multiple threads for websocket server | ||
| 139 | +6.0 Protocol for communication | ||
| 140 | +1) client connect to server | ||
| 141 | +```shell | ||
| 142 | + ws client -> srv ws address | ||
| 143 | + ws address example: ws://127.0.0.1:8889/ | ||
| 144 | +``` | ||
| 145 | +2) client send 16k pcm_s16le binary stream data to server | ||
| 146 | +```shell | ||
| 147 | + PCM sampleRate 16000 | ||
| 148 | + single channel | ||
| 149 | + sampleSize 16bit | ||
| 150 | + little endian | ||
| 151 | + type short | ||
| 152 | +``` | ||
| 153 | +3) client send "Done" text to server when all data is sent | ||
| 154 | +```shell | ||
| 155 | + ws_socket.send("Done") | ||
| 156 | +``` | ||
| 157 | +4) client will receive json message from server whenever asr engine decoded new text | ||
| 158 | +```shell | ||
| 159 | + json example: {"text":"甚至出现交易几乎停滞的情况","eof":false"} | ||
| 160 | +``` | ||
| 161 | + | ||
| 162 | + | ||
| 163 | +6.1 Build | ||
| 164 | + | ||
| 165 | +```bash | ||
| 166 | + cd sherpa-onnx/java-api-examples | ||
| 167 | + make all | ||
| 168 | +``` | ||
| 169 | + | ||
| 170 | +6.2 Run srv example | ||
| 171 | + | ||
| 172 | +usage: AsrWebsocketServer soPath modelCfgPath | ||
| 173 | + | ||
| 174 | +```bash | ||
| 175 | + make runsrv /**change path in Makefile according to your env**/ | ||
| 176 | +``` | ||
| 177 | + | ||
| 178 | +6.3 Run multiple threads client example | ||
| 179 | + | ||
| 180 | +usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads | ||
| 181 | + | ||
| 182 | +json result example: {"text":"甚至出现交易几乎停滞的情况","eof":"true"} | ||
| 183 | + | ||
| 184 | +```bash | ||
| 185 | + make runclient /**change path in Makefile according to your env**/ | ||
| 186 | +``` | ||
| 118 | 187 |
| @@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh- | @@ -9,6 +9,17 @@ decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh- | ||
| 9 | joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx | 9 | joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx |
| 10 | tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt | 10 | tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt |
| 11 | num_threads=4 | 11 | num_threads=4 |
| 12 | -enable_endpoint_detection=false | 12 | +enable_endpoint_detection=true |
| 13 | decoding_method=modified_beam_search | 13 | decoding_method=modified_beam_search |
| 14 | -max_active_paths=4 | ||
| 14 | +max_active_paths=4 | ||
| 15 | +lm_model= | ||
| 16 | +lm_scale=0.5 | ||
| 17 | + | ||
| 18 | +#websocket server config | ||
| 19 | +port=8890 | ||
| 20 | +connection_thread_num=16 | ||
| 21 | +stream_thread_num=16 | ||
| 22 | +decoder_thread_num=16 | ||
| 23 | +parallel_decoder_num=16 | ||
| 24 | +decoder_time_idle=200 | ||
| 25 | +deocder_time_out=30000 |
| @@ -49,7 +49,8 @@ public class DecodeFile { | @@ -49,7 +49,8 @@ public class DecodeFile { | ||
| 49 | float rule3MinUtteranceLength = 20F; | 49 | float rule3MinUtteranceLength = 20F; |
| 50 | String decodingMethod = "greedy_search"; | 50 | String decodingMethod = "greedy_search"; |
| 51 | int maxActivePaths = 4; | 51 | int maxActivePaths = 4; |
| 52 | - | 52 | + String lm_model=""; |
| 53 | + float lm_scale=0.5F; | ||
| 53 | rcgOjb = | 54 | rcgOjb = |
| 54 | new OnlineRecognizer( | 55 | new OnlineRecognizer( |
| 55 | tokens, | 56 | tokens, |
| @@ -64,6 +65,8 @@ public class DecodeFile { | @@ -64,6 +65,8 @@ public class DecodeFile { | ||
| 64 | rule2MinTrailingSilence, | 65 | rule2MinTrailingSilence, |
| 65 | rule3MinUtteranceLength, | 66 | rule3MinUtteranceLength, |
| 66 | decodingMethod, | 67 | decodingMethod, |
| 68 | + lm_model, | ||
| 69 | + lm_scale, | ||
| 67 | maxActivePaths); | 70 | maxActivePaths); |
| 68 | streamObj = rcgOjb.createStream(); | 71 | streamObj = rcgOjb.createStream(); |
| 69 | } catch (Exception e) { | 72 | } catch (Exception e) { |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaomingwork | ||
| 3 | + */ | ||
| 4 | +// java AsrWebsocketClient | ||
| 5 | +// usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads | ||
| 6 | +package websocketsrv; | ||
| 7 | + | ||
| 8 | +import com.k2fsa.sherpa.onnx.OnlineRecognizer; | ||
| 9 | +import java.net.URI; | ||
| 10 | +import java.net.URISyntaxException; | ||
| 11 | +import java.nio.*; | ||
| 12 | +import java.util.Map; | ||
| 13 | +import org.java_websocket.client.WebSocketClient; | ||
| 14 | +import org.java_websocket.drafts.Draft; | ||
| 15 | +import org.java_websocket.handshake.ServerHandshake; | ||
| 16 | +import org.slf4j.Logger; | ||
| 17 | +import org.slf4j.LoggerFactory; | ||
| 18 | + | ||
| 19 | +/** This example demonstrates how to connect to websocket server. */ | ||
| 20 | +public class AsrWebsocketClient extends WebSocketClient { | ||
| 21 | + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketClient.class); | ||
| 22 | + | ||
| 23 | + public AsrWebsocketClient(URI serverUri, Draft draft) { | ||
| 24 | + super(serverUri, draft); | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + public AsrWebsocketClient(URI serverURI) { | ||
| 28 | + super(serverURI); | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + public AsrWebsocketClient(URI serverUri, Map<String, String> httpHeaders) { | ||
| 32 | + super(serverUri, httpHeaders); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + @Override | ||
| 36 | + public void onOpen(ServerHandshake handshakedata) { | ||
| 37 | + | ||
| 38 | + float[] floats = OnlineRecognizer.readWavFile(AsrWebsocketClient.wavPath); | ||
| 39 | + ByteBuffer buffer = | ||
| 40 | + ByteBuffer.allocate(4 * floats.length) | ||
| 41 | + .order(ByteOrder.LITTLE_ENDIAN); // float is sizeof 4. allocate enough buffer | ||
| 42 | + | ||
| 43 | + for (float f : floats) { | ||
| 44 | + buffer.putFloat(f); | ||
| 45 | + } | ||
| 46 | + buffer.rewind(); | ||
| 47 | + buffer.flip(); | ||
| 48 | + buffer.order(ByteOrder.LITTLE_ENDIAN); | ||
| 49 | + | ||
| 50 | + send(buffer.array()); // send buf to server | ||
| 51 | + send("Done"); // send 'Done' means finished | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + @Override | ||
| 55 | + public void onMessage(String message) { | ||
| 56 | + | ||
| 57 | + logger.info("received: " + message); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + @Override | ||
| 61 | + public void onClose(int code, String reason, boolean remote) { | ||
| 62 | + | ||
| 63 | + logger.info( | ||
| 64 | + "Connection closed by " | ||
| 65 | + + (remote ? "remote peer" : "us") | ||
| 66 | + + " Code: " | ||
| 67 | + + code | ||
| 68 | + + " Reason: " | ||
| 69 | + + reason); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + @Override | ||
| 73 | + public void onError(Exception ex) { | ||
| 74 | + ex.printStackTrace(); | ||
| 75 | + // if the error is fatal then onClose will be called additionally | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + public static OnlineRecognizer rcgobj; | ||
| 79 | + public static String wavPath; | ||
| 80 | + | ||
| 81 | + public static void main(String[] args) throws URISyntaxException { | ||
| 82 | + | ||
| 83 | + if (args.length != 5) { | ||
| 84 | + System.out.println("usage: AsrWebsocketClient soPath srvIp srvPort wavPath numThreads"); | ||
| 85 | + return; | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + String soPath = args[0]; | ||
| 89 | + String srvIp = args[1]; | ||
| 90 | + String srvPort = args[2]; | ||
| 91 | + String wavPath = args[3]; | ||
| 92 | + int numThreads = Integer.parseInt(args[4]); | ||
| 93 | + System.out.println("serIp=" + srvIp + ",srvPort=" + srvPort + ",wavPath=" + wavPath); | ||
| 94 | + | ||
| 95 | + class ClientThread implements Runnable { | ||
| 96 | + | ||
| 97 | + String soPath; | ||
| 98 | + String srvIp; | ||
| 99 | + String srvPort; | ||
| 100 | + String wavPath; | ||
| 101 | + | ||
| 102 | + ClientThread(String soPath, String srvIp, String srvPort, String wavPath) { | ||
| 103 | + this.soPath = soPath; | ||
| 104 | + this.srvIp = srvIp; | ||
| 105 | + this.srvPort = srvPort; | ||
| 106 | + this.wavPath = wavPath; | ||
| 107 | + } | ||
| 108 | + | ||
| 109 | + public void run() { | ||
| 110 | + try { | ||
| 111 | + | ||
| 112 | + OnlineRecognizer.setSoPath(soPath); | ||
| 113 | + | ||
| 114 | + AsrWebsocketClient.wavPath = wavPath; | ||
| 115 | + | ||
| 116 | + String wsAddress = "ws://" + srvIp + ":" + srvPort; | ||
| 117 | + AsrWebsocketClient c = new AsrWebsocketClient(new URI(wsAddress)); | ||
| 118 | + | ||
| 119 | + c.connect(); | ||
| 120 | + } catch (Exception e) { | ||
| 121 | + e.printStackTrace(); | ||
| 122 | + } | ||
| 123 | + } | ||
| 124 | + } | ||
| 125 | + for (int i = 0; i < numThreads; i++) { | ||
| 126 | + System.out.println("Thread1 is running..."); | ||
| 127 | + Thread t = new Thread(new ClientThread(soPath, srvIp, srvPort, wavPath)); | ||
| 128 | + t.start(); | ||
| 129 | + } | ||
| 130 | + } | ||
| 131 | +} |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaoming | ||
| 3 | + */ | ||
| 4 | +// java websocketServer | ||
| 5 | +// usage: AsrWebsocketServer soPath modelCfgPath | ||
| 6 | +package websocketsrv; | ||
| 7 | + | ||
| 8 | +import com.k2fsa.sherpa.onnx.OnlineRecognizer; | ||
| 9 | +import com.k2fsa.sherpa.onnx.OnlineStream; | ||
| 10 | +import java.io.*; | ||
| 11 | +import java.io.IOException; | ||
| 12 | +import java.net.InetSocketAddress; | ||
| 13 | +import java.net.UnknownHostException; | ||
| 14 | +import java.nio.ByteBuffer; | ||
| 15 | +import java.nio.ByteOrder; | ||
| 16 | +import java.nio.FloatBuffer; | ||
| 17 | +import java.util.*; | ||
| 18 | +import java.util.Collections; | ||
| 19 | +import java.util.concurrent.*; | ||
| 20 | +import java.util.concurrent.LinkedBlockingQueue; | ||
| 21 | +import org.java_websocket.WebSocket; | ||
| 22 | +import org.java_websocket.drafts.Draft; | ||
| 23 | +import org.java_websocket.drafts.Draft_6455; | ||
| 24 | +import org.java_websocket.handshake.ClientHandshake; | ||
| 25 | +import org.java_websocket.server.WebSocketServer; | ||
| 26 | +import org.slf4j.Logger; | ||
| 27 | +import org.slf4j.LoggerFactory; | ||
| 28 | + | ||
| 29 | +/** | ||
| 30 | + * AsrWebSocketServer has three threads pools, one pool for network io, one pool for asr stream and | ||
| 31 | + * one pool for asr decoder. | ||
| 32 | + */ | ||
| 33 | +public class AsrWebsocketServer extends WebSocketServer { | ||
| 34 | + private static final Logger logger = LoggerFactory.getLogger(AsrWebsocketServer.class); | ||
| 35 | + // Queue between io network io thread pool and stream thread pool, use websocket as the key | ||
| 36 | + private LinkedBlockingQueue<WebSocket> streamQueue = new LinkedBlockingQueue<WebSocket>(); | ||
| 37 | + // Queue waiting for deocdeing, use websocket as the key | ||
| 38 | + private LinkedBlockingQueue<WebSocket> decoderQueue = new LinkedBlockingQueue<WebSocket>(); | ||
| 39 | + | ||
| 40 | + // recogizer object | ||
| 41 | + private OnlineRecognizer rcgOjb = null; | ||
| 42 | + | ||
| 43 | + // mapping between websocket connection and connection data | ||
| 44 | + private ConcurrentHashMap<WebSocket, ConnectionData> connectionMap = | ||
| 45 | + new ConcurrentHashMap<WebSocket, ConnectionData>(); | ||
| 46 | + | ||
| 47 | + public AsrWebsocketServer(int port, int numThread) throws UnknownHostException { | ||
| 48 | + // server port and num of threads for network io | ||
| 49 | + super(new InetSocketAddress(port), numThread); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + public AsrWebsocketServer(InetSocketAddress address) { | ||
| 53 | + super(address); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + public AsrWebsocketServer(int port, Draft_6455 draft) { | ||
| 57 | + super(new InetSocketAddress(port), Collections.<Draft>singletonList(draft)); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + @Override | ||
| 61 | + public void onOpen(WebSocket conn, ClientHandshake handshake) {} | ||
| 62 | + | ||
| 63 | + @Override | ||
| 64 | + public void onClose(WebSocket conn, int code, String reason, boolean remote) { | ||
| 65 | + connectionMap.remove(conn); | ||
| 66 | + logger.info( | ||
| 67 | + conn | ||
| 68 | + + " remove one connection!, now connection number=" | ||
| 69 | + + String.valueOf(connectionMap.size())); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + @Override | ||
| 73 | + public void onMessage(WebSocket conn, String message) { | ||
| 74 | + // this is text message | ||
| 75 | + try { | ||
| 76 | + // if rec "Done" msg from client | ||
| 77 | + if (message.equals("Done")) { | ||
| 78 | + ConnectionData connData = creatOrGetConnectionData(conn); | ||
| 79 | + connData.setEof(true); | ||
| 80 | + if (!streamQueueFind(conn)) { | ||
| 81 | + streamQueue.put(conn); | ||
| 82 | + } | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + } catch (Exception e) { | ||
| 86 | + e.printStackTrace(); | ||
| 87 | + } | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + private ConnectionData creatOrGetConnectionData(WebSocket conn) { | ||
| 91 | + // create a new connection data if not in connection map or return the existed one | ||
| 92 | + | ||
| 93 | + ConnectionData connData = null; | ||
| 94 | + try { | ||
| 95 | + if (!connectionMap.containsKey(conn)) { | ||
| 96 | + OnlineStream stream = rcgOjb.createStream(); | ||
| 97 | + connData = new ConnectionData(conn, stream); | ||
| 98 | + connectionMap.put(conn, connData); | ||
| 99 | + } else { | ||
| 100 | + connData = connectionMap.get(conn); | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + logger.info( | ||
| 104 | + conn.getRemoteSocketAddress().getAddress().getHostAddress() | ||
| 105 | + + " open one connection,, now connection number=" | ||
| 106 | + + String.valueOf(connectionMap.size())); | ||
| 107 | + | ||
| 108 | + } catch (Exception e) { | ||
| 109 | + System.err.println(e); | ||
| 110 | + e.printStackTrace(); | ||
| 111 | + } | ||
| 112 | + return connData; | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + @Override | ||
| 116 | + public void onMessage(WebSocket conn, ByteBuffer blob) { | ||
| 117 | + try { | ||
| 118 | + | ||
| 119 | + // for handle binary data | ||
| 120 | + blob.order(ByteOrder.LITTLE_ENDIAN); // set little endian | ||
| 121 | + | ||
| 122 | + // set to float | ||
| 123 | + FloatBuffer floatbuf = blob.asFloatBuffer(); | ||
| 124 | + | ||
| 125 | + if (floatbuf.capacity() > 0) { | ||
| 126 | + // allocate memory for float data | ||
| 127 | + float[] arr = new float[floatbuf.capacity()]; | ||
| 128 | + | ||
| 129 | + floatbuf.get(arr); | ||
| 130 | + ConnectionData connData = creatOrGetConnectionData(conn); | ||
| 131 | + // put websocket to stream queue with binary type==1 | ||
| 132 | + connData.addSamplesToData(arr); | ||
| 133 | + | ||
| 134 | + if (!streamQueueFind(conn)) { | ||
| 135 | + streamQueue.put(conn); | ||
| 136 | + } | ||
| 137 | + } | ||
| 138 | + } catch (Exception e) { | ||
| 139 | + e.printStackTrace(); | ||
| 140 | + } | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + | ||
| 144 | + | ||
| 145 | + public boolean streamQueueFind(WebSocket conn) { | ||
| 146 | + return streamQueue.contains(conn); | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + public void initModelWithCfg(Map<String, String> cfgMap, String cfgPath) { | ||
| 150 | + try { | ||
| 151 | + | ||
| 152 | + rcgOjb = new OnlineRecognizer(cfgPath); | ||
| 153 | + // size of stream thread pool | ||
| 154 | + int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num")); | ||
| 155 | + // size of decoder thread pool | ||
| 156 | + int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num")); | ||
| 157 | + | ||
| 158 | + // time(ms) idle for decoder thread when no job | ||
| 159 | + int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle")); | ||
| 160 | + // size of streams for parallel decoding | ||
| 161 | + int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num")); | ||
| 162 | + // time(ms) out for connection data | ||
| 163 | + int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out")); | ||
| 164 | + | ||
| 165 | + // create stream threads | ||
| 166 | + for (int i = 0; i < streamThreadNum; i++) { | ||
| 167 | + new StreamThreadHandler(streamQueue, decoderQueue, connectionMap).start(); | ||
| 168 | + } | ||
| 169 | + // create decoder threads | ||
| 170 | + for (int i = 0; i < decoderThreadNum; i++) { | ||
| 171 | + new DecoderThreadHandler( | ||
| 172 | + decoderQueue, | ||
| 173 | + connectionMap, | ||
| 174 | + rcgOjb, | ||
| 175 | + decoderTimeIdle, | ||
| 176 | + parallelDecoderNum, | ||
| 177 | + deocderTimeOut) | ||
| 178 | + .start(); | ||
| 179 | + } | ||
| 180 | + } catch (Exception e) { | ||
| 181 | + System.err.println(e); | ||
| 182 | + e.printStackTrace(); | ||
| 183 | + } | ||
| 184 | + } | ||
| 185 | + | ||
| 186 | + public static Map<String, String> readProperties(String CfgPath) { | ||
| 187 | + // read and parse config file | ||
| 188 | + Properties props = new Properties(); | ||
| 189 | + Map<String, String> proMap = new HashMap<String, String>(); | ||
| 190 | + try { | ||
| 191 | + | ||
| 192 | + File file = new File(CfgPath); | ||
| 193 | + if (!file.exists()) { | ||
| 194 | + logger.info(String.valueOf(CfgPath) + " cfg file not exists!"); | ||
| 195 | + System.exit(0); | ||
| 196 | + } | ||
| 197 | + InputStream in = new BufferedInputStream(new FileInputStream(CfgPath)); | ||
| 198 | + props.load(in); | ||
| 199 | + Enumeration en = props.propertyNames(); | ||
| 200 | + while (en.hasMoreElements()) { | ||
| 201 | + String key = (String) en.nextElement(); | ||
| 202 | + String Property = props.getProperty(key); | ||
| 203 | + proMap.put(key, Property); | ||
| 204 | + } | ||
| 205 | + | ||
| 206 | + } catch (Exception e) { | ||
| 207 | + e.printStackTrace(); | ||
| 208 | + } | ||
| 209 | + return proMap; | ||
| 210 | + } | ||
| 211 | + | ||
| 212 | + public static void main(String[] args) throws InterruptedException, IOException { | ||
| 213 | + if (args.length != 2) { | ||
| 214 | + logger.info("usage: AsrWebsocketServer soPath modelCfgPath"); | ||
| 215 | + | ||
| 216 | + return; | ||
| 217 | + } | ||
| 218 | + | ||
| 219 | + String soPath = args[0]; | ||
| 220 | + String cfgPath = args[1]; | ||
| 221 | + | ||
| 222 | + OnlineRecognizer.setSoPath(soPath); | ||
| 223 | + | ||
| 224 | + Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); | ||
| 225 | + int port = Integer.valueOf(cfgMap.get("port")); | ||
| 226 | + | ||
| 227 | + int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); | ||
| 228 | + AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); | ||
| 229 | + s.initModelWithCfg(cfgMap, cfgPath); | ||
| 230 | + logger.info("Server started on port: " + s.getPort()); | ||
| 231 | + s.start(); | ||
| 232 | + } | ||
| 233 | + | ||
| 234 | + @Override | ||
| 235 | + public void onError(WebSocket conn, Exception ex) { | ||
| 236 | + ex.printStackTrace(); | ||
| 237 | + if (conn != null) { | ||
| 238 | + // some errors like port binding failed may not be assignable to a specific websocket | ||
| 239 | + } | ||
| 240 | + } | ||
| 241 | + | ||
| 242 | + @Override | ||
| 243 | + public void onStart() { | ||
| 244 | + logger.info("Server started!"); | ||
| 245 | + setConnectionLostTimeout(0); | ||
| 246 | + setConnectionLostTimeout(100); | ||
| 247 | + } | ||
| 248 | +} |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaoming | ||
| 3 | + */ | ||
| 4 | +// connection data act as a bridge between different threads pools | ||
| 5 | + | ||
| 6 | +package websocketsrv; | ||
| 7 | + | ||
| 8 | +import com.k2fsa.sherpa.onnx.OnlineStream; | ||
| 9 | +import java.time.LocalDateTime; | ||
| 10 | +import java.util.LinkedList; | ||
| 11 | +import java.util.Queue; | ||
| 12 | +import java.util.concurrent.*; | ||
| 13 | +import org.java_websocket.WebSocket; | ||
| 14 | + | ||
| 15 | +public class ConnectionData { | ||
| 16 | + | ||
| 17 | + private WebSocket webSocket; // the websocket for this connection data | ||
| 18 | + | ||
| 19 | + private OnlineStream stream; // connection stream | ||
| 20 | + | ||
| 21 | + private Queue<float[]> queueSamples = | ||
| 22 | + new LinkedList<float[]>(); // binary data rec from the client | ||
| 23 | + | ||
| 24 | + private boolean eof = false; // connection data is done | ||
| 25 | + | ||
| 26 | + private LocalDateTime lastHandleTime; // used for time out in ms | ||
| 27 | + | ||
| 28 | + public ConnectionData(WebSocket webSocket, OnlineStream stream) { | ||
| 29 | + this.webSocket = webSocket; | ||
| 30 | + | ||
| 31 | + this.stream = stream; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + public void addSamplesToData(float[] samples) { | ||
| 35 | + this.queueSamples.add(samples); | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + public LocalDateTime getLastHandleTime() { | ||
| 39 | + return this.lastHandleTime; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + public void setLastHandleTime(LocalDateTime now) { | ||
| 43 | + this.lastHandleTime = now; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + public boolean getEof() { | ||
| 47 | + return this.eof; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + public void setEof(boolean eof) { | ||
| 51 | + this.eof = eof; | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + public WebSocket getWebSocket() { | ||
| 55 | + return this.webSocket; | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + public Queue<float[]> getQueueSamples() { | ||
| 59 | + return this.queueSamples; | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + public OnlineStream getStream() { | ||
| 63 | + return this.stream; | ||
| 64 | + } | ||
| 65 | +} |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaoming | ||
| 3 | + */ | ||
| 4 | +// java DecoderThreadHandler | ||
| 5 | +package websocketsrv; | ||
| 6 | + | ||
| 7 | +import com.k2fsa.sherpa.onnx.OnlineRecognizer; | ||
| 8 | +import com.k2fsa.sherpa.onnx.OnlineStream; | ||
| 9 | +import java.nio.*; | ||
| 10 | +import java.nio.charset.StandardCharsets; | ||
| 11 | +import java.time.LocalDateTime; | ||
| 12 | +import java.util.*; | ||
| 13 | +import java.util.List; | ||
| 14 | +import java.util.concurrent.*; | ||
| 15 | +import java.util.concurrent.LinkedBlockingQueue; | ||
| 16 | +import org.java_websocket.WebSocket; | ||
| 17 | +import org.java_websocket.drafts.Draft; | ||
| 18 | +import org.java_websocket.framing.Framedata; | ||
| 19 | +import org.slf4j.Logger; | ||
| 20 | +import org.slf4j.LoggerFactory; | ||
| 21 | + | ||
| 22 | +public class DecoderThreadHandler extends Thread { | ||
| 23 | + private static final Logger logger = LoggerFactory.getLogger(DecoderThreadHandler.class); | ||
| 24 | + // Websocket Queue that waiting for decoding | ||
| 25 | + private LinkedBlockingQueue<WebSocket> decoderQueue; | ||
| 26 | + // the mapping between websocket and connection data | ||
| 27 | + private ConcurrentHashMap<WebSocket, ConnectionData> connMap; | ||
| 28 | + | ||
| 29 | + private OnlineRecognizer rcgOjb = null; // recgnizer object | ||
| 30 | + | ||
| 31 | + // connection data list for this thread to decode in parallel | ||
| 32 | + private List<ConnectionData> connDataList = new ArrayList<ConnectionData>(); | ||
| 33 | + | ||
| 34 | + private int parallelDecoderNum = 10; // parallel decoding number | ||
| 35 | + private int deocderTimeIdle = 10; // idle time(ms) when no job | ||
| 36 | + private int deocderTimeOut = 3000; // if it is timeout(ms), the connection data will be removed | ||
| 37 | + | ||
| 38 | + public DecoderThreadHandler( | ||
| 39 | + LinkedBlockingQueue<WebSocket> decoderQueue, | ||
| 40 | + ConcurrentHashMap<WebSocket, ConnectionData> connMap, | ||
| 41 | + OnlineRecognizer rcgOjb, | ||
| 42 | + int deocderTimeIdle, | ||
| 43 | + int parallelDecoderNum, | ||
| 44 | + int deocderTimeOut) { | ||
| 45 | + this.decoderQueue = decoderQueue; | ||
| 46 | + this.connMap = connMap; | ||
| 47 | + this.rcgOjb = rcgOjb; | ||
| 48 | + this.deocderTimeIdle = deocderTimeIdle; | ||
| 49 | + this.parallelDecoderNum = parallelDecoderNum; | ||
| 50 | + this.deocderTimeOut = deocderTimeOut; | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + public void run() { | ||
| 54 | + while (true) { | ||
| 55 | + try { | ||
| 56 | + // time(ms) idle if there is no job | ||
| 57 | + | ||
| 58 | + Thread.sleep(deocderTimeIdle); | ||
| 59 | + // clear data list for this threads | ||
| 60 | + connDataList.clear(); | ||
| 61 | + if (rcgOjb == null) continue; | ||
| 62 | + | ||
| 63 | + // loop for total decoder Queue | ||
| 64 | + while (!decoderQueue.isEmpty()) { | ||
| 65 | + | ||
| 66 | + // get websocket | ||
| 67 | + WebSocket conn = decoderQueue.take(); | ||
| 68 | + // get connection data according to websocket | ||
| 69 | + ConnectionData connData = connMap.get(conn); | ||
| 70 | + | ||
| 71 | + // if the websocket closed, continue | ||
| 72 | + if (connData == null) continue; | ||
| 73 | + // get the stream | ||
| 74 | + OnlineStream stream = connData.getStream(); | ||
| 75 | + | ||
| 76 | + // put to decoder list if 1) stream is ready; 2) and | ||
| 77 | + // size not > parallelDecoderNum | ||
| 78 | + if ((rcgOjb.isReady(stream) && connDataList.size() < parallelDecoderNum)) { | ||
| 79 | + | ||
| 80 | + // add to this thread's decoder list | ||
| 81 | + connDataList.add(connData); | ||
| 82 | + // change the handled time for this connection data | ||
| 83 | + connData.setLastHandleTime(LocalDateTime.now()); | ||
| 84 | + } | ||
| 85 | + // break when decoder list size >= parallelDecoderNum | ||
| 86 | + if (connDataList.size() >= parallelDecoderNum) { | ||
| 87 | + break; | ||
| 88 | + } | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + // if decoder data list for this thread >0 | ||
| 92 | + if (connDataList.size() > 0) { | ||
| 93 | + | ||
| 94 | + // create a stream array for parallel decoding | ||
| 95 | + OnlineStream[] arr = new OnlineStream[connDataList.size()]; | ||
| 96 | + for (int i = 0; i < connDataList.size(); i++) { | ||
| 97 | + | ||
| 98 | + arr[i] = connDataList.get(i).getStream(); | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + // parallel decoding | ||
| 102 | + rcgOjb.decodeStreams(arr); | ||
| 103 | + } | ||
| 104 | + | ||
| 105 | + // get result for each connection | ||
| 106 | + for (ConnectionData connData : connDataList) { | ||
| 107 | + | ||
| 108 | + OnlineStream stream = connData.getStream(); | ||
| 109 | + WebSocket webSocket = connData.getWebSocket(); | ||
| 110 | + | ||
| 111 | + String txtResult = rcgOjb.getResult(stream); | ||
| 112 | + | ||
| 113 | + // decode text in utf-8 | ||
| 114 | + byte[] utf8Data = txtResult.getBytes(StandardCharsets.UTF_8); | ||
| 115 | + | ||
| 116 | + boolean isEof = (connData.getEof() == true && !rcgOjb.isReady(stream)); | ||
| 117 | + // result | ||
| 118 | + if (utf8Data.length > 0) { | ||
| 119 | + | ||
| 120 | + String jsonResult = | ||
| 121 | + "{\"text\":\"" + txtResult + "\",\"eof\":" + String.valueOf(isEof) + "\"}"; | ||
| 122 | + | ||
| 123 | + if (webSocket.isOpen()) { | ||
| 124 | + // create a TEXT Frame for send back json result | ||
| 125 | + Draft draft = webSocket.getDraft(); | ||
| 126 | + List<Framedata> frames = null; | ||
| 127 | + frames = draft.createFrames(jsonResult, false); | ||
| 128 | + // send to client | ||
| 129 | + webSocket.sendFrame(frames); | ||
| 130 | + } | ||
| 131 | + } | ||
| 132 | + } | ||
| 133 | + // loop for each connection data in this thread | ||
| 134 | + for (ConnectionData connData : connDataList) { | ||
| 135 | + OnlineStream stream = connData.getStream(); | ||
| 136 | + WebSocket webSocket = connData.getWebSocket(); | ||
| 137 | + // if the stream is still ready, put it to decoder Queue again for next decoding | ||
| 138 | + if (rcgOjb.isReady(stream)) { | ||
| 139 | + decoderQueue.put(webSocket); | ||
| 140 | + } | ||
| 141 | + // the duration between last handled time and now | ||
| 142 | + java.time.Duration duration = | ||
| 143 | + java.time.Duration.between(connData.getLastHandleTime(), LocalDateTime.now()); | ||
| 144 | + // close the websocket if 1) data is done and stream not ready; 2) or data is time out; | ||
| 145 | + // 3) or | ||
| 146 | + // connection is closed | ||
| 147 | + if ((connData.getEof() == true | ||
| 148 | + && !rcgOjb.isReady(stream) | ||
| 149 | + && connData.getQueueSamples().isEmpty()) | ||
| 150 | + || duration.toMillis() > deocderTimeOut | ||
| 151 | + || !connData.getWebSocket().isOpen()) { | ||
| 152 | + | ||
| 153 | + logger.info("close websocket!!!"); | ||
| 154 | + | ||
| 155 | + // delay close web socket as data may still in processing | ||
| 156 | + Timer timer = new Timer(); | ||
| 157 | + timer.schedule( | ||
| 158 | + new TimerTask() { | ||
| 159 | + public void run() { | ||
| 160 | + | ||
| 161 | + webSocket.close(); | ||
| 162 | + } | ||
| 163 | + }, | ||
| 164 | + 5000); // 5 seconds | ||
| 165 | + } | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + } catch (Exception e) { | ||
| 169 | + e.printStackTrace(); | ||
| 170 | + } | ||
| 171 | + } | ||
| 172 | + } | ||
| 173 | +} |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaoming | ||
| 3 | + */ | ||
| 4 | +// java StreamThreadHandler | ||
| 5 | +package websocketsrv; | ||
| 6 | + | ||
| 7 | +import com.k2fsa.sherpa.onnx.OnlineStream; | ||
| 8 | +import java.nio.*; | ||
| 9 | +import java.util.*; | ||
| 10 | +import java.util.concurrent.*; | ||
| 11 | +import java.util.concurrent.LinkedBlockingQueue; | ||
| 12 | +import org.java_websocket.WebSocket; | ||
| 13 | +// thread for processing stream | ||
| 14 | + | ||
| 15 | +public class StreamThreadHandler extends Thread { | ||
| 16 | + // Queue between io network io thread pool and stream thread pool, use websocket as the key | ||
| 17 | + private LinkedBlockingQueue<WebSocket> streamQueue; | ||
| 18 | + // Queue waiting for deocdeing, use websocket as the key | ||
| 19 | + private LinkedBlockingQueue<WebSocket> decoderQueue; | ||
| 20 | + // mapping between websocket connection and connection data | ||
| 21 | + private ConcurrentHashMap<WebSocket, ConnectionData> connMap; | ||
| 22 | + | ||
| 23 | + public StreamThreadHandler( | ||
| 24 | + LinkedBlockingQueue<WebSocket> streamQueue, | ||
| 25 | + LinkedBlockingQueue<WebSocket> decoderQueue, | ||
| 26 | + ConcurrentHashMap<WebSocket, ConnectionData> connMap) { | ||
| 27 | + this.streamQueue = streamQueue; | ||
| 28 | + this.decoderQueue = decoderQueue; | ||
| 29 | + this.connMap = connMap; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + public void run() { | ||
| 33 | + while (true) { | ||
| 34 | + try { | ||
| 35 | + // fetch one websocket from queue | ||
| 36 | + WebSocket conn = (WebSocket) this.streamQueue.take(); | ||
| 37 | + // get the connection data according to websocket | ||
| 38 | + ConnectionData connData = connMap.get(conn); | ||
| 39 | + OnlineStream stream = connData.getStream(); | ||
| 40 | + | ||
| 41 | + // handle received binary data | ||
| 42 | + if (!connData.getQueueSamples().isEmpty()) { | ||
| 43 | + // loop to put all received binary data to stream | ||
| 44 | + while (!connData.getQueueSamples().isEmpty()) { | ||
| 45 | + | ||
| 46 | + float[] samples = connData.getQueueSamples().poll(); | ||
| 47 | + | ||
| 48 | + stream.acceptWaveform(samples); | ||
| 49 | + } | ||
| 50 | + // if data is finished | ||
| 51 | + if (connData.getEof() == true) { | ||
| 52 | + | ||
| 53 | + stream.inputFinished(); | ||
| 54 | + } | ||
| 55 | + // add this websocket to decoder Queue if not in the Queue | ||
| 56 | + if (!decoderQueue.contains(conn)) { | ||
| 57 | + | ||
| 58 | + decoderQueue.put(conn); | ||
| 59 | + } | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + } catch (Exception e) { | ||
| 63 | + e.printStackTrace(); | ||
| 64 | + } | ||
| 65 | + } | ||
| 66 | + } | ||
| 67 | +} |
| 1 | +/* | ||
| 2 | + * // Copyright 2022-2023 by zhaoming | ||
| 3 | + */ | ||
| 4 | + | ||
| 5 | +package com.k2fsa.sherpa.onnx; | ||
| 6 | + | ||
| 7 | +public class OnlineLMConfig { | ||
| 8 | + private final String model; | ||
| 9 | + private final float scale; | ||
| 10 | + | ||
| 11 | + public OnlineLMConfig(String model, float scale) { | ||
| 12 | + this.model = model; | ||
| 13 | + this.scale = scale; | ||
| 14 | + } | ||
| 15 | + | ||
| 16 | + public String getModel() { | ||
| 17 | + return model; | ||
| 18 | + } | ||
| 19 | + | ||
| 20 | + public float getScale() { | ||
| 21 | + return scale; | ||
| 22 | + } | ||
| 23 | +} |
| @@ -65,11 +65,14 @@ public class OnlineRecognizer { | @@ -65,11 +65,14 @@ public class OnlineRecognizer { | ||
| 65 | false); | 65 | false); |
| 66 | FeatureConfig featConfig = | 66 | FeatureConfig featConfig = |
| 67 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 67 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); |
| 68 | - OnlineRecognizerConfig rcgCfg = | 68 | + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); |
| 69 | + | ||
| 70 | + OnlineRecognizerConfig rcgCfg = | ||
| 69 | new OnlineRecognizerConfig( | 71 | new OnlineRecognizerConfig( |
| 70 | featConfig, | 72 | featConfig, |
| 71 | modelCfg, | 73 | modelCfg, |
| 72 | endCfg, | 74 | endCfg, |
| 75 | + onlineLmConfig, | ||
| 73 | Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), | 76 | Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), |
| 74 | proMap.get("decoding_method").trim(), | 77 | proMap.get("decoding_method").trim(), |
| 75 | Integer.parseInt(proMap.get("max_active_paths").trim())); | 78 | Integer.parseInt(proMap.get("max_active_paths").trim())); |
| @@ -107,11 +110,15 @@ public class OnlineRecognizer { | @@ -107,11 +110,15 @@ public class OnlineRecognizer { | ||
| 107 | false); | 110 | false); |
| 108 | FeatureConfig featConfig = | 111 | FeatureConfig featConfig = |
| 109 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 112 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); |
| 110 | - OnlineRecognizerConfig rcgCfg = | 113 | + |
| 114 | + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); | ||
| 115 | + | ||
| 116 | + OnlineRecognizerConfig rcgCfg = | ||
| 111 | new OnlineRecognizerConfig( | 117 | new OnlineRecognizerConfig( |
| 112 | featConfig, | 118 | featConfig, |
| 113 | modelCfg, | 119 | modelCfg, |
| 114 | endCfg, | 120 | endCfg, |
| 121 | + onlineLmConfig, | ||
| 115 | Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), | 122 | Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), |
| 116 | proMap.get("decoding_method").trim(), | 123 | proMap.get("decoding_method").trim(), |
| 117 | Integer.parseInt(proMap.get("max_active_paths").trim())); | 124 | Integer.parseInt(proMap.get("max_active_paths").trim())); |
| @@ -137,6 +144,8 @@ public class OnlineRecognizer { | @@ -137,6 +144,8 @@ public class OnlineRecognizer { | ||
| 137 | float rule2MinTrailingSilence, | 144 | float rule2MinTrailingSilence, |
| 138 | float rule3MinUtteranceLength, | 145 | float rule3MinUtteranceLength, |
| 139 | String decodingMethod, | 146 | String decodingMethod, |
| 147 | + String lm_model, | ||
| 148 | + float lm_scale, | ||
| 140 | int maxActivePaths) { | 149 | int maxActivePaths) { |
| 141 | this.sampleRate = sampleRate; | 150 | this.sampleRate = sampleRate; |
| 142 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); | 151 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); |
| @@ -146,14 +155,10 @@ public class OnlineRecognizer { | @@ -146,14 +155,10 @@ public class OnlineRecognizer { | ||
| 146 | OnlineTransducerModelConfig modelCfg = | 155 | OnlineTransducerModelConfig modelCfg = |
| 147 | new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); | 156 | new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); |
| 148 | FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); | 157 | FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); |
| 149 | - OnlineRecognizerConfig rcgCfg = | 158 | + OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale); |
| 159 | + OnlineRecognizerConfig rcgCfg = | ||
| 150 | new OnlineRecognizerConfig( | 160 | new OnlineRecognizerConfig( |
| 151 | - featConfig, | ||
| 152 | - modelCfg, | ||
| 153 | - endCfg, | ||
| 154 | - enableEndpointDetection, | ||
| 155 | - decodingMethod, | ||
| 156 | - maxActivePaths); | 161 | + featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); |
| 157 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 | 162 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 |
| 158 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); | 163 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); |
| 159 | } | 164 | } |
| @@ -241,7 +246,7 @@ public class OnlineRecognizer { | @@ -241,7 +246,7 @@ public class OnlineRecognizer { | ||
| 241 | return stream; | 246 | return stream; |
| 242 | } | 247 | } |
| 243 | 248 | ||
| 244 | - public float[] readWavFile(String fileName) { | 249 | + public static float[] readWavFile(String fileName) { |
| 245 | // read data from the filename | 250 | // read data from the filename |
| 246 | Object[] wavdata = readWave(fileName); | 251 | Object[] wavdata = readWave(fileName); |
| 247 | Object data = wavdata[0]; // data[0] is float data, data[1] sample rate | 252 | Object data = wavdata[0]; // data[0] is float data, data[1] sample rate |
| @@ -281,7 +286,7 @@ public class OnlineRecognizer { | @@ -281,7 +286,7 @@ public class OnlineRecognizer { | ||
| 281 | } | 286 | } |
| 282 | // JNI interface libsherpa-onnx-jni.so | 287 | // JNI interface libsherpa-onnx-jni.so |
| 283 | 288 | ||
| 284 | - private native Object[] readWave(String fileName); | 289 | + private static native Object[] readWave(String fileName); // static |
| 285 | 290 | ||
| 286 | private native String getResult(long ptr, long streamPtr); | 291 | private native String getResult(long ptr, long streamPtr); |
| 287 | 292 |
| @@ -8,25 +8,33 @@ public class OnlineRecognizerConfig { | @@ -8,25 +8,33 @@ public class OnlineRecognizerConfig { | ||
| 8 | private final FeatureConfig featConfig; | 8 | private final FeatureConfig featConfig; |
| 9 | private final OnlineTransducerModelConfig modelConfig; | 9 | private final OnlineTransducerModelConfig modelConfig; |
| 10 | private final EndpointConfig endpointConfig; | 10 | private final EndpointConfig endpointConfig; |
| 11 | + private final OnlineLMConfig lmConfig; | ||
| 11 | private final boolean enableEndpoint; | 12 | private final boolean enableEndpoint; |
| 12 | private final String decodingMethod; | 13 | private final String decodingMethod; |
| 13 | private final int maxActivePaths; | 14 | private final int maxActivePaths; |
| 15 | + | ||
| 14 | 16 | ||
| 15 | public OnlineRecognizerConfig( | 17 | public OnlineRecognizerConfig( |
| 16 | FeatureConfig featConfig, | 18 | FeatureConfig featConfig, |
| 17 | OnlineTransducerModelConfig modelConfig, | 19 | OnlineTransducerModelConfig modelConfig, |
| 18 | EndpointConfig endpointConfig, | 20 | EndpointConfig endpointConfig, |
| 21 | + OnlineLMConfig lmConfig, | ||
| 19 | boolean enableEndpoint, | 22 | boolean enableEndpoint, |
| 20 | String decodingMethod, | 23 | String decodingMethod, |
| 21 | int maxActivePaths) { | 24 | int maxActivePaths) { |
| 22 | this.featConfig = featConfig; | 25 | this.featConfig = featConfig; |
| 23 | this.modelConfig = modelConfig; | 26 | this.modelConfig = modelConfig; |
| 24 | this.endpointConfig = endpointConfig; | 27 | this.endpointConfig = endpointConfig; |
| 28 | + this.lmConfig = lmConfig; | ||
| 25 | this.enableEndpoint = enableEndpoint; | 29 | this.enableEndpoint = enableEndpoint; |
| 26 | this.decodingMethod = decodingMethod; | 30 | this.decodingMethod = decodingMethod; |
| 27 | this.maxActivePaths = maxActivePaths; | 31 | this.maxActivePaths = maxActivePaths; |
| 28 | } | 32 | } |
| 29 | 33 | ||
| 34 | + public OnlineLMConfig getLmConfig() { | ||
| 35 | + return lmConfig; | ||
| 36 | + } | ||
| 37 | + | ||
| 30 | public FeatureConfig getFeatConfig() { | 38 | public FeatureConfig getFeatConfig() { |
| 31 | return featConfig; | 39 | return featConfig; |
| 32 | } | 40 | } |
-
请 注册 或 登录 后发表评论