Fangjun Kuang
Committed by GitHub

Fix model_type for jni, c# and iOS. (#216)

... ... @@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig(
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
)
data class OnlineLMConfig(
... ... @@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
1 -> {
... ... @@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
}
... ... @@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
}
... ... @@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}
... ... @@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}
}
... ...
... ... @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
decoder: decoder,
joiner: joiner,
tokens: tokens,
numThreads: 2
numThreads: 2,
modelType: "zipformer"
)
}
... ... @@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder,
joiner: joiner,
tokens: tokens,
numThreads: 2
numThreads: 2,
modelType: "zipformer2"
)
}
... ... @@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder,
joiner: joiner,
tokens: tokens,
numThreads: 2
numThreads: 2,
modelType: "zipformer2"
)
}
... ... @@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder,
joiner: joiner,
tokens: tokens,
numThreads: 2
numThreads: 2,
modelType: "zipformer2"
)
}
... ...
... ... @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
decoder: decoder,
joiner: joiner,
tokens: tokens,
numThreads: 2
numThreads: 2,
modelType: "zipformer"
)
}
... ...
... ... @@ -26,6 +26,7 @@ namespace SherpaOnnx
NumThreads = 1;
Provider = "cpu";
Debug = 0;
ModelType = "";
}
[MarshalAs(UnmanagedType.LPStr)]
public string Encoder;
... ... @@ -47,6 +48,9 @@ namespace SherpaOnnx
/// true to print debug information of the model
public int Debug;
[MarshalAs(UnmanagedType.LPStr)]
public string ModelType;
}
/// It expects 16 kHz 16-bit single channel wave format.
... ...
... ... @@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
const char *tokens;
int32_t num_threads;
const char *provider;
const char *model_type;
int32_t debug; // true to print debug information of the model
const char *model_type;
} SherpaOnnxOnlineTransducerModelConfig;
/// It expects 16 kHz 16-bit single channel wave format.
... ...
... ... @@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
... ...
... ... @@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
tokens: String,
numThreads: Int = 2,
provider: String = "cpu",
debug: Int = 0
debug: Int = 0,
modelType: String = ""
) -> SherpaOnnxOnlineTransducerModelConfig {
return SherpaOnnxOnlineTransducerModelConfig(
encoder: toCPointer(encoder),
... ... @@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
tokens: toCPointer(tokens),
num_threads: Int32(numThreads),
provider: toCPointer(provider),
debug: Int32(debug)
debug: Int32(debug),
model_type: toCPointer(modelType)
)
}
... ...