zhaomingwork
Committed by GitHub

Java api update for adding modelType in config class (#228)

@@ -4,16 +4,17 @@ feature_dim=80 @@ -4,16 +4,17 @@ feature_dim=80
4 rule1_min_trailing_silence=2.4 4 rule1_min_trailing_silence=2.4
5 rule2_min_trailing_silence=1.2 5 rule2_min_trailing_silence=1.2
6 rule3_min_utterance_length=20 6 rule3_min_utterance_length=20
7 -encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx  
8 -decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx  
9 -joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx  
10 -tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt 7 +encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
  8 +decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
  9 +joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
  10 +tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
11 num_threads=4 11 num_threads=4
12 enable_endpoint_detection=true 12 enable_endpoint_detection=true
13 decoding_method=modified_beam_search 13 decoding_method=modified_beam_search
14 max_active_paths=4 14 max_active_paths=4
15 lm_model= 15 lm_model=
16 lm_scale=0.5 16 lm_scale=0.5
  17 +model_type=zipformer
17 18
18 #websocket server config 19 #websocket server config
19 port=8890 20 port=8890
@@ -49,8 +49,9 @@ public class DecodeFile { @@ -49,8 +49,9 @@ public class DecodeFile {
49 float rule3MinUtteranceLength = 20F; 49 float rule3MinUtteranceLength = 20F;
50 String decodingMethod = "greedy_search"; 50 String decodingMethod = "greedy_search";
51 int maxActivePaths = 4; 51 int maxActivePaths = 4;
52 - String lm_model="";  
53 - float lm_scale=0.5F; 52 + String lm_model = "";
  53 + float lm_scale = 0.5F;
  54 + String modelType = "zipformer";
54 rcgOjb = 55 rcgOjb =
55 new OnlineRecognizer( 56 new OnlineRecognizer(
56 tokens, 57 tokens,
@@ -65,9 +66,10 @@ public class DecodeFile { @@ -65,9 +66,10 @@ public class DecodeFile {
65 rule2MinTrailingSilence, 66 rule2MinTrailingSilence,
66 rule3MinUtteranceLength, 67 rule3MinUtteranceLength,
67 decodingMethod, 68 decodingMethod,
68 - lm_model,  
69 - lm_scale,  
70 - maxActivePaths); 69 + lm_model,
  70 + lm_scale,
  71 + maxActivePaths,
  72 + modelType);
71 streamObj = rcgOjb.createStream(); 73 streamObj = rcgOjb.createStream();
72 } catch (Exception e) { 74 } catch (Exception e) {
73 System.err.println(e); 75 System.err.println(e);
@@ -39,6 +39,7 @@ public class OnlineRecognizer { @@ -39,6 +39,7 @@ public class OnlineRecognizer {
39 private long ptr = 0; // this is the asr engine ptrss 39 private long ptr = 0; // this is the asr engine ptrss
40 40
41 private int sampleRate = 16000; 41 private int sampleRate = 16000;
  42 +
42 // load config file for OnlineRecognizer 43 // load config file for OnlineRecognizer
43 public OnlineRecognizer(String modelCfgPath) { 44 public OnlineRecognizer(String modelCfgPath) {
44 Map<String, String> proMap = this.readProperties(modelCfgPath); 45 Map<String, String> proMap = this.readProperties(modelCfgPath);
@@ -62,17 +63,20 @@ public class OnlineRecognizer { @@ -62,17 +63,20 @@ public class OnlineRecognizer {
62 proMap.get("joiner").trim(), 63 proMap.get("joiner").trim(),
63 proMap.get("tokens").trim(), 64 proMap.get("tokens").trim(),
64 Integer.parseInt(proMap.get("num_threads").trim()), 65 Integer.parseInt(proMap.get("num_threads").trim()),
65 - false); 66 + false,
  67 + proMap.get("model_type").trim());
66 FeatureConfig featConfig = 68 FeatureConfig featConfig =
67 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 69 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
68 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));  
69 -  
70 - OnlineRecognizerConfig rcgCfg = 70 + OnlineLMConfig onlineLmConfig =
  71 + new OnlineLMConfig(
  72 + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
  73 +
  74 + OnlineRecognizerConfig rcgCfg =
71 new OnlineRecognizerConfig( 75 new OnlineRecognizerConfig(
72 featConfig, 76 featConfig,
73 modelCfg, 77 modelCfg,
74 endCfg, 78 endCfg,
75 - onlineLmConfig, 79 + onlineLmConfig,
76 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), 80 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
77 proMap.get("decoding_method").trim(), 81 proMap.get("decoding_method").trim(),
78 Integer.parseInt(proMap.get("max_active_paths").trim())); 82 Integer.parseInt(proMap.get("max_active_paths").trim()));
@@ -107,18 +111,21 @@ public class OnlineRecognizer { @@ -107,18 +111,21 @@ public class OnlineRecognizer {
107 proMap.get("joiner").trim(), 111 proMap.get("joiner").trim(),
108 proMap.get("tokens").trim(), 112 proMap.get("tokens").trim(),
109 Integer.parseInt(proMap.get("num_threads").trim()), 113 Integer.parseInt(proMap.get("num_threads").trim()),
110 - false); 114 + false,
  115 + proMap.get("model_type").trim());
111 FeatureConfig featConfig = 116 FeatureConfig featConfig =
112 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 117 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
113 -  
114 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim()));  
115 -  
116 - OnlineRecognizerConfig rcgCfg = 118 +
  119 + OnlineLMConfig onlineLmConfig =
  120 + new OnlineLMConfig(
  121 + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
  122 +
  123 + OnlineRecognizerConfig rcgCfg =
117 new OnlineRecognizerConfig( 124 new OnlineRecognizerConfig(
118 featConfig, 125 featConfig,
119 modelCfg, 126 modelCfg,
120 endCfg, 127 endCfg,
121 - onlineLmConfig, 128 + onlineLmConfig,
122 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()), 129 Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
123 proMap.get("decoding_method").trim(), 130 proMap.get("decoding_method").trim(),
124 Integer.parseInt(proMap.get("max_active_paths").trim())); 131 Integer.parseInt(proMap.get("max_active_paths").trim()));
@@ -144,21 +151,29 @@ public class OnlineRecognizer { @@ -144,21 +151,29 @@ public class OnlineRecognizer {
144 float rule2MinTrailingSilence, 151 float rule2MinTrailingSilence,
145 float rule3MinUtteranceLength, 152 float rule3MinUtteranceLength,
146 String decodingMethod, 153 String decodingMethod,
147 - String lm_model,  
148 - float lm_scale,  
149 - int maxActivePaths) { 154 + String lm_model,
  155 + float lm_scale,
  156 + int maxActivePaths,
  157 + String modelType) {
150 this.sampleRate = sampleRate; 158 this.sampleRate = sampleRate;
151 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); 159 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
152 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); 160 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
153 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); 161 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
154 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 162 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
155 OnlineTransducerModelConfig modelCfg = 163 OnlineTransducerModelConfig modelCfg =
156 - new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); 164 + new OnlineTransducerModelConfig(
  165 + encoder, decoder, joiner, tokens, numThreads, false, modelType);
157 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); 166 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
158 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale);  
159 - OnlineRecognizerConfig rcgCfg = 167 + OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
  168 + OnlineRecognizerConfig rcgCfg =
160 new OnlineRecognizerConfig( 169 new OnlineRecognizerConfig(
161 - featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); 170 + featConfig,
  171 + modelCfg,
  172 + endCfg,
  173 + onlineLmConfig,
  174 + enableEndpointDetection,
  175 + decodingMethod,
  176 + maxActivePaths);
162 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 177 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
163 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); 178 this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
164 } 179 }
@@ -284,9 +299,10 @@ public class OnlineRecognizer { @@ -284,9 +299,10 @@ public class OnlineRecognizer {
284 public void releaseStream(OnlineStream s) { 299 public void releaseStream(OnlineStream s) {
285 s.release(); 300 s.release();
286 } 301 }
  302 +
287 // JNI interface libsherpa-onnx-jni.so 303 // JNI interface libsherpa-onnx-jni.so
288 304
289 - private static native Object[] readWave(String fileName); // static 305 + private static native Object[] readWave(String fileName); // static
290 306
291 private native String getResult(long ptr, long streamPtr); 307 private native String getResult(long ptr, long streamPtr);
292 308
@@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig { @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig {
11 private final String tokens; 11 private final String tokens;
12 private final int numThreads; 12 private final int numThreads;
13 private final boolean debug; 13 private final boolean debug;
  14 + private final String provider = "cpu";
  15 + private String modelType = "";
14 16
15 public OnlineTransducerModelConfig( 17 public OnlineTransducerModelConfig(
16 - String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) { 18 + String encoder,
  19 + String decoder,
  20 + String joiner,
  21 + String tokens,
  22 + int numThreads,
  23 + boolean debug,
  24 + String modelType) {
17 this.encoder = encoder; 25 this.encoder = encoder;
18 this.decoder = decoder; 26 this.decoder = decoder;
19 this.joiner = joiner; 27 this.joiner = joiner;
20 this.tokens = tokens; 28 this.tokens = tokens;
21 this.numThreads = numThreads; 29 this.numThreads = numThreads;
22 this.debug = debug; 30 this.debug = debug;
  31 + this.modelType = modelType;
23 } 32 }
24 33
25 public String getEncoder() { 34 public String getEncoder() {