Wei Kang
Committed by GitHub

Add java api for hotwords (#319)

* Add java api

* support websocket

* Fix kotlin
... ... @@ -53,6 +53,8 @@ data class OnlineRecognizerConfig(
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
)
class SherpaOnnx(
... ...
ENTRY_POINT = ./
LIB_SRC_DIR := ../sherpa-onnx/java-api/src/com/k2fsa/sherpa/onnx
... ... @@ -65,18 +64,22 @@ clean:
mkdir -p ./lib
runfile:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile test.wav
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile
runhotwords:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeFile hotwords.wav
runmic:
java -cp ./lib/sherpaonnx.jar:build $(RUNJFLAGS) DecodeMic
runsrv:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer ../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketServer $(shell pwd)/../build/lib/libsherpa-onnx-jni.so ./modeltest.cfg
runclient:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient ../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./test.wav 32
runclienthotwords:
java -cp $(BUILD_DIR):lib/Java-WebSocket-1.5.3.jar:lib/slf4j-simple-1.7.25.jar:lib/slf4j-api-1.7.25.jar:../lib/sherpaonnx.jar $(RUNJFLAGS) websocketsrv.AsrWebsocketClient $(shell pwd)/../build/lib/libsherpa-onnx-jni.so 127.0.0.1 8890 ./hotwords.wav 32
buildlib: $(LIB_FILES:.java=.class)
... ...
... ... @@ -12,6 +12,8 @@ num_threads=4
enable_endpoint_detection=true
decoding_method=modified_beam_search
max_active_paths=4
hotwords_file=
hotwords_score=1.5
lm_model=
lm_scale=0.5
model_type=zipformer
... ...
... ... @@ -36,6 +36,8 @@ if [ ! -d $repo ];then
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
ln -s $repo/test_wavs/0.wav hotwords.wav
fi
log $(pwd)
... ... @@ -64,3 +66,9 @@ cd ../java-api-examples
make all
make runfile
echo "礼 拜 二" > hotwords.txt
sed -i 's/hotwords_file=/hotwords_file=hotwords.txt/g' modeltest.cfg
make runhotwords
... ...
... ... @@ -49,6 +49,8 @@ public class DecodeFile {
float rule3MinUtteranceLength = 20F;
String decodingMethod = "greedy_search";
int maxActivePaths = 4;
String hotwordsFile = "";
float hotwordsScore = 1.5F;
String lm_model = "";
float lm_scale = 0.5F;
String modelType = "zipformer";
... ... @@ -69,6 +71,8 @@ public class DecodeFile {
lm_model,
lm_scale,
maxActivePaths,
hotwordsFile,
hotwordsScore,
modelType);
streamObj = rcgOjb.createStream();
} catch (Exception e) {
... ... @@ -158,7 +162,7 @@ public class DecodeFile {
try {
String appDir = System.getProperty("user.dir");
System.out.println("appdir=" + appDir);
String fileName = appDir + "/test.wav";
String fileName = appDir + "/" + args[0];
String cfgPath = appDir + "/modeltest.cfg";
String soPath = appDir + "/../build/lib/libsherpa-onnx-jni.so";
OnlineRecognizer.setSoPath(soPath);
... ...
... ... @@ -140,8 +140,6 @@ public class AsrWebsocketServer extends WebSocketServer {
}
}
public boolean streamQueueFind(WebSocket conn) {
return streamQueue.contains(conn);
}
... ... @@ -151,16 +149,16 @@ public class AsrWebsocketServer extends WebSocketServer {
rcgOjb = new OnlineRecognizer(cfgPath);
// size of stream thread pool
int streamThreadNum = Integer.valueOf(cfgMap.get("stream_thread_num"));
int streamThreadNum = Integer.valueOf(cfgMap.getOrDefault("stream_thread_num", "16"));
// size of decoder thread pool
int decoderThreadNum = Integer.valueOf(cfgMap.get("decoder_thread_num"));
int decoderThreadNum = Integer.valueOf(cfgMap.getOrDefault("decoder_thread_num", "16"));
// time(ms) idle for decoder thread when no job
int decoderTimeIdle = Integer.valueOf(cfgMap.get("decoder_time_idle"));
int decoderTimeIdle = Integer.valueOf(cfgMap.getOrDefault("decoder_time_idle", "200"));
// size of streams for parallel decoding
int parallelDecoderNum = Integer.valueOf(cfgMap.get("parallel_decoder_num"));
int parallelDecoderNum = Integer.valueOf(cfgMap.getOrDefault("parallel_decoder_num", "16"));
// time(ms) out for connection data
int deocderTimeOut = Integer.valueOf(cfgMap.get("deocder_time_out"));
int deocderTimeOut = Integer.valueOf(cfgMap.getOrDefault("deocder_time_out", "30000"));
// create stream threads
for (int i = 0; i < streamThreadNum; i++) {
... ... @@ -218,13 +216,13 @@ public class AsrWebsocketServer extends WebSocketServer {
String soPath = args[0];
String cfgPath = args[1];
OnlineRecognizer.setSoPath(soPath);
logger.info("readProperties");
Map<String, String> cfgMap = AsrWebsocketServer.readProperties(cfgPath);
int port = Integer.valueOf(cfgMap.get("port"));
int port = Integer.valueOf(cfgMap.getOrDefault("port", "8890"));
int connectionThreadNum = Integer.valueOf(cfgMap.get("connection_thread_num"));
int connectionThreadNum = Integer.valueOf(cfgMap.getOrDefault("connection_thread_num", "16"));
AsrWebsocketServer s = new AsrWebsocketServer(port, connectionThreadNum);
logger.info("initModelWithCfg");
s.initModelWithCfg(cfgMap, cfgPath);
... ...
... ... @@ -44,38 +44,48 @@ public class OnlineRecognizer {
public OnlineRecognizer(String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg = new OnlineParaformerModelConfig(proMap.get("encoder").trim(), proMap.get("decoder").trim());
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
... ... @@ -83,9 +93,11 @@ public class OnlineRecognizer {
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
... ... @@ -98,41 +110,49 @@ public class OnlineRecognizer {
public OnlineRecognizer(Object assetManager, String modelCfgPath) {
Map<String, String> proMap = this.readProperties(modelCfgPath);
try {
int sampleRate = Integer.parseInt(proMap.get("sample_rate").trim());
int sampleRate = Integer.parseInt(proMap.getOrDefault("sample_rate", "16000").trim());
this.sampleRate = sampleRate;
EndpointRule rule1 =
new EndpointRule(
false, Float.parseFloat(proMap.get("rule1_min_trailing_silence").trim()), 0.0F);
false,
Float.parseFloat(proMap.getOrDefault("rule1_min_trailing_silence", "2.4").trim()),
0.0F);
EndpointRule rule2 =
new EndpointRule(
true, Float.parseFloat(proMap.get("rule2_min_trailing_silence").trim()), 0.0F);
true,
Float.parseFloat(proMap.getOrDefault("rule2_min_trailing_silence", "1.2").trim()),
0.0F);
EndpointRule rule3 =
new EndpointRule(
false, 0.0F, Float.parseFloat(proMap.get("rule3_min_utterance_length").trim()));
false,
0.0F,
Float.parseFloat(proMap.getOrDefault("rule3_min_utterance_length", "20").trim()));
EndpointConfig endCfg = new EndpointConfig(rule1, rule2, rule3);
OnlineParaformerModelConfig modelParaCfg =
new OnlineParaformerModelConfig(
proMap.get("encoder").trim(), proMap.get("decoder").trim());
proMap.getOrDefault("encoder", "").trim(), proMap.getOrDefault("decoder", "").trim());
OnlineTransducerModelConfig modelTranCfg =
new OnlineTransducerModelConfig(
proMap.get("encoder").trim(),
proMap.get("decoder").trim(),
proMap.get("joiner").trim());
proMap.getOrDefault("encoder", "").trim(),
proMap.getOrDefault("decoder", "").trim(),
proMap.getOrDefault("joiner", "").trim());
OnlineModelConfig modelCfg =
new OnlineModelConfig(
proMap.get("tokens").trim(),
Integer.parseInt(proMap.get("num_threads").trim()),
proMap.getOrDefault("tokens", "").trim(),
Integer.parseInt(proMap.getOrDefault("num_threads", "4").trim()),
false,
proMap.get("model_type").trim(),
proMap.getOrDefault("model_type", "zipformer").trim(),
modelParaCfg,
modelTranCfg);
FeatureConfig featConfig =
new FeatureConfig(sampleRate, Integer.parseInt(proMap.get("feature_dim").trim()));
new FeatureConfig(
sampleRate, Integer.parseInt(proMap.getOrDefault("feature_dim", "80").trim()));
OnlineLMConfig onlineLmConfig =
new OnlineLMConfig(
proMap.get("lm_model").trim(), Float.parseFloat(proMap.get("lm_scale").trim()));
proMap.getOrDefault("lm_model", "").trim(),
Float.parseFloat(proMap.getOrDefault("lm_scale", "0.5").trim()));
OnlineRecognizerConfig rcgCfg =
new OnlineRecognizerConfig(
... ... @@ -140,9 +160,11 @@ public class OnlineRecognizer {
modelCfg,
endCfg,
onlineLmConfig,
Boolean.parseBoolean(proMap.get("enable_endpoint_detection").trim()),
proMap.get("decoding_method").trim(),
Integer.parseInt(proMap.get("max_active_paths").trim()));
Boolean.parseBoolean(proMap.getOrDefault("enable_endpoint_detection", "true").trim()),
proMap.getOrDefault("decoding_method", "modified_beam_search").trim(),
Integer.parseInt(proMap.getOrDefault("max_active_paths", "4").trim()),
proMap.getOrDefault("hotwords_file", "").trim(),
Float.parseFloat(proMap.getOrDefault("hotwords_score", "1.5").trim()));
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(assetManager, rcgCfg);
... ... @@ -168,6 +190,8 @@ public class OnlineRecognizer {
String lm_model,
float lm_scale,
int maxActivePaths,
String hotwordsFile,
float hotwordsScore,
String modelType) {
this.sampleRate = sampleRate;
EndpointRule rule1 = new EndpointRule(false, rule1MinTrailingSilence, 0.0F);
... ... @@ -189,7 +213,9 @@ public class OnlineRecognizer {
onlineLmConfig,
enableEndpointDetection,
decodingMethod,
maxActivePaths);
maxActivePaths,
hotwordsFile,
hotwordsScore);
// create a new Recognizer, first parameter kept for android asset_manager ANDROID_API__ >= 9
this.ptr = createOnlineRecognizer(new Object(), rcgCfg);
}
... ... @@ -211,7 +237,6 @@ public class OnlineRecognizer {
String key = (String) en.nextElement();
String Property = props.getProperty(key);
proMap.put(key, Property);
// System.out.println(key+"="+Property);
}
} catch (Exception e) {
... ...
... ... @@ -12,6 +12,8 @@ public class OnlineRecognizerConfig {
private final boolean enableEndpoint;
private final String decodingMethod;
private final int maxActivePaths;
private final String hotwordsFile;
private final float hotwordsScore;
public OnlineRecognizerConfig(
FeatureConfig featConfig,
... ... @@ -20,7 +22,9 @@ public class OnlineRecognizerConfig {
OnlineLMConfig lmConfig,
boolean enableEndpoint,
String decodingMethod,
int maxActivePaths) {
int maxActivePaths,
String hotwordsFile,
float hotwordsScore) {
this.featConfig = featConfig;
this.modelConfig = modelConfig;
this.endpointConfig = endpointConfig;
... ... @@ -28,6 +32,8 @@ public class OnlineRecognizerConfig {
this.enableEndpoint = enableEndpoint;
this.decodingMethod = decodingMethod;
this.maxActivePaths = maxActivePaths;
this.hotwordsFile = hotwordsFile;
this.hotwordsScore = hotwordsScore;
}
public OnlineLMConfig getLmConfig() {
... ...
... ... @@ -125,6 +125,15 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
... ... @@ -293,6 +302,15 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.hotwords_file = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
... ...