Committed by
GitHub
Add java api for hotwords (#319)
* Add java api * support websocket * Fix kotlin
正在显示
9 个修改的文件
包含
116 行增加
和
50 行删除
| @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig( | @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig( | ||
| 53 | var enableEndpoint: Boolean = true, | 53 | var enableEndpoint: Boolean = true, |
| 54 | var decodingMethod: String = "greedy_search", | 54 | var decodingMethod: String = "greedy_search", |
| 55 | var maxActivePaths: Int = 4, | 55 | var maxActivePaths: Int = 4, |
| 56 | + var hotwordsFile: String = "", | ||
| 57 | + var hotwordsScore: Float = 1.5f, | ||
| 56 | ) | 58 | ) |
| 57 | 59 | ||
| 58 | class SherpaOnnx( | 60 | class SherpaOnnx( |
| 1 | - | ||
| 2 | ENTRY_POINT = ./ | 1 | ENTRY_POINT = ./ |
| 3 | 2 | ||
| 4 | LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx | 3 | LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx |
| @@ -65,18 +64,22 @@ clean: | @@ -65,18 +64,22 @@ clean: | ||
| 65 | mkdir -p ./lib | 64 | mkdir -p ./lib |
| 66 | 65 | ||
| 67 | runfile: | 66 | runfile: |
| 67 | + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav | ||
| 68 | 68 | ||
| 69 | - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile | 69 | +runhotwords: |
| 70 | + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav | ||
| 70 | 71 | ||
| 71 | runmic: | 72 | runmic: |
| 72 | - | ||
| 73 | java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic | 73 | java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic |
| 74 | 74 | ||
| 75 | runsrv: | 75 | runsrv: |
| 76 | - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg | 76 | + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg |
| 77 | 77 | ||
| 78 | runclient: | 78 | runclient: |
| 79 | - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 | 79 | + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 |
| 80 | + | ||
| 81 | +runclienthotwords: | ||
| 82 | + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32 | ||
| 80 | 83 | ||
| 81 | buildlib: $(LIB_FILES:.java=.class) | 84 | buildlib: $(LIB_FILES:.java=.class) |
| 82 | 85 |
| @@ -12,6 +12,8 @@ num_threads=4 | @@ -12,6 +12,8 @@ num_threads=4 | ||
| 12 | enable_endpoint_detection=true | 12 | enable_endpoint_detection=true |
| 13 | decoding_method=modified_beam_search | 13 | decoding_method=modified_beam_search |
| 14 | max_active_paths=4 | 14 | max_active_paths=4 |
| 15 | +hotwords_file= | ||
| 16 | +hotwords_score=1.5 | ||
| 15 | lm_model= | 17 | lm_model= |
| 16 | lm_scale=0.5 | 18 | lm_scale=0.5 |
| 17 | model_type=zipformer | 19 | model_type=zipformer |
| @@ -36,6 +36,8 @@ if [ ! -d $repo ];then | @@ -36,6 +36,8 @@ if [ ! -d $repo ];then | ||
| 36 | git lfs pull --include "*.onnx" | 36 | git lfs pull --include "*.onnx" |
| 37 | ls -lh *.onnx | 37 | ls -lh *.onnx |
| 38 | popd | 38 | popd |
| 39 | + ln -s $repo/test_wavs/0.wav hotwords.wav | ||
| 40 | + | ||
| 39 | fi | 41 | fi |
| 40 | 42 | ||
| 41 | log $(pwd) | 43 | log $(pwd) |
| @@ -64,3 +66,9 @@ cd ../java-api-examples | @@ -64,3 +66,9 @@ cd ../java-api-examples | ||
| 64 | make all | 66 | make all |
| 65 | 67 | ||
| 66 | make runfile | 68 | make runfile |
| 69 | + | ||
| 70 | +echo "礼 拜 二" > hotwords.txt | ||
| 71 | + | ||
| 72 | +sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg | ||
| 73 | + | ||
| 74 | +make runhotwords |
| @@ -49,6 +49,8 @@ public class DecodeFile { | @@ -49,6 +49,8 @@ public class DecodeFile { | ||
| 49 | float rule3MinUtteranceLength = 20F; | 49 | float rule3MinUtteranceLength = 20F; |
| 50 | String decodingMethod = "greedy_search"; | 50 | String decodingMethod = "greedy_search"; |
| 51 | int maxActivePaths = 4; | 51 | int maxActivePaths = 4; |
| 52 | + String hotwordsFile = ""; | ||
| 53 | + float hotwordsScore = 1.5F; | ||
| 52 | String lm_model = ""; | 54 | String lm_model = ""; |
| 53 | float lm_scale = 0.5F; | 55 | float lm_scale = 0.5F; |
| 54 | String modelType = "zipformer"; | 56 | String modelType = "zipformer"; |
| @@ -69,6 +71,8 @@ public class DecodeFile { | @@ -69,6 +71,8 @@ public class DecodeFile { | ||
| 69 | lm_model, | 71 | lm_model, |
| 70 | lm_scale, | 72 | lm_scale, |
| 71 | maxActivePaths, | 73 | maxActivePaths, |
| 74 | + hotwordsFile, | ||
| 75 | + hotwordsScore, | ||
| 72 | modelType); | 76 | modelType); |
| 73 | streamObj = rcgOjb.createStream(); | 77 | streamObj = rcgOjb.createStream(); |
| 74 | } catch (Exception e) { | 78 | } catch (Exception e) { |
| @@ -158,7 +162,7 @@ public class DecodeFile { | @@ -158,7 +162,7 @@ public class DecodeFile { | ||
| 158 | try { | 162 | try { |
| 159 | String appDir = System.getProperty("user.dir"); | 163 | String appDir = System.getProperty("user.dir"); |
| 160 | System.out.println("appdir=" + appDir); | 164 | System.out.println("appdir=" + appDir); |
| 161 | - String fileName = appDir + "/test.wav"; | 165 | + String fileName = appDir + "/" + args[0]; |
| 162 | String cfgPath = appDir + "/modeltest.cfg"; | 166 | String cfgPath = appDir + "/modeltest.cfg"; |
| 163 | String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so"; | 167 | String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so"; |
| 164 | OnlineRecognizer.setSoPath(soPath); | 168 | OnlineRecognizer.setSoPath(soPath); |
| @@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer { | @@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer { | ||
| 140 | } | 140 | } |
| 141 | } | 141 | } |
| 142 | 142 | ||
| 143 | - | ||
| 144 | - | ||
| 145 | public boolean streamQueueFind(WebSocket conn) { | 143 | public boolean streamQueueFind(WebSocket conn) { |
| 146 | return streamQueue.contains(conn); | 144 | return streamQueue.contains(conn); |
| 147 | } | 145 | } |
| @@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer { | @@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer { | ||
| 151 | 149 | ||
| 152 | rcgOjb = new OnlineRecognizer(cfgPath); | 150 | rcgOjb = new OnlineRecognizer(cfgPath); |
| 153 | // size of stream thread pool | 151 | // size of stream thread pool |
| 154 | - int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num")); | 152 | + int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16")); |
| 155 | // size of decoder thread pool | 153 | // size of decoder thread pool |
| 156 | - int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num")); | 154 | + int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16")); |
| 157 | 155 | ||
| 158 | // time(ms) idle for decoder thread when no job | 156 | // time(ms) idle for decoder thread when no job |
| 159 | - int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle")); | 157 | + int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200")); |
| 160 | // size of streams for parallel decoding | 158 | // size of streams for parallel decoding |
| 161 | - int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num")); | 159 | + int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16")); |
| 162 | // time(ms) out for connection data | 160 | // time(ms) out for connection data |
| 163 | - int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out")); | 161 | + int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000")); |
| 164 | 162 | ||
| 165 | // create stream threads | 163 | // create stream threads |
| 166 | for (int i = 0; i < streamThreadNum; i++) { | 164 | for (int i = 0; i < streamThreadNum; i++) { |
| @@ -222,9 +220,9 @@ public class AsrWebsocketServer extends WebSocketServer { | @@ -222,9 +220,9 @@ public class AsrWebsocketServer extends WebSocketServer { | ||
| 222 | OnlineRecognizer.setSoPath(soPath); | 220 | OnlineRecognizer.setSoPath(soPath); |
| 223 | logger.info("readProperties"); | 221 | logger.info("readProperties"); |
| 224 | Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); | 222 | Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); |
| 225 | - int port = Integer.valueOf(cfgMap.get("port")); | 223 | + int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890")); |
| 226 | 224 | ||
| 227 | - int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); | 225 | + int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16")); |
| 228 | AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); | 226 | AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); |
| 229 | logger.info("initModelWithCfg"); | 227 | logger.info("initModelWithCfg"); |
| 230 | s.initModelWithCfg(cfgMap, cfgPath); | 228 | s.initModelWithCfg(cfgMap, cfgPath); |
| @@ -44,38 +44,48 @@ public class OnlineRecognizer { | @@ -44,38 +44,48 @@ public class OnlineRecognizer { | ||
| 44 | public OnlineRecognizer(String modelCfgPath) { | 44 | public OnlineRecognizer(String modelCfgPath) { |
| 45 | Map<String, String> proMap = this.readProperties(modelCfgPath); | 45 | Map<String, String> proMap = this.readProperties(modelCfgPath); |
| 46 | try { | 46 | try { |
| 47 | - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); | 47 | + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); |
| 48 | this.sampleRate = sampleRate; | 48 | this.sampleRate = sampleRate; |
| 49 | EndpointRule rule1 = | 49 | EndpointRule rule1 = |
| 50 | new EndpointRule( | 50 | new EndpointRule( |
| 51 | - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); | 51 | + false, |
| 52 | + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), | ||
| 53 | + 0.0F); | ||
| 52 | EndpointRule rule2 = | 54 | EndpointRule rule2 = |
| 53 | new EndpointRule( | 55 | new EndpointRule( |
| 54 | - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); | 56 | + true, |
| 57 | + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), | ||
| 58 | + 0.0F); | ||
| 55 | EndpointRule rule3 = | 59 | EndpointRule rule3 = |
| 56 | new EndpointRule( | 60 | new EndpointRule( |
| 57 | - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); | 61 | + false, |
| 62 | + 0.0F, | ||
| 63 | + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); | ||
| 58 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); | 64 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); |
| 59 | 65 | ||
| 60 | - OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim()); | 66 | + OnlineParaformerModelConfig modelParaCfg = |
| 67 | + new OnlineParaformerModelConfig( | ||
| 68 | + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); | ||
| 61 | OnlineTransducerModelConfig modelTranCfg = | 69 | OnlineTransducerModelConfig modelTranCfg = |
| 62 | new OnlineTransducerModelConfig( | 70 | new OnlineTransducerModelConfig( |
| 63 | - proMap.get("encoder").trim(), | ||
| 64 | - proMap.get("decoder").trim(), | ||
| 65 | - proMap.get("joiner").trim()); | 71 | + proMap.getOrDefault("encoder", "").trim(), |
| 72 | + proMap.getOrDefault("decoder", "").trim(), | ||
| 73 | + proMap.getOrDefault("joiner", "").trim()); | ||
| 66 | OnlineModelConfig modelCfg = | 74 | OnlineModelConfig modelCfg = |
| 67 | new OnlineModelConfig( | 75 | new OnlineModelConfig( |
| 68 | - proMap.get("tokens").trim(), | ||
| 69 | - Integer.parseInt(proMap.get("num_threads").trim()), | 76 | + proMap.getOrDefault("tokens", "").trim(), |
| 77 | + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), | ||
| 70 | false, | 78 | false, |
| 71 | - proMap.get("model_type").trim(), | 79 | + proMap.getOrDefault("model_type", "zipformer").trim(), |
| 72 | modelParaCfg, | 80 | modelParaCfg, |
| 73 | modelTranCfg); | 81 | modelTranCfg); |
| 74 | FeatureConfig featConfig = | 82 | FeatureConfig featConfig = |
| 75 | - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 83 | + new FeatureConfig( |
| 84 | + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); | ||
| 76 | OnlineLMConfig onlineLmConfig = | 85 | OnlineLMConfig onlineLmConfig = |
| 77 | new OnlineLMConfig( | 86 | new OnlineLMConfig( |
| 78 | - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); | 87 | + proMap.getOrDefault("lm_model", "").trim(), |
| 88 | + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); | ||
| 79 | 89 | ||
| 80 | OnlineRecognizerConfig rcgCfg = | 90 | OnlineRecognizerConfig rcgCfg = |
| 81 | new OnlineRecognizerConfig( | 91 | new OnlineRecognizerConfig( |
| @@ -83,9 +93,11 @@ public class OnlineRecognizer { | @@ -83,9 +93,11 @@ public class OnlineRecognizer { | ||
| 83 | modelCfg, | 93 | modelCfg, |
| 84 | endCfg, | 94 | endCfg, |
| 85 | onlineLmConfig, | 95 | onlineLmConfig, |
| 86 | - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), | ||
| 87 | - proMap.get("decoding_method").trim(), | ||
| 88 | - Integer.parseInt(proMap.get("max_active_paths").trim())); | 96 | + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), |
| 97 | + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), | ||
| 98 | + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), | ||
| 99 | + proMap.getOrDefault("hotwords_file", "").trim(), | ||
| 100 | + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); | ||
| 89 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 | 101 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 |
| 90 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); | 102 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); |
| 91 | 103 | ||
| @@ -98,41 +110,49 @@ public class OnlineRecognizer { | @@ -98,41 +110,49 @@ public class OnlineRecognizer { | ||
| 98 | public OnlineRecognizer(Object assetManager, String modelCfgPath) { | 110 | public OnlineRecognizer(Object assetManager, String modelCfgPath) { |
| 99 | Map<String, String> proMap = this.readProperties(modelCfgPath); | 111 | Map<String, String> proMap = this.readProperties(modelCfgPath); |
| 100 | try { | 112 | try { |
| 101 | - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); | 113 | + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim()); |
| 102 | this.sampleRate = sampleRate; | 114 | this.sampleRate = sampleRate; |
| 103 | EndpointRule rule1 = | 115 | EndpointRule rule1 = |
| 104 | new EndpointRule( | 116 | new EndpointRule( |
| 105 | - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); | 117 | + false, |
| 118 | + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()), | ||
| 119 | + 0.0F); | ||
| 106 | EndpointRule rule2 = | 120 | EndpointRule rule2 = |
| 107 | new EndpointRule( | 121 | new EndpointRule( |
| 108 | - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); | 122 | + true, |
| 123 | + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()), | ||
| 124 | + 0.0F); | ||
| 109 | EndpointRule rule3 = | 125 | EndpointRule rule3 = |
| 110 | new EndpointRule( | 126 | new EndpointRule( |
| 111 | - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); | 127 | + false, |
| 128 | + 0.0F, | ||
| 129 | + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim())); | ||
| 112 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); | 130 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); |
| 113 | OnlineParaformerModelConfig modelParaCfg = | 131 | OnlineParaformerModelConfig modelParaCfg = |
| 114 | new OnlineParaformerModelConfig( | 132 | new OnlineParaformerModelConfig( |
| 115 | - proMap.get("encoder").trim(), proMap.get("decoder").trim()); | 133 | + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim()); |
| 116 | OnlineTransducerModelConfig modelTranCfg = | 134 | OnlineTransducerModelConfig modelTranCfg = |
| 117 | new OnlineTransducerModelConfig( | 135 | new OnlineTransducerModelConfig( |
| 118 | - proMap.get("encoder").trim(), | ||
| 119 | - proMap.get("decoder").trim(), | ||
| 120 | - proMap.get("joiner").trim()); | 136 | + proMap.getOrDefault("encoder", "").trim(), |
| 137 | + proMap.getOrDefault("decoder", "").trim(), | ||
| 138 | + proMap.getOrDefault("joiner", "").trim()); | ||
| 121 | 139 | ||
| 122 | OnlineModelConfig modelCfg = | 140 | OnlineModelConfig modelCfg = |
| 123 | new OnlineModelConfig( | 141 | new OnlineModelConfig( |
| 124 | - proMap.get("tokens").trim(), | ||
| 125 | - Integer.parseInt(proMap.get("num_threads").trim()), | 142 | + proMap.getOrDefault("tokens", "").trim(), |
| 143 | + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()), | ||
| 126 | false, | 144 | false, |
| 127 | - proMap.get("model_type").trim(), | 145 | + proMap.getOrDefault("model_type", "zipformer").trim(), |
| 128 | modelParaCfg, | 146 | modelParaCfg, |
| 129 | modelTranCfg); | 147 | modelTranCfg); |
| 130 | FeatureConfig featConfig = | 148 | FeatureConfig featConfig = |
| 131 | - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 149 | + new FeatureConfig( |
| 150 | + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim())); | ||
| 132 | 151 | ||
| 133 | OnlineLMConfig onlineLmConfig = | 152 | OnlineLMConfig onlineLmConfig = |
| 134 | new OnlineLMConfig( | 153 | new OnlineLMConfig( |
| 135 | - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); | 154 | + proMap.getOrDefault("lm_model", "").trim(), |
| 155 | + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim())); | ||
| 136 | 156 | ||
| 137 | OnlineRecognizerConfig rcgCfg = | 157 | OnlineRecognizerConfig rcgCfg = |
| 138 | new OnlineRecognizerConfig( | 158 | new OnlineRecognizerConfig( |
| @@ -140,9 +160,11 @@ public class OnlineRecognizer { | @@ -140,9 +160,11 @@ public class OnlineRecognizer { | ||
| 140 | modelCfg, | 160 | modelCfg, |
| 141 | endCfg, | 161 | endCfg, |
| 142 | onlineLmConfig, | 162 | onlineLmConfig, |
| 143 | - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), | ||
| 144 | - proMap.get("decoding_method").trim(), | ||
| 145 | - Integer.parseInt(proMap.get("max_active_paths").trim())); | 163 | + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()), |
| 164 | + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(), | ||
| 165 | + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()), | ||
| 166 | + proMap.getOrDefault("hotwords_file", "").trim(), | ||
| 167 | + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim())); | ||
| 146 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 | 168 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 |
| 147 | this.ptr = createOnlineRecognizer(assetManager, rcgCfg); | 169 | this.ptr = createOnlineRecognizer(assetManager, rcgCfg); |
| 148 | 170 | ||
| @@ -168,6 +190,8 @@ public class OnlineRecognizer { | @@ -168,6 +190,8 @@ public class OnlineRecognizer { | ||
| 168 | String lm_model, | 190 | String lm_model, |
| 169 | float lm_scale, | 191 | float lm_scale, |
| 170 | int maxActivePaths, | 192 | int maxActivePaths, |
| 193 | + String hotwordsFile, | ||
| 194 | + float hotwordsScore, | ||
| 171 | String modelType) { | 195 | String modelType) { |
| 172 | this.sampleRate = sampleRate; | 196 | this.sampleRate = sampleRate; |
| 173 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); | 197 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); |
| @@ -189,7 +213,9 @@ public class OnlineRecognizer { | @@ -189,7 +213,9 @@ public class OnlineRecognizer { | ||
| 189 | onlineLmConfig, | 213 | onlineLmConfig, |
| 190 | enableEndpointDetection, | 214 | enableEndpointDetection, |
| 191 | decodingMethod, | 215 | decodingMethod, |
| 192 | - maxActivePaths); | 216 | + maxActivePaths, |
| 217 | + hotwordsFile, | ||
| 218 | + hotwordsScore); | ||
| 193 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 | 219 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 |
| 194 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); | 220 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); |
| 195 | } | 221 | } |
| @@ -211,7 +237,6 @@ public class OnlineRecognizer { | @@ -211,7 +237,6 @@ public class OnlineRecognizer { | ||
| 211 | String key = (String) en.nextElement(); | 237 | String key = (String) en.nextElement(); |
| 212 | String Property = props.getProperty(key); | 238 | String Property = props.getProperty(key); |
| 213 | proMap.put(key, Property); | 239 | proMap.put(key, Property); |
| 214 | - // System.out.println(key+"="+Property); | ||
| 215 | } | 240 | } |
| 216 | 241 | ||
| 217 | } catch (Exception e) { | 242 | } catch (Exception e) { |
| @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig { | @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig { | ||
| 12 | private final boolean enableEndpoint; | 12 | private final boolean enableEndpoint; |
| 13 | private final String decodingMethod; | 13 | private final String decodingMethod; |
| 14 | private final int maxActivePaths; | 14 | private final int maxActivePaths; |
| 15 | + private final String hotwordsFile; | ||
| 16 | + private final float hotwordsScore; | ||
| 15 | 17 | ||
| 16 | public OnlineRecognizerConfig( | 18 | public OnlineRecognizerConfig( |
| 17 | FeatureConfig featConfig, | 19 | FeatureConfig featConfig, |
| @@ -20,7 +22,9 @@ public class OnlineRecognizerConfig { | @@ -20,7 +22,9 @@ public class OnlineRecognizerConfig { | ||
| 20 | OnlineLMConfig lmConfig, | 22 | OnlineLMConfig lmConfig, |
| 21 | boolean enableEndpoint, | 23 | boolean enableEndpoint, |
| 22 | String decodingMethod, | 24 | String decodingMethod, |
| 23 | - int maxActivePaths) { | 25 | + int maxActivePaths, |
| 26 | + String hotwordsFile, | ||
| 27 | + float hotwordsScore) { | ||
| 24 | this.featConfig = featConfig; | 28 | this.featConfig = featConfig; |
| 25 | this.modelConfig = modelConfig; | 29 | this.modelConfig = modelConfig; |
| 26 | this.endpointConfig = endpointConfig; | 30 | this.endpointConfig = endpointConfig; |
| @@ -28,6 +32,8 @@ public class OnlineRecognizerConfig { | @@ -28,6 +32,8 @@ public class OnlineRecognizerConfig { | ||
| 28 | this.enableEndpoint = enableEndpoint; | 32 | this.enableEndpoint = enableEndpoint; |
| 29 | this.decodingMethod = decodingMethod; | 33 | this.decodingMethod = decodingMethod; |
| 30 | this.maxActivePaths = maxActivePaths; | 34 | this.maxActivePaths = maxActivePaths; |
| 35 | + this.hotwordsFile = hotwordsFile; | ||
| 36 | + this.hotwordsScore = hotwordsScore; | ||
| 31 | } | 37 | } |
| 32 | 38 | ||
| 33 | public OnlineLMConfig getLmConfig() { | 39 | public OnlineLMConfig getLmConfig() { |
| @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 125 | fid = env->GetFieldID(cls, "maxActivePaths", "I"); | 125 | fid = env->GetFieldID(cls, "maxActivePaths", "I"); |
| 126 | ans.max_active_paths = env->GetIntField(config, fid); | 126 | ans.max_active_paths = env->GetIntField(config, fid); |
| 127 | 127 | ||
| 128 | + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); | ||
| 129 | + s = (jstring)env->GetObjectField(config, fid); | ||
| 130 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 131 | + ans.hotwords_file = p; | ||
| 132 | + env->ReleaseStringUTFChars(s, p); | ||
| 133 | + | ||
| 134 | + fid = env->GetFieldID(cls, "hotwordsScore", "F"); | ||
| 135 | + ans.hotwords_score = env->GetFloatField(config, fid); | ||
| 136 | + | ||
| 128 | //---------- feat config ---------- | 137 | //---------- feat config ---------- |
| 129 | fid = env->GetFieldID(cls, "featConfig", | 138 | fid = env->GetFieldID(cls, "featConfig", |
| 130 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); | 139 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); |
| @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { | ||
| 293 | fid = env->GetFieldID(cls, "maxActivePaths", "I"); | 302 | fid = env->GetFieldID(cls, "maxActivePaths", "I"); |
| 294 | ans.max_active_paths = env->GetIntField(config, fid); | 303 | ans.max_active_paths = env->GetIntField(config, fid); |
| 295 | 304 | ||
| 305 | + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); | ||
| 306 | + s = (jstring)env->GetObjectField(config, fid); | ||
| 307 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 308 | + ans.hotwords_file = p; | ||
| 309 | + env->ReleaseStringUTFChars(s, p); | ||
| 310 | + | ||
| 311 | + fid = env->GetFieldID(cls, "hotwordsScore", "F"); | ||
| 312 | + ans.hotwords_score = env->GetFloatField(config, fid); | ||
| 313 | + | ||
| 296 | //---------- feat config ---------- | 314 | //---------- feat config ---------- |
| 297 | fid = env->GetFieldID(cls, "featConfig", | 315 | fid = env->GetFieldID(cls, "featConfig", |
| 298 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); | 316 | "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); |
-
请 注册 或 登录 后发表评论