zhaomingwork
Committed by GitHub

Java api update for adding modelType in config class (#228)

@@ -4,16 +4,17 @@ feature_dim=80 @@ -4,16 +4,17 @@ feature_dim=80
4 rule1_min_trailing_silence=2.4 4 rule1_min_trailing_silence=2.4
5 rule2_min_trailing_silence=1.2 5 rule2_min_trailing_silence=1.2
6 rule3_min_utterance_length=20 6 rule3_min_utterance_length=20
7 -encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx  
8 -decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx  
9 -joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx  
10 -tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt 7 +encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx
  8 +decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx
  9 +joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx
  10 +tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt
11 num_threads=4 11 num_threads=4
12 enable_endpoint_detection=true 12 enable_endpoint_detection=true
13 decoding_method=modified_beam_search 13 decoding_method=modified_beam_search
14 max_active_paths=4 14 max_active_paths=4
15 lm_model= 15 lm_model=
16 lm_scale=0.5 16 lm_scale=0.5
  17 +model_type=zipformer
17 18
18 #websocket server config 19 #websocket server config
19 port=8890 20 port=8890
@@ -49,8 +49,9 @@ public class DecodeFile { @@ -49,8 +49,9 @@ public class DecodeFile {
49 float rule3MinUtteranceLength = 20F; 49 float rule3MinUtteranceLength = 20F;
50 String decodingMethod = "greedy_search"; 50 String decodingMethod = "greedy_search";
51 int maxActivePaths = 4; 51 int maxActivePaths = 4;
52 - String lm_model="";  
53 - float lm_scale=0.5F; 52 + String lm_model = "";
  53 + float lm_scale = 0.5F;
  54 + String modelType = "zipformer";
54 rcgOjb = 55 rcgOjb =
55 new OnlineRecognizer( 56 new OnlineRecognizer(
56 tokens, 57 tokens,
@@ -67,7 +68,8 @@ public class DecodeFile { @@ -67,7 +68,8 @@ public class DecodeFile {
67 decodingMethod, 68 decodingMethod,
68 lm_model, 69 lm_model,
69 lm_scale, 70 lm_scale,
70 - maxActivePaths); 71 + maxActivePaths,
  72 + modelType);
71 streamObj = rcgOjb.createStream(); 73 streamObj = rcgOjb.createStream();
72 } catch (Exception e) { 74 } catch (Exception e) {
73 System.err.println(e); 75 System.err.println(e);
@@ -39,6 +39,7 @@ public class OnlineRecognizer { @@ -39,6 +39,7 @@ public class OnlineRecognizer {
39 private long ptr = 0; // this is the asr engine ptrss 39 private long ptr = 0; // this is the asr engine ptrss
40 40
41 private int sampleRate = 16000; 41 private int sampleRate = 16000;
  42 +
42 // load config file for OnlineRecognizer 43 // load config file for OnlineRecognizer
43 public OnlineRecognizer(String modelCfgPath) { 44 public OnlineRecognizer(String modelCfgPath) {
44 Map<String, String> proMap = this.readProperties(modelCfgPath); 45 Map<String, String> proMap = this.readProperties(modelCfgPath);
@@ -62,10 +63,13 @@ public class OnlineRecognizer { @@ -62,10 +63,13 @@ public class OnlineRecognizer {
62 proMap.get("joiner").trim(), 63 proMap.get("joiner").trim(),
63 proMap.get("tokens").trim(), 64 proMap.get("tokens").trim(),
64 Integer.parseInt(proMap.get("num_threads").trim()), 65 Integer.parseInt(proMap.get("num_threads").trim()),
65 - false); 66 + false,
  67 + proMap.get("model_type").trim());
66 FeatureConfig featConfig = 68 FeatureConfig featConfig =
67 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 69 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
68 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); 70 + OnlineLMConfig onlineLmConfig =
  71 + new OnlineLMConfig(
  72 + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
69 73
70 OnlineRecognizerConfig rcgCfg = 74 OnlineRecognizerConfig rcgCfg =
71 new OnlineRecognizerConfig( 75 new OnlineRecognizerConfig(
@@ -107,11 +111,14 @@ public class OnlineRecognizer { @@ -107,11 +111,14 @@ public class OnlineRecognizer {
107 proMap.get("joiner").trim(), 111 proMap.get("joiner").trim(),
108 proMap.get("tokens").trim(), 112 proMap.get("tokens").trim(),
109 Integer.parseInt(proMap.get("num_threads").trim()), 113 Integer.parseInt(proMap.get("num_threads").trim()),
110 - false); 114 + false,
  115 + proMap.get("model_type").trim());
111 FeatureConfig featConfig = 116 FeatureConfig featConfig =
112 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); 117 new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
113 118
114 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); 119 + OnlineLMConfig onlineLmConfig =
  120 + new OnlineLMConfig(
  121 + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
115 122
116 OnlineRecognizerConfig rcgCfg = 123 OnlineRecognizerConfig rcgCfg =
117 new OnlineRecognizerConfig( 124 new OnlineRecognizerConfig(
@@ -146,19 +153,27 @@ public class OnlineRecognizer { @@ -146,19 +153,27 @@ public class OnlineRecognizer {
146 String decodingMethod, 153 String decodingMethod,
147 String lm_model, 154 String lm_model,
148 float lm_scale, 155 float lm_scale,
149 - int maxActivePaths) { 156 + int maxActivePaths,
  157 + String modelType) {
150 this.sampleRate = sampleRate; 158 this.sampleRate = sampleRate;
151 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); 159 EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
152 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); 160 EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
153 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); 161 EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
154 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); 162 EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
155 OnlineTransducerModelConfig modelCfg = 163 OnlineTransducerModelConfig modelCfg =
156 - new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); 164 + new OnlineTransducerModelConfig(
  165 + encoder, decoder, joiner, tokens, numThreads, false, modelType);
157 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); 166 FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
158 - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale); 167 + OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
159 OnlineRecognizerConfig rcgCfg = 168 OnlineRecognizerConfig rcgCfg =
160 new OnlineRecognizerConfig( 169 new OnlineRecognizerConfig(
161 - featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); 170 + featConfig,
  171 + modelCfg,
  172 + endCfg,
  173 + onlineLmConfig,
  174 + enableEndpointDetection,
  175 + decodingMethod,
  176 + maxActivePaths);
162 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 177 // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
163 this.ptr = createOnlineRecognizer(new Object(), rcgCfg); 178 this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
164 } 179 }
@@ -284,6 +299,7 @@ public class OnlineRecognizer { @@ -284,6 +299,7 @@ public class OnlineRecognizer {
284 public void releaseStream(OnlineStream s) { 299 public void releaseStream(OnlineStream s) {
285 s.release(); 300 s.release();
286 } 301 }
  302 +
287 // JNI interface libsherpa-onnx-jni.so 303 // JNI interface libsherpa-onnx-jni.so
288 304
289 private static native Object[] readWave(String fileName); // static 305 private static native Object[] readWave(String fileName); // static
@@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig { @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig {
11 private final String tokens; 11 private final String tokens;
12 private final int numThreads; 12 private final int numThreads;
13 private final boolean debug; 13 private final boolean debug;
  14 + private final String provider = "cpu";
  15 + private String modelType = "";
14 16
15 public OnlineTransducerModelConfig( 17 public OnlineTransducerModelConfig(
16 - String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) { 18 + String encoder,
  19 + String decoder,
  20 + String joiner,
  21 + String tokens,
  22 + int numThreads,
  23 + boolean debug,
  24 + String modelType) {
17 this.encoder = encoder; 25 this.encoder = encoder;
18 this.decoder = decoder; 26 this.decoder = decoder;
19 this.joiner = joiner; 27 this.joiner = joiner;
20 this.tokens = tokens; 28 this.tokens = tokens;
21 this.numThreads = numThreads; 29 this.numThreads = numThreads;
22 this.debug = debug; 30 this.debug = debug;
  31 + this.modelType = modelType;
23 } 32 }
24 33
25 public String getEncoder() { 34 public String getEncoder() {