Wei Kang
Committed by GitHub

Add java api for hotwords (#319)

* Add java api

* support websocket

* Fix kotlin
@@ -53,6 +53,8 @@ data class OnlineRecognizerConfig( @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig(
53 var enableEndpoint: Boolean = true, 53 var enableEndpoint: Boolean = true,
54 var decodingMethod: String = "greedy_search", 54 var decodingMethod: String = "greedy_search",
55 var maxActivePaths: Int = 4, 55 var maxActivePaths: Int = 4,
  56 + var hotwordsFile: String = "",
  57 + var hotwordsScore: Float = 1.5f,
56 ) 58 )
57 59
58 class SherpaOnnx( 60 class SherpaOnnx(
1 -  
2 ENTRY_POINT = ./ 1 ENTRY_POINT = ./
3 2
4 LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx 3 LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx
@@ -65,18 +64,22 @@ clean: @@ -65,18 +64,22 @@ clean:
65 mkdir -p ./lib 64 mkdir -p ./lib
66 65
67 runfile: 66 runfile:
  67 + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
68 68
69 - java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile 69 +runhotwords:
  70 + java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
70 71
71 runmic: 72 runmic:
72 -  
73 java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic 73 java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
74 74
75 runsrv: 75 runsrv:
76 - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg 76 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
77 77
78 runclient: 78 runclient:
79 - java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32 79 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
  80 +
  81 +runclienthotwords:
  82 + java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32
80 83
81 buildlib: $(LIB_FILES:.java=.class) 84 buildlib: $(LIB_FILES:.java=.class)
82 85
@@ -12,6 +12,8 @@ num_threads=4 @@ -12,6 +12,8 @@ num_threads=4
12 enable_endpoint_detection=true 12 enable_endpoint_detection=true
13 decoding_method=modified_beam_search 13 decoding_method=modified_beam_search
14 max_active_paths=4 14 max_active_paths=4
  15 +hotwords_file=
  16 +hotwords_score=1.5
15 lm_model= 17 lm_model=
16 lm_scale=0.5 18 lm_scale=0.5
17 model_type=zipformer 19 model_type=zipformer
@@ -36,6 +36,8 @@ if [ ! -d $repo ];then @@ -36,6 +36,8 @@ if [ ! -d $repo ];then
36 git lfs pull --include "*.onnx" 36 git lfs pull --include "*.onnx"
37 ls -lh *.onnx 37 ls -lh *.onnx
38 popd 38 popd
  39 + ln -s $repo/test_wavs/0.wav hotwords.wav
  40 +
39 fi 41 fi
40 42
41 log $(pwd) 43 log $(pwd)
@@ -64,3 +66,9 @@ cd ../java-api-examples @@ -64,3 +66,9 @@ cd ../java-api-examples
64 make all 66 make all
65 67
66 make runfile 68 make runfile
  69 +
  70 +echo "礼 拜 二" > hotwords.txt
  71 +
  72 +sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg
  73 +
  74 +make runhotwords
@@ -49,6 +49,8 @@ public class DecodeFile { @@ -49,6 +49,8 @@ public class DecodeFile {
49 float rule3MinUtteranceLength = 20F; 49 float rule3MinUtteranceLength = 20F;
50 String decodingMethod = "greedy_search"; 50 String decodingMethod = "greedy_search";
51 int maxActivePaths = 4; 51 int maxActivePaths = 4;
  52 + String hotwordsFile = "";
  53 + float hotwordsScore = 1.5F;
52 String lm_model = ""; 54 String lm_model = "";
53 float lm_scale = 0.5F; 55 float lm_scale = 0.5F;
54 String modelType = "zipformer"; 56 String modelType = "zipformer";
@@ -69,6 +71,8 @@ public class DecodeFile { @@ -69,6 +71,8 @@ public class DecodeFile {
69 lm_model, 71 lm_model,
70 lm_scale, 72 lm_scale,
71 maxActivePaths, 73 maxActivePaths,
  74 + hotwordsFile,
  75 + hotwordsScore,
72 modelType); 76 modelType);
73 streamObj = rcgOjb.createStream(); 77 streamObj = rcgOjb.createStream();
74 } catch (Exception e) { 78 } catch (Exception e) {
@@ -158,7 +162,7 @@ public class DecodeFile { @@ -158,7 +162,7 @@ public class DecodeFile {
158 try { 162 try {
159 String appDir = System.getProperty("user.dir"); 163 String appDir = System.getProperty("user.dir");
160 System.out.println("appdir=" + appDir); 164 System.out.println("appdir=" + appDir);
161 - String fileName = appDir + "/test.wav"; 165 + String fileName = appDir + "/" + args[0];
162 String cfgPath = appDir + "/modeltest.cfg"; 166 String cfgPath = appDir + "/modeltest.cfg";
163 String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so"; 167 String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so";
164 OnlineRecognizer.setSoPath(soPath); 168 OnlineRecognizer.setSoPath(soPath);
@@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer { @@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer {
140 } 140 }
141 } 141 }
142 142
143 -  
144 -  
145 public boolean streamQueueFind(WebSocket conn) { 143 public boolean streamQueueFind(WebSocket conn) {
146 return streamQueue.contains(conn); 144 return streamQueue.contains(conn);
147 } 145 }
@@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer { @@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer {
151 149
152 rcgOjb = new OnlineRecognizer(cfgPath); 150 rcgOjb = new OnlineRecognizer(cfgPath);
153 // size of stream thread pool 151 // size of stream thread pool
154 - int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num")); 152 + int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
155 // size of decoder thread pool 153 // size of decoder thread pool
156 - int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num")); 154 + int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));
157 155
158 // time(ms) idle for decoder thread when no job 156 // time(ms) idle for decoder thread when no job
159 - int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle")); 157 + int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
160 // size of streams for parallel decoding 158 // size of streams for parallel decoding
161 - int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num")); 159 + int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
162 // time(ms) out for connection data 160 // time(ms) out for connection data
163 - int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out")); 161 + int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));
164 162
165 // create stream threads 163 // create stream threads
166 for (int i = 0; i < streamThreadNum; i++) { 164 for (int i = 0; i < streamThreadNum; i++) {
@@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer { @@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer {
218 216
219 String soPath = args[0]; 217 String soPath = args[0];
220 String cfgPath = args[1]; 218 String cfgPath = args[1];
221 - 219 +
222 OnlineRecognizer.setSoPath(soPath); 220 OnlineRecognizer.setSoPath(soPath);
223 logger.info("readProperties"); 221 logger.info("readProperties");
224 Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); 222 Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
225 - int port = Integer.valueOf(cfgMap.get("port")); 223 + int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));
226 224
227 - int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); 225 + int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
228 AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); 226 AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
229 logger.info("initModelWithCfg"); 227 logger.info("initModelWithCfg");
230 s.initModelWithCfg(cfgMap, cfgPath); 228 s.initModelWithCfg(cfgMap, cfgPath);
@@ -44,38 +44,48 @@ public class OnlineRecognizer { @@ -44,38 +44,48 @@ public class OnlineRecognizer {
44 public OnlineRecognizer(String modelCfgPath) { 44 public OnlineRecognizer(String modelCfgPath) {
45 Map<String, String> proMap = this.readProperties(modelCfgPath); 45 Map<String, String> proMap = this.readProperties(modelCfgPath);
46 try { 46 try {
47 - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); 47 + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
48 this.sampleRate = sampleRate; 48 this.sampleRate = sampleRate;
49 EndpointRule rule1 = 49 EndpointRule rule1 =
50 new EndpointRule( 50 new EndpointRule(
51 - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); 51 + false,
  52 + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
  53 + 0.0F);
52 EndpointRule rule2 = 54 EndpointRule rule2 =
53 new EndpointRule( 55 new EndpointRule(
54 - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); 56 + true,
  57 + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
  58 + 0.0F);
55 EndpointRule rule3 = 59 EndpointRule rule3 =
56 new EndpointRule( 60 new EndpointRule(
57 - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); 61 + false,
  62 + 0.0F,
  63 + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
58 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 64 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
59 65
60 - OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim()); 66 + OnlineParaformerModelConfig modelParaCfg =
  67 + new OnlineParaformerModelConfig(
  68 + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
61 OnlineTransducerModelConfig modelTranCfg = 69 OnlineTransducerModelConfig modelTranCfg =
62 new OnlineTransducerModelConfig( 70 new OnlineTransducerModelConfig(
63 - proMap.get("encoder").trim(),  
64 - proMap.get("decoder").trim(),  
65 - proMap.get("joiner").trim()); 71 + proMap.getOrDefault("encoder", "").trim(),
  72 + proMap.getOrDefault("decoder", "").trim(),
  73 + proMap.getOrDefault("joiner", "").trim());
66 OnlineModelConfig modelCfg = 74 OnlineModelConfig modelCfg =
67 new OnlineModelConfig( 75 new OnlineModelConfig(
68 - proMap.get("tokens").trim(),  
69 - Integer.parseInt(proMap.get("num_threads").trim()), 76 + proMap.getOrDefault("tokens", "").trim(),
  77 + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
70 false, 78 false,
71 - proMap.get("model_type").trim(), 79 + proMap.getOrDefault("model_type", "zipformer").trim(),
72 modelParaCfg, 80 modelParaCfg,
73 modelTranCfg); 81 modelTranCfg);
74 FeatureConfig featConfig = 82 FeatureConfig featConfig =
75 - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 83 + new FeatureConfig(
  84 + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
76 OnlineLMConfig onlineLmConfig = 85 OnlineLMConfig onlineLmConfig =
77 new OnlineLMConfig( 86 new OnlineLMConfig(
78 - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); 87 + proMap.getOrDefault("lm_model", "").trim(),
  88 + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
79 89
80 OnlineRecognizerConfig rcgCfg = 90 OnlineRecognizerConfig rcgCfg =
81 new OnlineRecognizerConfig( 91 new OnlineRecognizerConfig(
@@ -83,9 +93,11 @@ public class OnlineRecognizer { @@ -83,9 +93,11 @@ public class OnlineRecognizer {
83 modelCfg, 93 modelCfg,
84 endCfg, 94 endCfg,
85 onlineLmConfig, 95 onlineLmConfig,
86 - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),  
87 - proMap.get("decoding_method").trim(),  
88 - Integer.parseInt(proMap.get("max_active_paths").trim())); 96 + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
  97 + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
  98 + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
  99 + proMap.getOrDefault("hotwords_file", "").trim(),
  100 + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
89 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 101 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
90 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); 102 this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
91 103
@@ -98,41 +110,49 @@ public class OnlineRecognizer { @@ -98,41 +110,49 @@ public class OnlineRecognizer {
98 public OnlineRecognizer(Object assetManager, String modelCfgPath) { 110 public OnlineRecognizer(Object assetManager, String modelCfgPath) {
99 Map<String, String> proMap = this.readProperties(modelCfgPath); 111 Map<String, String> proMap = this.readProperties(modelCfgPath);
100 try { 112 try {
101 - int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim()); 113 + int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
102 this.sampleRate = sampleRate; 114 this.sampleRate = sampleRate;
103 EndpointRule rule1 = 115 EndpointRule rule1 =
104 new EndpointRule( 116 new EndpointRule(
105 - false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F); 117 + false,
  118 + Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
  119 + 0.0F);
106 EndpointRule rule2 = 120 EndpointRule rule2 =
107 new EndpointRule( 121 new EndpointRule(
108 - true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F); 122 + true,
  123 + Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
  124 + 0.0F);
109 EndpointRule rule3 = 125 EndpointRule rule3 =
110 new EndpointRule( 126 new EndpointRule(
111 - false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); 127 + false,
  128 + 0.0F,
  129 + Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
112 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 130 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
113 OnlineParaformerModelConfig modelParaCfg = 131 OnlineParaformerModelConfig modelParaCfg =
114 new OnlineParaformerModelConfig( 132 new OnlineParaformerModelConfig(
115 - proMap.get("encoder").trim(), proMap.get("decoder").trim()); 133 + proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
116 OnlineTransducerModelConfig modelTranCfg = 134 OnlineTransducerModelConfig modelTranCfg =
117 new OnlineTransducerModelConfig( 135 new OnlineTransducerModelConfig(
118 - proMap.get("encoder").trim(),  
119 - proMap.get("decoder").trim(),  
120 - proMap.get("joiner").trim()); 136 + proMap.getOrDefault("encoder", "").trim(),
  137 + proMap.getOrDefault("decoder", "").trim(),
  138 + proMap.getOrDefault("joiner", "").trim());
121 139
122 OnlineModelConfig modelCfg = 140 OnlineModelConfig modelCfg =
123 new OnlineModelConfig( 141 new OnlineModelConfig(
124 - proMap.get("tokens").trim(),  
125 - Integer.parseInt(proMap.get("num_threads").trim()), 142 + proMap.getOrDefault("tokens", "").trim(),
  143 + Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
126 false, 144 false,
127 - proMap.get("model_type").trim(), 145 + proMap.getOrDefault("model_type", "zipformer").trim(),
128 modelParaCfg, 146 modelParaCfg,
129 modelTranCfg); 147 modelTranCfg);
130 FeatureConfig featConfig = 148 FeatureConfig featConfig =
131 - new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 149 + new FeatureConfig(
  150 + sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
132 151
133 OnlineLMConfig onlineLmConfig = 152 OnlineLMConfig onlineLmConfig =
134 new OnlineLMConfig( 153 new OnlineLMConfig(
135 - proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); 154 + proMap.getOrDefault("lm_model", "").trim(),
  155 + Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
136 156
137 OnlineRecognizerConfig rcgCfg = 157 OnlineRecognizerConfig rcgCfg =
138 new OnlineRecognizerConfig( 158 new OnlineRecognizerConfig(
@@ -140,9 +160,11 @@ public class OnlineRecognizer { @@ -140,9 +160,11 @@ public class OnlineRecognizer {
140 modelCfg, 160 modelCfg,
141 endCfg, 161 endCfg,
142 onlineLmConfig, 162 onlineLmConfig,
143 - Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),  
144 - proMap.get("decoding_method").trim(),  
145 - Integer.parseInt(proMap.get("max_active_paths").trim())); 163 + Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
  164 + proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
  165 + Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
  166 + proMap.getOrDefault("hotwords_file", "").trim(),
  167 + Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
146 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 168 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
147 this.ptr = createOnlineRecognizer(assetManager, rcgCfg); 169 this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
148 170
@@ -168,6 +190,8 @@ public class OnlineRecognizer { @@ -168,6 +190,8 @@ public class OnlineRecognizer {
168 String lm_model, 190 String lm_model,
169 float lm_scale, 191 float lm_scale,
170 int maxActivePaths, 192 int maxActivePaths,
  193 + String hotwordsFile,
  194 + float hotwordsScore,
171 String modelType) { 195 String modelType) {
172 this.sampleRate = sampleRate; 196 this.sampleRate = sampleRate;
173 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); 197 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
@@ -189,7 +213,9 @@ public class OnlineRecognizer { @@ -189,7 +213,9 @@ public class OnlineRecognizer {
189 onlineLmConfig, 213 onlineLmConfig,
190 enableEndpointDetection, 214 enableEndpointDetection,
191 decodingMethod, 215 decodingMethod,
192 - maxActivePaths); 216 + maxActivePaths,
  217 + hotwordsFile,
  218 + hotwordsScore);
193 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 219 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
194 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); 220 this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
195 } 221 }
@@ -211,7 +237,6 @@ public class OnlineRecognizer { @@ -211,7 +237,6 @@ public class OnlineRecognizer {
211 String key = (String) en.nextElement(); 237 String key = (String) en.nextElement();
212 String Property = props.getProperty(key); 238 String Property = props.getProperty(key);
213 proMap.put(key, Property); 239 proMap.put(key, Property);
214 - // System.out.println(key+"="+Property);  
215 } 240 }
216 241
217 } catch (Exception e) { 242 } catch (Exception e) {
@@ -12,6 +12,8 @@ public class OnlineRecognizerConfig { @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
12 private final boolean enableEndpoint; 12 private final boolean enableEndpoint;
13 private final String decodingMethod; 13 private final String decodingMethod;
14 private final int maxActivePaths; 14 private final int maxActivePaths;
  15 + private final String hotwordsFile;
  16 + private final float hotwordsScore;
15 17
16 public OnlineRecognizerConfig( 18 public OnlineRecognizerConfig(
17 FeatureConfig featConfig, 19 FeatureConfig featConfig,
@@ -20,7 +22,9 @@ public class OnlineRecognizerConfig { @@ -20,7 +22,9 @@ public class OnlineRecognizerConfig {
20 OnlineLMConfig lmConfig, 22 OnlineLMConfig lmConfig,
21 boolean enableEndpoint, 23 boolean enableEndpoint,
22 String decodingMethod, 24 String decodingMethod,
23 - int maxActivePaths) { 25 + int maxActivePaths,
  26 + String hotwordsFile,
  27 + float hotwordsScore) {
24 this.featConfig = featConfig; 28 this.featConfig = featConfig;
25 this.modelConfig = modelConfig; 29 this.modelConfig = modelConfig;
26 this.endpointConfig = endpointConfig; 30 this.endpointConfig = endpointConfig;
@@ -28,6 +32,8 @@ public class OnlineRecognizerConfig { @@ -28,6 +32,8 @@ public class OnlineRecognizerConfig {
28 this.enableEndpoint = enableEndpoint; 32 this.enableEndpoint = enableEndpoint;
29 this.decodingMethod = decodingMethod; 33 this.decodingMethod = decodingMethod;
30 this.maxActivePaths = maxActivePaths; 34 this.maxActivePaths = maxActivePaths;
  35 + this.hotwordsFile = hotwordsFile;
  36 + this.hotwordsScore = hotwordsScore;
31 } 37 }
32 38
33 public OnlineLMConfig getLmConfig() { 39 public OnlineLMConfig getLmConfig() {
@@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
125 fid = env->GetFieldID(cls, "maxActivePaths", "I"); 125 fid = env->GetFieldID(cls, "maxActivePaths", "I");
126 ans.max_active_paths = env->GetIntField(config, fid); 126 ans.max_active_paths = env->GetIntField(config, fid);
127 127
  128 + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
  129 + s = (jstring)env->GetObjectField(config, fid);
  130 + p = env->GetStringUTFChars(s, nullptr);
  131 + ans.hotwords_file = p;
  132 + env->ReleaseStringUTFChars(s, p);
  133 +
  134 + fid = env->GetFieldID(cls, "hotwordsScore", "F");
  135 + ans.hotwords_score = env->GetFloatField(config, fid);
  136 +
128 //---------- feat config ---------- 137 //---------- feat config ----------
129 fid = env->GetFieldID(cls, "featConfig", 138 fid = env->GetFieldID(cls, "featConfig",
130 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); 139 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
@@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
293 fid = env->GetFieldID(cls, "maxActivePaths", "I"); 302 fid = env->GetFieldID(cls, "maxActivePaths", "I");
294 ans.max_active_paths = env->GetIntField(config, fid); 303 ans.max_active_paths = env->GetIntField(config, fid);
295 304
  305 + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
  306 + s = (jstring)env->GetObjectField(config, fid);
  307 + p = env->GetStringUTFChars(s, nullptr);
  308 + ans.hotwords_file = p;
  309 + env->ReleaseStringUTFChars(s, p);
  310 +
  311 + fid = env->GetFieldID(cls, "hotwordsScore", "F");
  312 + ans.hotwords_score = env->GetFloatField(config, fid);
  313 +
296 //---------- feat config ---------- 314 //---------- feat config ----------
297 fid = env->GetFieldID(cls, "featConfig", 315 fid = env->GetFieldID(cls, "featConfig",
298 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); 316 "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");