Fangjun Kuang
Committed by GitHub

Fix #608 (#610)

Fix java tests.
... ... @@ -9,10 +9,11 @@ LIB_FILES = \
$(LIB_SRC_DIR)/OnlineLMConfig.java \
$(LIB_SRC_DIR)/OnlineTransducerModelConfig.java \
$(LIB_SRC_DIR)/OnlineParaformerModelConfig.java \
$(LIB_SRC_DIR)/OnlineZipformer2CtcModelConfig.java \
$(LIB_SRC_DIR)/OnlineModelConfig.java \
$(LIB_SRC_DIR)/OnlineRecognizerConfig.java \
$(LIB_SRC_DIR)/OnlineStream.java \
$(LIB_SRC_DIR)/OnlineRecognizer.java \
$(LIB_SRC_DIR)/OnlineRecognizer.java
WEBSOCKET_DIR:= ./src/websocketsrv
WEBSOCKET_FILES = \
... ... @@ -42,10 +43,10 @@ vpath %.java src
buildfile:
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE)
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_FILE)
buildmic:
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic)
$(JAVAC) -cp lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 src/$(EXAMPLE_Mic)
rebuild: clean all
... ... @@ -63,8 +64,8 @@ clean:
mkdir -p $(BUILD_DIR)
mkdir -p ./lib
runfile:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
runfile: packjar buildfile
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
runhotwords:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
... ... @@ -85,8 +86,7 @@ buildlib: $(LIB_FILES:.java=.class)
%.class: %.java
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
$(JAVAC) -cp $(BUILD_DIR) -d $(BUILD_DIR) -encoding UTF-8 $<
buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
... ... @@ -95,7 +95,7 @@ buildwebsocket: $(WEBSOCKET_FILES:.java=.class)
$(JAVAC) -cp $(BUILD_DIR):lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:lib/Java-WebSocket-1.5.3.jar:../lib/sherpaonnx.jar -d $(BUILD_DIR) -encoding UTF-8 $<
packjar:
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
packjar: buildlib
jar cvfe lib/sherpaonnx.jar . -C $(BUILD_DIR) .
all: clean buildlib packjar buildfile buildmic downjar buildwebsocket
... ...
.idea
java-api.iml
... ...
... ... @@ -5,25 +5,25 @@
package com.k2fsa.sherpa.onnx;
public class EndpointConfig {
private final EndpointRule rule1;
private final EndpointRule rule2;
private final EndpointRule rule3;
private final EndpointRule rule1;
private final EndpointRule rule2;
private final EndpointRule rule3;
public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) {
this.rule1 = rule1;
this.rule2 = rule2;
this.rule3 = rule3;
}
public EndpointConfig(EndpointRule rule1, EndpointRule rule2, EndpointRule rule3) {
this.rule1 = rule1;
this.rule2 = rule2;
this.rule3 = rule3;
}
public EndpointRule getRule1() {
return rule1;
}
public EndpointRule getRule1() {
return rule1;
}
public EndpointRule getRule2() {
return rule2;
}
public EndpointRule getRule2() {
return rule2;
}
public EndpointRule getRule3() {
return rule3;
}
public EndpointRule getRule3() {
return rule3;
}
}
... ...
... ... @@ -5,26 +5,26 @@
package com.k2fsa.sherpa.onnx;
public class EndpointRule {
private final boolean mustContainNonSilence;
private final float minTrailingSilence;
private final float minUtteranceLength;
private final boolean mustContainNonSilence;
private final float minTrailingSilence;
private final float minUtteranceLength;
public EndpointRule(
boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) {
this.mustContainNonSilence = mustContainNonSilence;
this.minTrailingSilence = minTrailingSilence;
this.minUtteranceLength = minUtteranceLength;
}
public EndpointRule(
boolean mustContainNonSilence, float minTrailingSilence, float minUtteranceLength) {
this.mustContainNonSilence = mustContainNonSilence;
this.minTrailingSilence = minTrailingSilence;
this.minUtteranceLength = minUtteranceLength;
}
public float getMinTrailingSilence() {
return minTrailingSilence;
}
public float getMinTrailingSilence() {
return minTrailingSilence;
}
public float getMinUtteranceLength() {
return minUtteranceLength;
}
public float getMinUtteranceLength() {
return minUtteranceLength;
}
public boolean getMustContainNonSilence() {
return mustContainNonSilence;
}
public boolean getMustContainNonSilence() {
return mustContainNonSilence;
}
}
... ...
... ... @@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx;
public class FeatureConfig {
private final int sampleRate;
private final int featureDim;
private final int sampleRate;
private final int featureDim;
public FeatureConfig(int sampleRate, int featureDim) {
this.sampleRate = sampleRate;
this.featureDim = featureDim;
}
public FeatureConfig(int sampleRate, int featureDim) {
this.sampleRate = sampleRate;
this.featureDim = featureDim;
}
public int getSampleRate() {
return sampleRate;
}
public int getSampleRate() {
return sampleRate;
}
public int getFeatureDim() {
return featureDim;
}
public int getFeatureDim() {
return featureDim;
}
}
... ...
... ... @@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx;
public class OnlineLMConfig {
private final String model;
private final float scale;
private final String model;
private final float scale;
public OnlineLMConfig(String model, float scale) {
this.model = model;
this.scale = scale;
}
public OnlineLMConfig(String model, float scale) {
this.model = model;
this.scale = scale;
}
public String getModel() {
return model;
}
public String getModel() {
return model;
}
public float getScale() {
return scale;
}
public float getScale() {
return scale;
}
}
... ...
... ... @@ -5,47 +5,51 @@
package com.k2fsa.sherpa.onnx;
public class OnlineModelConfig {
private final OnlineParaformerModelConfig paraformer;
private final OnlineTransducerModelConfig transducer;
private final String tokens;
private final int numThreads;
private final boolean debug;
private final String provider = "cpu";
private String modelType = "";
public OnlineModelConfig(
String tokens,
int numThreads,
boolean debug,
String modelType,
OnlineParaformerModelConfig paraformer,
OnlineTransducerModelConfig transducer) {
this.tokens = tokens;
this.numThreads = numThreads;
this.debug = debug;
this.modelType = modelType;
this.paraformer = paraformer;
this.transducer = transducer;
}
public OnlineParaformerModelConfig getParaformer() {
return paraformer;
}
public OnlineTransducerModelConfig getTransducer() {
return transducer;
}
public String getTokens() {
return tokens;
}
public int getNumThreads() {
return numThreads;
}
public boolean getDebug() {
return debug;
}
private final OnlineParaformerModelConfig paraformer;
private final OnlineTransducerModelConfig transducer;
private final OnlineZipformer2CtcModelConfig zipformer2Ctc;
private final String tokens;
private final int numThreads;
private final boolean debug;
private final String provider = "cpu";
private String modelType = "";
public OnlineModelConfig(
String tokens,
int numThreads,
boolean debug,
String modelType,
OnlineParaformerModelConfig paraformer,
OnlineTransducerModelConfig transducer,
OnlineZipformer2CtcModelConfig zipformer2Ctc
) {
this.tokens = tokens;
this.numThreads = numThreads;
this.debug = debug;
this.modelType = modelType;
this.paraformer = paraformer;
this.transducer = transducer;
this.zipformer2Ctc = zipformer2Ctc;
}
public OnlineParaformerModelConfig getParaformer() {
return paraformer;
}
public OnlineTransducerModelConfig getTransducer() {
return transducer;
}
public String getTokens() {
return tokens;
}
public int getNumThreads() {
return numThreads;
}
public boolean getDebug() {
return debug;
}
}
... ...
... ... @@ -5,19 +5,19 @@
package com.k2fsa.sherpa.onnx;
public class OnlineParaformerModelConfig {
private final String encoder;
private final String decoder;
private final String encoder;
private final String decoder;
public OnlineParaformerModelConfig(String encoder, String decoder) {
this.encoder = encoder;
this.decoder = decoder;
}
public OnlineParaformerModelConfig(String encoder, String decoder) {
this.encoder = encoder;
this.decoder = decoder;
}
public String getEncoder() {
return encoder;
}
public String getEncoder() {
return encoder;
}
public String getDecoder() {
return decoder;
}
public String getDecoder() {
return decoder;
}
}
... ...
... ... @@ -32,336 +32,345 @@ usage example:
*/
package com.k2fsa.sherpa.onnx;
import java.io.*;
import java.util.*;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
public class OnlineRecognizer {
private long ptr = 0; // this is the asr engine ptrss
private int sampleRate = 16000;
// load config file for OnlineRecognizer
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
} catch (Exception e) {
System.err.println(e);
private long ptr = 0; // this is the asr engine ptrss
private int sampleRate = 16000;
// load config file for OnlineRecognizer
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig =
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
} catch (Exception e) {
System.err.println(e);
}
}
}
// use for android asset_manager ANDROID_API__ >= 9
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
} catch (Exception e) {
System.err.println(e);
// use for android asset_manager ANDROID_API__ >= 9
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig =
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
} catch (Exception e) {
System.err.println(e);
}
}
}
// set onlineRecognizer by parameter
public OnlineRecognizer(
String tokens,
String encoder,
String decoder,
String joiner,
int numThreads,
int sampleRate,
int featureDim,
boolean enableEndpointDetection,
float rule1MinTrailingSilence,
float rule2MinTrailingSilence,
float rule3MinUtteranceLength,
String decodingMethod,
String lm_model,
float lm_scale,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(encoder, decoder, joiner);
OnlineModelConfig modelCfg =
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg);
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths,
hotwordsFile,
hotwordsScore);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
private Map<String, String> readProperties(String modelCfgPath) {
// read and parse config file
Properties props = new Properties();
Map<String, String> proMap = new HashMap<>();
try {
File file = new File(modelCfgPath);
if (!file.exists()) {
System.out.println("model cfg file not exists!");
System.exit(0);
}
InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath));
props.load(in);
Enumeration en = props.propertyNames();
while (en.hasMoreElements()) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
}
} catch (Exception e) {
e.printStackTrace();
// set onlineRecognizer by parameter
public OnlineRecognizer(
String tokens,
String encoder,
String decoder,
String joiner,
int numThreads,
int sampleRate,
int featureDim,
boolean enableEndpointDetection,
float rule1MinTrailingSilence,
float rule2MinTrailingSilence,
float rule3MinUtteranceLength,
String decodingMethod,
String lm_model,
float lm_scale,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
EndpointRule rule2 = new EndpointRule(true, rule2MinTrailingSilence, 0.0F);
EndpointRule rule3 = new EndpointRule(false, 0.0F, rule3MinUtteranceLength);
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(encoder, decoder);
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(encoder, decoder, joiner);
OnlineZipformer2CtcModelConfig zipformer2CtcConfig = new OnlineZipformer2CtcModelConfig("");
OnlineModelConfig modelCfg =
new OnlineModelConfig(tokens, numThreads, false, modelType, modelParaCfg, modelTranCfg, zipformer2CtcConfig);
FeatureConfig featConfig = new FeatureConfig(sampleRate, featureDim);
OnlineLMConfig onlineLmConfig = new OnlineLMConfig(lm_model, lm_scale);
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
featConfig,
modelCfg,
endCfg,
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths,
hotwordsFile,
hotwordsScore);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
return proMap;
}
public void decodeStream(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
// when feeded samples to engine, call DecodeStream to let it process
decodeStream(this.ptr, streamPtr);
}
public void decodeStreams(OnlineStream[] ssOjb) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
// decode for multiple streams
long[] ss = new long[ssOjb.length];
for (int i = 0; i < ssOjb.length; i++) {
ss[i] = ssOjb[i].getPtr();
if (ss[i] == 0) throw new Exception("null exception for stream ptr");
public static float[] readWavFile(String fileName) {
// read data from the filename
Object[] wavdata = readWave(fileName);
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
float[] floatData = (float[]) data;
return floatData;
}
decodeStreams(this.ptr, ss);
}
public boolean isReady(OnlineStream s) throws Exception {
// whether the engine is ready for decode
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isReady(this.ptr, streamPtr);
}
// load the libsherpa-onnx-jni.so lib
public static void loadSoLib(String soPath) {
// load libsherpa-onnx-jni.so lib from the path
public String getResult(OnlineStream s) throws Exception {
// get text from the engine
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return getResult(this.ptr, streamPtr);
}
System.out.println("so lib path=" + soPath + "\n");
System.load(soPath.trim());
System.out.println("load so lib succeed\n");
}
public boolean isEndpoint(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isEndpoint(this.ptr, streamPtr);
}
public static void setSoPath(String soPath) {
OnlineRecognizer.loadSoLib(soPath);
OnlineStream.loadSoLib(soPath);
}
public void reSet(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr, streamPtr);
}
private static native Object[] readWave(String fileName); // static
private Map<String, String> readProperties(String modelCfgPath) {
// read and parse config file
Properties props = new Properties();
Map<String, String> proMap = new HashMap<>();
try {
File file = new File(modelCfgPath);
if (!file.exists()) {
System.out.println("model cfg file not exists!");
System.exit(0);
}
InputStream in = new BufferedInputStream(new FileInputStream(modelCfgPath));
props.load(in);
Enumeration en = props.propertyNames();
while (en.hasMoreElements()) {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
}
public OnlineStream createStream() throws Exception {
// create one stream for data to feed in
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = createStream(this.ptr);
OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate);
return stream;
}
} catch (Exception e) {
e.printStackTrace();
}
return proMap;
}
public static float[] readWavFile(String fileName) {
// read data from the filename
Object[] wavdata = readWave(fileName);
Object data = wavdata[0]; // data[0] is float data, data[1] sample rate
public void decodeStream(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
// when feeded samples to engine, call DecodeStream to let it process
decodeStream(this.ptr, streamPtr);
}
float[] floatData = (float[]) data;
public void decodeStreams(OnlineStream[] ssOjb) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
// decode for multiple streams
long[] ss = new long[ssOjb.length];
for (int i = 0; i < ssOjb.length; i++) {
ss[i] = ssOjb[i].getPtr();
if (ss[i] == 0) throw new Exception("null exception for stream ptr");
}
decodeStreams(this.ptr, ss);
}
return floatData;
}
public boolean isReady(OnlineStream s) throws Exception {
// whether the engine is ready for decode
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isReady(this.ptr, streamPtr);
}
// load the libsherpa-onnx-jni.so lib
public static void loadSoLib(String soPath) {
// load libsherpa-onnx-jni.so lib from the path
public String getResult(OnlineStream s) throws Exception {
// get text from the engine
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return getResult(this.ptr, streamPtr);
}
System.out.println("so lib path=" + soPath + "\n");
System.load(soPath.trim());
System.out.println("load so lib succeed\n");
}
public boolean isEndpoint(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
return isEndpoint(this.ptr, streamPtr);
}
public static void setSoPath(String soPath) {
OnlineRecognizer.loadSoLib(soPath);
OnlineStream.loadSoLib(soPath);
}
public void reSet(OnlineStream s) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = s.getPtr();
if (streamPtr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr, streamPtr);
}
protected void finalize() throws Throwable {
release();
}
public OnlineStream createStream() throws Exception {
// create one stream for data to feed in
if (this.ptr == 0) throw new Exception("null exception for recognizer ptr");
long streamPtr = createStream(this.ptr);
OnlineStream stream = new OnlineStream(streamPtr, this.sampleRate);
return stream;
}
// recognizer release, you'd better call it manually if not use anymore
public void release() {
if (this.ptr == 0) return;
deleteOnlineRecognizer(this.ptr);
this.ptr = 0;
}
protected void finalize() throws Throwable {
release();
}
// stream release, you'd better call it manually if not use anymore
public void releaseStream(OnlineStream s) {
s.release();
}
// recognizer release, you'd better call it manually if not use anymore
public void release() {
if (this.ptr == 0) return;
deleteOnlineRecognizer(this.ptr);
this.ptr = 0;
}
// JNI interface libsherpa-onnx-jni.so
// JNI interface libsherpa-onnx-jni.so
private static native Object[] readWave(String fileName); // static
// stream release, you'd better call it manually if not use anymore
public void releaseStream(OnlineStream s) {
s.release();
}
private native String getResult(long ptr, long streamPtr);
private native String getResult(long ptr, long streamPtr);
private native void decodeStream(long ptr, long streamPtr);
private native void decodeStream(long ptr, long streamPtr);
private native void decodeStreams(long ptr, long[] ssPtr);
private native void decodeStreams(long ptr, long[] ssPtr);
private native boolean isReady(long ptr, long streamPtr);
private native boolean isReady(long ptr, long streamPtr);
// first parameter keep for android asset_manager ANDROID_API__ >= 9
private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config);
// first parameter keep for android asset_manager ANDROID_API__ >= 9
private native long createOnlineRecognizer(Object asset, OnlineRecognizerConfig config);
private native long createStream(long ptr);
private native long createStream(long ptr);
private native void deleteOnlineRecognizer(long ptr);
private native void deleteOnlineRecognizer(long ptr);
private native boolean isEndpoint(long ptr, long streamPtr);
private native boolean isEndpoint(long ptr, long streamPtr);
private native void reSet(long ptr, long streamPtr);
private native void reSet(long ptr, long streamPtr);
}
... ...
... ... @@ -5,62 +5,62 @@
package com.k2fsa.sherpa.onnx;
public class OnlineRecognizerConfig {
private final FeatureConfig featConfig;
private final OnlineModelConfig modelConfig;
private final EndpointConfig endpointConfig;
private final OnlineLMConfig lmConfig;
private final boolean enableEndpoint;
private final String decodingMethod;
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;
private final FeatureConfig featConfig;
private final OnlineModelConfig modelConfig;
private final EndpointConfig endpointConfig;
private final OnlineLMConfig lmConfig;
private final boolean enableEndpoint;
private final String decodingMethod;
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;
public OnlineRecognizerConfig(
FeatureConfig featConfig,
OnlineModelConfig modelConfig,
EndpointConfig endpointConfig,
OnlineLMConfig lmConfig,
boolean enableEndpoint,
String decodingMethod,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore) {
this.featConfig = featConfig;
this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig;
this.lmConfig = lmConfig;
this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore;
}
public OnlineRecognizerConfig(
FeatureConfig featConfig,
OnlineModelConfig modelConfig,
EndpointConfig endpointConfig,
OnlineLMConfig lmConfig,
boolean enableEndpoint,
String decodingMethod,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore) {
this.featConfig = featConfig;
this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig;
this.lmConfig = lmConfig;
this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore;
}
public OnlineLMConfig getLmConfig() {
return lmConfig;
}
public OnlineLMConfig getLmConfig() {
return lmConfig;
}
public FeatureConfig getFeatConfig() {
return featConfig;
}
public FeatureConfig getFeatConfig() {
return featConfig;
}
public OnlineModelConfig getModelConfig() {
return modelConfig;
}
public OnlineModelConfig getModelConfig() {
return modelConfig;
}
public EndpointConfig getEndpointConfig() {
return endpointConfig;
}
public EndpointConfig getEndpointConfig() {
return endpointConfig;
}
public boolean isEnableEndpoint() {
return enableEndpoint;
}
public boolean isEnableEndpoint() {
return enableEndpoint;
}
public String getDecodingMethod() {
return decodingMethod;
}
public String getDecodingMethod() {
return decodingMethod;
}
public int getMaxActivePaths() {
return maxActivePaths;
}
public int getMaxActivePaths() {
return maxActivePaths;
}
}
... ...
... ... @@ -4,83 +4,81 @@
// Stream is used for feeding data to the asr engine
package com.k2fsa.sherpa.onnx;
import java.io.*;
import java.util.*;
public class OnlineStream {
private long ptr = 0; // this is the stream ptr
private long ptr = 0; // this is the stream ptr
private int sampleRate = 16000;
private int sampleRate = 16000;
// assign ptr to this stream in construction
public OnlineStream(long ptr, int sampleRate) {
this.ptr = ptr;
this.sampleRate = sampleRate;
}
// assign ptr to this stream in construction
public OnlineStream(long ptr, int sampleRate) {
this.ptr = ptr;
this.sampleRate = sampleRate;
}
public long getPtr() {
return ptr;
}
public static void loadSoLib(String soPath) {
// load .so lib from the path
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
}
public void acceptWaveform(float[] samples) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
public long getPtr() {
return ptr;
}
// feed wave data to asr engine
acceptWaveform(this.ptr, this.sampleRate, samples);
}
public void acceptWaveform(float[] samples) throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
public void inputFinished() {
// add some tail padding
int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
float tailPaddings[] = new float[padLen]; // default value is 0
acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
// feed wave data to asr engine
acceptWaveform(this.ptr, this.sampleRate, samples);
}
// tell the engine all data are feeded
inputFinished(this.ptr);
}
public void inputFinished() {
// add some tail padding
int padLen = (int) (this.sampleRate * 0.3); // 0.3 seconds at 16 kHz sample rate
float[] tailPaddings = new float[padLen]; // default value is 0
acceptWaveform(this.ptr, this.sampleRate, tailPaddings);
public static void loadSoLib(String soPath) {
// load .so lib from the path
System.load(soPath.trim()); // ("sherpa-onnx-jni-java");
}
// tell the engine all data are feeded
inputFinished(this.ptr);
}
public void release() {
// stream object must be release after used
if (this.ptr == 0) return;
deleteStream(this.ptr);
this.ptr = 0;
}
public void release() {
// stream object must be release after used
if (this.ptr == 0) return;
deleteStream(this.ptr);
this.ptr = 0;
}
protected void finalize() throws Throwable {
release();
}
protected void finalize() throws Throwable {
release();
}
public boolean isLastFrame() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return isLastFrame(this.ptr);
}
public boolean isLastFrame() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return isLastFrame(this.ptr);
}
public void reSet() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr);
}
public void reSet() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
reSet(this.ptr);
}
public int featureDim() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return featureDim(this.ptr);
}
public int featureDim() throws Exception {
if (this.ptr == 0) throw new Exception("null exception for stream ptr");
return featureDim(this.ptr);
}
// JNI interface libsherpa-onnx-jni.so
private native void acceptWaveform(long ptr, int sampleRate, float[] samples);
// JNI interface libsherpa-onnx-jni.so
private native void acceptWaveform(long ptr, int sampleRate, float[] samples);
private native void inputFinished(long ptr);
private native void inputFinished(long ptr);
private native void deleteStream(long ptr);
private native void deleteStream(long ptr);
private native int numFramesReady(long ptr);
private native int numFramesReady(long ptr);
private native boolean isLastFrame(long ptr);
private native boolean isLastFrame(long ptr);
private native void reSet(long ptr);
private native void reSet(long ptr);
private native int featureDim(long ptr);
private native int featureDim(long ptr);
}
... ...
... ... @@ -5,25 +5,25 @@
package com.k2fsa.sherpa.onnx;
public class OnlineTransducerModelConfig {
private final String encoder;
private final String decoder;
private final String joiner;
private final String encoder;
private final String decoder;
private final String joiner;
public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
this.encoder = encoder;
this.decoder = decoder;
this.joiner = joiner;
}
public OnlineTransducerModelConfig(String encoder, String decoder, String joiner) {
this.encoder = encoder;
this.decoder = decoder;
this.joiner = joiner;
}
public String getEncoder() {
return encoder;
}
public String getEncoder() {
return encoder;
}
public String getDecoder() {
return decoder;
}
public String getDecoder() {
return decoder;
}
public String getJoiner() {
return joiner;
}
public String getJoiner() {
return joiner;
}
}
... ...
package com.k2fsa.sherpa.onnx;
public class OnlineZipformer2CtcModelConfig {
private final String model;
public OnlineZipformer2CtcModelConfig(String model) {
this.model = model;
}
public String getModel() {
return model;
}
}
... ...
... ... @@ -1522,8 +1522,8 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxOffline_delete(
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
JNIEnv *env, jobject /*obj*/,
jlong ptr, jboolean recreate, jstring keywords) {
JNIEnv *env, jobject /*obj*/, jlong ptr, jboolean recreate,
jstring keywords) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
model->Reset(recreate, p_keywords);
... ...