Committed by
GitHub
Java api update for adding modelType in config class (#228)
正在显示
4 个修改的文件
包含
44 行增加
和
16 行删除
| @@ -4,16 +4,17 @@ feature_dim=80 | @@ -4,16 +4,17 @@ feature_dim=80 | ||
| 4 | rule1_min_trailing_silence=2.4 | 4 | rule1_min_trailing_silence=2.4 |
| 5 | rule2_min_trailing_silence=1.2 | 5 | rule2_min_trailing_silence=1.2 |
| 6 | rule3_min_utterance_length=20 | 6 | rule3_min_utterance_length=20 |
| 7 | -encoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx | ||
| 8 | -decoder=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx | ||
| 9 | -joiner=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx | ||
| 10 | -tokens=/sherpa-onnx/build_old/bin/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt | 7 | +encoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx |
| 8 | +decoder=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx | ||
| 9 | +joiner=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx | ||
| 10 | +tokens=/sherpa/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt | ||
| 11 | num_threads=4 | 11 | num_threads=4 |
| 12 | enable_endpoint_detection=true | 12 | enable_endpoint_detection=true |
| 13 | decoding_method=modified_beam_search | 13 | decoding_method=modified_beam_search |
| 14 | max_active_paths=4 | 14 | max_active_paths=4 |
| 15 | lm_model= | 15 | lm_model= |
| 16 | lm_scale=0.5 | 16 | lm_scale=0.5 |
| 17 | +model_type=zipformer | ||
| 17 | 18 | ||
| 18 | #websocket server config | 19 | #websocket server config |
| 19 | port=8890 | 20 | port=8890 |
| @@ -49,8 +49,9 @@ public class DecodeFile { | @@ -49,8 +49,9 @@ public class DecodeFile { | ||
| 49 | float rule3MinUtteranceLength = 20F; | 49 | float rule3MinUtteranceLength = 20F; |
| 50 | String decodingMethod = "greedy_search"; | 50 | String decodingMethod = "greedy_search"; |
| 51 | int maxActivePaths = 4; | 51 | int maxActivePaths = 4; |
| 52 | - String lm_model=""; | ||
| 53 | - float lm_scale=0.5F; | 52 | + String lm_model = ""; |
| 53 | + float lm_scale = 0.5F; | ||
| 54 | + String modelType = "zipformer"; | ||
| 54 | rcgOjb = | 55 | rcgOjb = |
| 55 | new OnlineRecognizer( | 56 | new OnlineRecognizer( |
| 56 | tokens, | 57 | tokens, |
| @@ -67,7 +68,8 @@ public class DecodeFile { | @@ -67,7 +68,8 @@ public class DecodeFile { | ||
| 67 | decodingMethod, | 68 | decodingMethod, |
| 68 | lm_model, | 69 | lm_model, |
| 69 | lm_scale, | 70 | lm_scale, |
| 70 | - maxActivePaths); | 71 | + maxActivePaths, |
| 72 | + modelType); | ||
| 71 | streamObj = rcgOjb.createStream(); | 73 | streamObj = rcgOjb.createStream(); |
| 72 | } catch (Exception e) { | 74 | } catch (Exception e) { |
| 73 | System.err.println(e); | 75 | System.err.println(e); |
| @@ -39,6 +39,7 @@ public class OnlineRecognizer { | @@ -39,6 +39,7 @@ public class OnlineRecognizer { | ||
| 39 | private long ptr = 0; // this is the asr engine ptrss | 39 | private long ptr = 0; // this is the asr engine ptrss |
| 40 | 40 | ||
| 41 | private int sampleRate = 16000; | 41 | private int sampleRate = 16000; |
| 42 | + | ||
| 42 | // load config file for OnlineRecognizer | 43 | // load config file for OnlineRecognizer |
| 43 | public OnlineRecognizer(String modelCfgPath) { | 44 | public OnlineRecognizer(String modelCfgPath) { |
| 44 | Map<String, String> proMap = this.readProperties(modelCfgPath); | 45 | Map<String, String> proMap = this.readProperties(modelCfgPath); |
| @@ -62,10 +63,13 @@ public class OnlineRecognizer { | @@ -62,10 +63,13 @@ public class OnlineRecognizer { | ||
| 62 | proMap.get("joiner").trim(), | 63 | proMap.get("joiner").trim(), |
| 63 | proMap.get("tokens").trim(), | 64 | proMap.get("tokens").trim(), |
| 64 | Integer.parseInt(proMap.get("num_threads").trim()), | 65 | Integer.parseInt(proMap.get("num_threads").trim()), |
| 65 | - false); | 66 | + false, |
| 67 | + proMap.get("model_type").trim()); | ||
| 66 | FeatureConfig featConfig = | 68 | FeatureConfig featConfig = |
| 67 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 69 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); |
| 68 | - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); | 70 | + OnlineLMConfig onlineLmConfig = |
| 71 | + new OnlineLMConfig( | ||
| 72 | + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); | ||
| 69 | 73 | ||
| 70 | OnlineRecognizerConfig rcgCfg = | 74 | OnlineRecognizerConfig rcgCfg = |
| 71 | new OnlineRecognizerConfig( | 75 | new OnlineRecognizerConfig( |
| @@ -107,11 +111,14 @@ public class OnlineRecognizer { | @@ -107,11 +111,14 @@ public class OnlineRecognizer { | ||
| 107 | proMap.get("joiner").trim(), | 111 | proMap.get("joiner").trim(), |
| 108 | proMap.get("tokens").trim(), | 112 | proMap.get("tokens").trim(), |
| 109 | Integer.parseInt(proMap.get("num_threads").trim()), | 113 | Integer.parseInt(proMap.get("num_threads").trim()), |
| 110 | - false); | 114 | + false, |
| 115 | + proMap.get("model_type").trim()); | ||
| 111 | FeatureConfig featConfig = | 116 | FeatureConfig featConfig = |
| 112 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); | 117 | new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim())); |
| 113 | 118 | ||
| 114 | - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(proMap.get("lm_model").trim(),Float.parseFloat(proMap.get("lm_scale").trim())); | 119 | + OnlineLMConfig onlineLmConfig = |
| 120 | + new OnlineLMConfig( | ||
| 121 | + proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim())); | ||
| 115 | 122 | ||
| 116 | OnlineRecognizerConfig rcgCfg = | 123 | OnlineRecognizerConfig rcgCfg = |
| 117 | new OnlineRecognizerConfig( | 124 | new OnlineRecognizerConfig( |
| @@ -146,19 +153,27 @@ public class OnlineRecognizer { | @@ -146,19 +153,27 @@ public class OnlineRecognizer { | ||
| 146 | String decodingMethod, | 153 | String decodingMethod, |
| 147 | String lm_model, | 154 | String lm_model, |
| 148 | float lm_scale, | 155 | float lm_scale, |
| 149 | - int maxActivePaths) { | 156 | + int maxActivePaths, |
| 157 | + String modelType) { | ||
| 150 | this.sampleRate = sampleRate; | 158 | this.sampleRate = sampleRate; |
| 151 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); | 159 | EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F); |
| 152 | EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); | 160 | EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F); |
| 153 | EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); | 161 | EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength); |
| 154 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); | 162 | EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3); |
| 155 | OnlineTransducerModelConfig modelCfg = | 163 | OnlineTransducerModelConfig modelCfg = |
| 156 | - new OnlineTransducerModelConfig(encoder, decoder, joiner, tokens, numThreads, false); | 164 | + new OnlineTransducerModelConfig( |
| 165 | + encoder, decoder, joiner, tokens, numThreads, false, modelType); | ||
| 157 | FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); | 166 | FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim); |
| 158 | - OnlineLMConfig onlineLmConfig=new OnlineLMConfig(lm_model,lm_scale); | 167 | + OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale); |
| 159 | OnlineRecognizerConfig rcgCfg = | 168 | OnlineRecognizerConfig rcgCfg = |
| 160 | new OnlineRecognizerConfig( | 169 | new OnlineRecognizerConfig( |
| 161 | - featConfig, modelCfg, endCfg, onlineLmConfig,enableEndpointDetection, decodingMethod, maxActivePaths); | 170 | + featConfig, |
| 171 | + modelCfg, | ||
| 172 | + endCfg, | ||
| 173 | + onlineLmConfig, | ||
| 174 | + enableEndpointDetection, | ||
| 175 | + decodingMethod, | ||
| 176 | + maxActivePaths); | ||
| 162 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 | 177 | // create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9 |
| 163 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); | 178 | this.ptr = createOnlineRecognizer(new Object(), rcgCfg); |
| 164 | } | 179 | } |
| @@ -284,6 +299,7 @@ public class OnlineRecognizer { | @@ -284,6 +299,7 @@ public class OnlineRecognizer { | ||
| 284 | public void releaseStream(OnlineStream s) { | 299 | public void releaseStream(OnlineStream s) { |
| 285 | s.release(); | 300 | s.release(); |
| 286 | } | 301 | } |
| 302 | + | ||
| 287 | // JNI interface libsherpa-onnx-jni.so | 303 | // JNI interface libsherpa-onnx-jni.so |
| 288 | 304 | ||
| 289 | private static native Object[] readWave(String fileName); // static | 305 | private static native Object[] readWave(String fileName); // static |
| @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig { | @@ -11,15 +11,24 @@ public class OnlineTransducerModelConfig { | ||
| 11 | private final String tokens; | 11 | private final String tokens; |
| 12 | private final int numThreads; | 12 | private final int numThreads; |
| 13 | private final boolean debug; | 13 | private final boolean debug; |
| 14 | + private final String provider = "cpu"; | ||
| 15 | + private String modelType = ""; | ||
| 14 | 16 | ||
| 15 | public OnlineTransducerModelConfig( | 17 | public OnlineTransducerModelConfig( |
| 16 | - String encoder, String decoder, String joiner, String tokens, int numThreads, boolean debug) { | 18 | + String encoder, |
| 19 | + String decoder, | ||
| 20 | + String joiner, | ||
| 21 | + String tokens, | ||
| 22 | + int numThreads, | ||
| 23 | + boolean debug, | ||
| 24 | + String modelType) { | ||
| 17 | this.encoder = encoder; | 25 | this.encoder = encoder; |
| 18 | this.decoder = decoder; | 26 | this.decoder = decoder; |
| 19 | this.joiner = joiner; | 27 | this.joiner = joiner; |
| 20 | this.tokens = tokens; | 28 | this.tokens = tokens; |
| 21 | this.numThreads = numThreads; | 29 | this.numThreads = numThreads; |
| 22 | this.debug = debug; | 30 | this.debug = debug; |
| 31 | + this.modelType = modelType; | ||
| 23 | } | 32 | } |
| 24 | 33 | ||
| 25 | public String getEncoder() { | 34 | public String getEncoder() { |
-
请 注册 或 登录 后发表评论