zhaomingwork
Committed by GitHub

update java for paraformer (#276)

@@ -220,12 +220,13 @@ public class AsrWebsocketServer extends WebSocketServer { @@ -220,12 +220,13 @@ public class AsrWebsocketServer extends WebSocketServer {
220 String cfgPath = args[1]; 220 String cfgPath = args[1];
221 221
222 OnlineRecognizer.setSoPath(soPath); 222 OnlineRecognizer.setSoPath(soPath);
223 - 223 + logger.info("readProperties");
224 Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath); 224 Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
225 int port = Integer.valueOf(cfgMap.get("port")); 225 int port = Integer.valueOf(cfgMap.get("port"));
226 226
227 int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num")); 227 int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
228 AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum); 228 AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
  229 + logger.info("initModelWithCfg");
229 s.initModelWithCfg(cfgMap, cfgPath); 230 s.initModelWithCfg(cfgMap, cfgPath);
230 logger.info("Server started on port: " + s.getPort()); 231 logger.info("Server started on port: " + s.getPort());
231 s.start(); 232 s.start();
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +
  5 +package com.k2fsa.sherpa.onnx;
  6 +
  7 +public class OnlineModelConfig {
  8 + private final OnlineParaformerModelConfig paraformer;
  9 + private final OnlineTransducerModelConfig transducer;
  10 + private final String tokens;
  11 + private final int numThreads;
  12 + private final boolean debug;
  13 + private final String provider = "cpu";
  14 + private String modelType = "";
  15 +
  16 + public OnlineModelConfig(
  17 + String tokens,
  18 + int numThreads,
  19 + boolean debug,
  20 + String modelType,
  21 + OnlineParaformerModelConfig paraformer,
  22 + OnlineTransducerModelConfig transducer) {
  23 +
  24 + this.tokens = tokens;
  25 + this.numThreads = numThreads;
  26 + this.debug = debug;
  27 + this.modelType = modelType;
  28 + this.paraformer = paraformer;
  29 + this.transducer = transducer;
  30 + }
  31 +
  32 + public OnlineParaformerModelConfig getParaformer() {
  33 + return paraformer;
  34 + }
  35 +
  36 + public OnlineTransducerModelConfig getTransducer() {
  37 + return transducer;
  38 + }
  39 +
  40 + public String getTokens() {
  41 + return tokens;
  42 + }
  43 +
  44 + public int getNumThreads() {
  45 + return numThreads;
  46 + }
  47 +
  48 + public boolean getDebug() {
  49 + return debug;
  50 + }
  51 +}
  1 +/*
  2 + * // Copyright 2022-2023 by zhaoming
  3 + */
  4 +
  5 +package com.k2fsa.sherpa.onnx;
  6 +
  7 +public class OnlineParaformerModelConfig {
  8 + private final String encoder;
  9 + private final String decoder;
  10 +
  11 + public OnlineParaformerModelConfig(String encoder, String decoder) {
  12 + this.encoder = encoder;
  13 + this.decoder = decoder;
  14 + }
  15 +
  16 + public String getEncoder() {
  17 + return encoder;
  18 + }
  19 +
  20 + public String getDecoder() {
  21 + return decoder;
  22 + }
  23 +}
@@ -56,15 +56,21 @@ public class OnlineRecognizer { @@ -56,15 +56,21 @@ public class OnlineRecognizer {
56 new EndpointRule( 56 new EndpointRule(
57 false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); 57 false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
58 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 58 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
59 - OnlineTransducerModelConfig modelCfg = 59 +
  60 + OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig("", "");
  61 + OnlineTransducerModelConfig modelTranCfg =
60 new OnlineTransducerModelConfig( 62 new OnlineTransducerModelConfig(
61 proMap.get("encoder").trim(), 63 proMap.get("encoder").trim(),
62 proMap.get("decoder").trim(), 64 proMap.get("decoder").trim(),
63 - proMap.get("joiner").trim(), 65 + proMap.get("joiner").trim());
  66 + OnlineModelConfig modelCfg =
  67 + new OnlineModelConfig(
64 proMap.get("tokens").trim(), 68 proMap.get("tokens").trim(),
65 Integer.parseInt(proMap.get("num_threads").trim()), 69 Integer.parseInt(proMap.get("num_threads").trim()),
66 false, 70 false,
67 - proMap.get("model_type").trim()); 71 + proMap.get("model_type").trim(),
  72 + modelParaCfg,
  73 + modelTranCfg);
68 FeatureConfig featConfig = 74 FeatureConfig featConfig =
69 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 75 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
70 OnlineLMConfig onlineLmConfig = 76 OnlineLMConfig onlineLmConfig =
@@ -104,15 +110,23 @@ public class OnlineRecognizer { @@ -104,15 +110,23 @@ public class OnlineRecognizer {
104 new EndpointRule( 110 new EndpointRule(
105 false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim())); 111 false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
106 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 112 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
107 - OnlineTransducerModelConfig modelCfg = 113 + OnlineParaformerModelConfig modelParaCfg =
  114 + new OnlineParaformerModelConfig(
  115 + proMap.get("encoder").trim(), proMap.get("decoder").trim());
  116 + OnlineTransducerModelConfig modelTranCfg =
108 new OnlineTransducerModelConfig( 117 new OnlineTransducerModelConfig(
109 proMap.get("encoder").trim(), 118 proMap.get("encoder").trim(),
110 proMap.get("decoder").trim(), 119 proMap.get("decoder").trim(),
111 - proMap.get("joiner").trim(), 120 + proMap.get("joiner").trim());
  121 +
  122 + OnlineModelConfig modelCfg =
  123 + new OnlineModelConfig(
112 proMap.get("tokens").trim(), 124 proMap.get("tokens").trim(),
113 Integer.parseInt(proMap.get("num_threads").trim()), 125 Integer.parseInt(proMap.get("num_threads").trim()),
114 false, 126 false,
115 - proMap.get("model_type").trim()); 127 + proMap.get("model_type").trim(),
  128 + modelParaCfg,
  129 + modelTranCfg);
116 FeatureConfig featConfig = 130 FeatureConfig featConfig =
117 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 131 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
118 132
@@ -160,9 +174,11 @@ public class OnlineRecognizer { @@ -160,9 +174,11 @@ public class OnlineRecognizer {
160 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); 174 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
161 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); 175 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
162 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 176 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
163 - OnlineTransducerModelConfig modelCfg =  
164 - new OnlineTransducerModelConfig(  
165 - encoder, decoder, joiner, tokens, numThreads, false, modelType); 177 + OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
  178 + OnlineTransducerModelConfig modelTranCfg =
  179 + new OnlineTransducerModelConfig(encoder, decoder, joiner);
  180 + OnlineModelConfig modelCfg =
  181 + new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg);
166 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); 182 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
167 OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); 183 OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
168 OnlineRecognizerConfig rcgCfg = 184 OnlineRecognizerConfig rcgCfg =
@@ -277,6 +293,7 @@ public class OnlineRecognizer { @@ -277,6 +293,7 @@ public class OnlineRecognizer {
277 293
278 System.out.println("so lib path=" + soPath + "\n"); 294 System.out.println("so lib path=" + soPath + "\n");
279 System.load(soPath.trim()); 295 System.load(soPath.trim());
  296 + System.out.println("load so lib succeed\n");
280 } 297 }
281 298
282 public static void setSoPath(String soPath) { 299 public static void setSoPath(String soPath) {
@@ -6,17 +6,16 @@ package com.k2fsa.sherpa.onnx; @@ -6,17 +6,16 @@ package com.k2fsa.sherpa.onnx;
6 6
7 public class OnlineRecognizerConfig { 7 public class OnlineRecognizerConfig {
8 private final FeatureConfig featConfig; 8 private final FeatureConfig featConfig;
9 - private final OnlineTransducerModelConfig modelConfig; 9 + private final OnlineModelConfig modelConfig;
10 private final EndpointConfig endpointConfig; 10 private final EndpointConfig endpointConfig;
11 private final OnlineLMConfig lmConfig; 11 private final OnlineLMConfig lmConfig;
12 private final boolean enableEndpoint; 12 private final boolean enableEndpoint;
13 private final String decodingMethod; 13 private final String decodingMethod;
14 private final int maxActivePaths; 14 private final int maxActivePaths;
15 15
16 -  
17 public OnlineRecognizerConfig( 16 public OnlineRecognizerConfig(
18 FeatureConfig featConfig, 17 FeatureConfig featConfig,
19 - OnlineTransducerModelConfig modelConfig, 18 + OnlineModelConfig modelConfig,
20 EndpointConfig endpointConfig, 19 EndpointConfig endpointConfig,
21 OnlineLMConfig lmConfig, 20 OnlineLMConfig lmConfig,
22 boolean enableEndpoint, 21 boolean enableEndpoint,
@@ -39,7 +38,7 @@ public class OnlineRecognizerConfig { @@ -39,7 +38,7 @@ public class OnlineRecognizerConfig {
39 return featConfig; 38 return featConfig;
40 } 39 }
41 40
42 - public OnlineTransducerModelConfig getModelConfig() { 41 + public OnlineModelConfig getModelConfig() {
43 return modelConfig; 42 return modelConfig;
44 } 43 }
45 44
@@ -8,27 +8,11 @@ public class OnlineTransducerModelConfig { @@ -8,27 +8,11 @@ public class OnlineTransducerModelConfig {
8 private final String encoder; 8 private final String encoder;
9 private final String decoder; 9 private final String decoder;
10 private final String joiner; 10 private final String joiner;
11 - private final String tokens;  
12 - private final int numThreads;  
13 - private final boolean debug;  
14 - private final String provider = "cpu";  
15 - private String modelType = "";  
16 11
17 - public OnlineTransducerModelConfig(  
18 - String encoder,  
19 - String decoder,  
20 - String joiner,  
21 - String tokens,  
22 - int numThreads,  
23 - boolean debug,  
24 - String modelType) { 12 + public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
25 this.encoder = encoder; 13 this.encoder = encoder;
26 this.decoder = decoder; 14 this.decoder = decoder;
27 this.joiner = joiner; 15 this.joiner = joiner;
28 - this.tokens = tokens;  
29 - this.numThreads = numThreads;  
30 - this.debug = debug;  
31 - this.modelType = modelType;  
32 } 16 }
33 17
34 public String getEncoder() { 18 public String getEncoder() {
@@ -42,16 +26,4 @@ public class OnlineTransducerModelConfig { @@ -42,16 +26,4 @@ public class OnlineTransducerModelConfig {
42 public String getJoiner() { 26 public String getJoiner() {
43 return joiner; 27 return joiner;
44 } 28 }
45 -  
46 - public String getTokens() {  
47 - return tokens;  
48 - }  
49 -  
50 - public int getNumThreads() {  
51 - return numThreads;  
52 - }  
53 -  
54 - public boolean getDebug() {  
55 - return debug;  
56 - }  
57 } 29 }