Fangjun Kuang
Committed by GitHub

Fix model_type for jni, c# and iOS. (#216)

@@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig( @@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig(
21 var tokens: String, 21 var tokens: String,
22 var numThreads: Int = 1, 22 var numThreads: Int = 1,
23 var debug: Boolean = false, 23 var debug: Boolean = false,
  24 + var provider: String = "cpu",
  25 + var modelType: String = "",
24 ) 26 )
25 27
26 data class OnlineLMConfig( 28 data class OnlineLMConfig(
@@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
135 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", 137 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
136 joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", 138 joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
137 tokens = "$modelDir/tokens.txt", 139 tokens = "$modelDir/tokens.txt",
  140 + modelType = "zipformer",
138 ) 141 )
139 } 142 }
140 1 -> { 143 1 -> {
@@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
144 decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", 147 decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
145 joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", 148 joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
146 tokens = "$modelDir/tokens.txt", 149 tokens = "$modelDir/tokens.txt",
  150 + modelType = "lstm",
147 ) 151 )
148 } 152 }
149 153
@@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
154 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", 158 decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
155 joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", 159 joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
156 tokens = "$modelDir/tokens.txt", 160 tokens = "$modelDir/tokens.txt",
  161 + modelType = "lstm",
157 ) 162 )
158 } 163 }
159 164
@@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
164 decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", 169 decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
165 joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", 170 joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
166 tokens = "$modelDir/data/lang_char/tokens.txt", 171 tokens = "$modelDir/data/lang_char/tokens.txt",
  172 + modelType = "zipformer2",
167 ) 173 )
168 } 174 }
169 175
@@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
174 decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", 180 decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
175 joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", 181 joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
176 tokens = "$modelDir/data/lang_char/tokens.txt", 182 tokens = "$modelDir/data/lang_char/tokens.txt",
  183 + modelType = "zipformer2",
177 ) 184 )
178 } 185 }
179 } 186 }
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
26 decoder: decoder, 26 decoder: decoder,
27 joiner: joiner, 27 joiner: joiner,
28 tokens: tokens, 28 tokens: tokens,
29 - numThreads: 2 29 + numThreads: 2,
  30 + modelType: "zipformer"
30 ) 31 )
31 } 32 }
32 33
@@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig { @@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
41 decoder: decoder, 42 decoder: decoder,
42 joiner: joiner, 43 joiner: joiner,
43 tokens: tokens, 44 tokens: tokens,
44 - numThreads: 2 45 + numThreads: 2,
  46 + modelType: "zipformer2"
45 ) 47 )
46 } 48 }
47 49
@@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig { @@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
56 decoder: decoder, 58 decoder: decoder,
57 joiner: joiner, 59 joiner: joiner,
58 tokens: tokens, 60 tokens: tokens,
59 - numThreads: 2 61 + numThreads: 2,
  62 + modelType: "zipformer2"
60 ) 63 )
61 } 64 }
62 65
@@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig { @@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
71 decoder: decoder, 74 decoder: decoder,
72 joiner: joiner, 75 joiner: joiner,
73 tokens: tokens, 76 tokens: tokens,
74 - numThreads: 2 77 + numThreads: 2,
  78 + modelType: "zipformer2"
75 ) 79 )
76 } 80 }
77 81
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
26 decoder: decoder, 26 decoder: decoder,
27 joiner: joiner, 27 joiner: joiner,
28 tokens: tokens, 28 tokens: tokens,
29 - numThreads: 2 29 + numThreads: 2,
  30 + modelType: "zipformer"
30 ) 31 )
31 } 32 }
32 33
@@ -26,6 +26,7 @@ namespace SherpaOnnx @@ -26,6 +26,7 @@ namespace SherpaOnnx
26 NumThreads = 1; 26 NumThreads = 1;
27 Provider = "cpu"; 27 Provider = "cpu";
28 Debug = 0; 28 Debug = 0;
  29 + ModelType = "";
29 } 30 }
30 [MarshalAs(UnmanagedType.LPStr)] 31 [MarshalAs(UnmanagedType.LPStr)]
31 public string Encoder; 32 public string Encoder;
@@ -47,6 +48,9 @@ namespace SherpaOnnx @@ -47,6 +48,9 @@ namespace SherpaOnnx
47 48
48 /// true to print debug information of the model 49 /// true to print debug information of the model
49 public int Debug; 50 public int Debug;
  51 +
  52 + [MarshalAs(UnmanagedType.LPStr)]
  53 + public string ModelType;
50 } 54 }
51 55
52 /// It expects 16 kHz 16-bit single channel wave format. 56 /// It expects 16 kHz 16-bit single channel wave format.
@@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { @@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
53 const char *tokens; 53 const char *tokens;
54 int32_t num_threads; 54 int32_t num_threads;
55 const char *provider; 55 const char *provider;
56 - const char *model_type;  
57 int32_t debug; // true to print debug information of the model 56 int32_t debug; // true to print debug information of the model
  57 + const char *model_type;
58 } SherpaOnnxOnlineTransducerModelConfig; 58 } SherpaOnnxOnlineTransducerModelConfig;
59 59
60 /// It expects 16 kHz 16-bit single channel wave format. 60 /// It expects 16 kHz 16-bit single channel wave format.
@@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
187 fid = env->GetFieldID(model_config_cls, "debug", "Z"); 187 fid = env->GetFieldID(model_config_cls, "debug", "Z");
188 ans.model_config.debug = env->GetBooleanField(model_config, fid); 188 ans.model_config.debug = env->GetBooleanField(model_config, fid);
189 189
  190 + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
  191 + s = (jstring)env->GetObjectField(model_config, fid);
  192 + p = env->GetStringUTFChars(s, nullptr);
  193 + ans.model_config.provider = p;
  194 + env->ReleaseStringUTFChars(s, p);
  195 +
  196 + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
  197 + s = (jstring)env->GetObjectField(model_config, fid);
  198 + p = env->GetStringUTFChars(s, nullptr);
  199 + ans.model_config.model_type = p;
  200 + env->ReleaseStringUTFChars(s, p);
  201 +
190 //---------- rnn lm model config ---------- 202 //---------- rnn lm model config ----------
191 fid = env->GetFieldID(cls, "lmConfig", 203 fid = env->GetFieldID(cls, "lmConfig",
192 "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); 204 "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
@@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig( @@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
36 tokens: String, 36 tokens: String,
37 numThreads: Int = 2, 37 numThreads: Int = 2,
38 provider: String = "cpu", 38 provider: String = "cpu",
39 - debug: Int = 0 39 + debug: Int = 0,
  40 + modelType: String = ""
40 ) -> SherpaOnnxOnlineTransducerModelConfig { 41 ) -> SherpaOnnxOnlineTransducerModelConfig {
41 return SherpaOnnxOnlineTransducerModelConfig( 42 return SherpaOnnxOnlineTransducerModelConfig(
42 encoder: toCPointer(encoder), 43 encoder: toCPointer(encoder),
@@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig( @@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
45 tokens: toCPointer(tokens), 46 tokens: toCPointer(tokens),
46 num_threads: Int32(numThreads), 47 num_threads: Int32(numThreads),
47 provider: toCPointer(provider), 48 provider: toCPointer(provider),
48 - debug: Int32(debug) 49 + debug: Int32(debug),
  50 + model_type: toCPointer(modelType)
49 ) 51 )
50 } 52 }
51 53