Committed by
GitHub
Fix model_type for jni, c# and iOS. (#216)
正在显示
7 个修改的文件
包含
38 行增加
和
8 行删除
| @@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig( | @@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig( | ||
| 21 | var tokens: String, | 21 | var tokens: String, |
| 22 | var numThreads: Int = 1, | 22 | var numThreads: Int = 1, |
| 23 | var debug: Boolean = false, | 23 | var debug: Boolean = false, |
| 24 | + var provider: String = "cpu", | ||
| 25 | + var modelType: String = "", | ||
| 24 | ) | 26 | ) |
| 25 | 27 | ||
| 26 | data class OnlineLMConfig( | 28 | data class OnlineLMConfig( |
| @@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 135 | decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", | 137 | decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", |
| 136 | joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", | 138 | joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", |
| 137 | tokens = "$modelDir/tokens.txt", | 139 | tokens = "$modelDir/tokens.txt", |
| 140 | + modelType = "zipformer", | ||
| 138 | ) | 141 | ) |
| 139 | } | 142 | } |
| 140 | 1 -> { | 143 | 1 -> { |
| @@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 144 | decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", | 147 | decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", |
| 145 | joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", | 148 | joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", |
| 146 | tokens = "$modelDir/tokens.txt", | 149 | tokens = "$modelDir/tokens.txt", |
| 150 | + modelType = "lstm", | ||
| 147 | ) | 151 | ) |
| 148 | } | 152 | } |
| 149 | 153 | ||
| @@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 154 | decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", | 158 | decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", |
| 155 | joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", | 159 | joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", |
| 156 | tokens = "$modelDir/tokens.txt", | 160 | tokens = "$modelDir/tokens.txt", |
| 161 | + modelType = "lstm", | ||
| 157 | ) | 162 | ) |
| 158 | } | 163 | } |
| 159 | 164 | ||
| @@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 164 | decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", | 169 | decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", |
| 165 | joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", | 170 | joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", |
| 166 | tokens = "$modelDir/data/lang_char/tokens.txt", | 171 | tokens = "$modelDir/data/lang_char/tokens.txt", |
| 172 | + modelType = "zipformer2", | ||
| 167 | ) | 173 | ) |
| 168 | } | 174 | } |
| 169 | 175 | ||
| @@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | @@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { | ||
| 174 | decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", | 180 | decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", |
| 175 | joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", | 181 | joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", |
| 176 | tokens = "$modelDir/data/lang_char/tokens.txt", | 182 | tokens = "$modelDir/data/lang_char/tokens.txt", |
| 183 | + modelType = "zipformer2", | ||
| 177 | ) | 184 | ) |
| 178 | } | 185 | } |
| 179 | } | 186 | } |
| @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode | @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode | ||
| 26 | decoder: decoder, | 26 | decoder: decoder, |
| 27 | joiner: joiner, | 27 | joiner: joiner, |
| 28 | tokens: tokens, | 28 | tokens: tokens, |
| 29 | - numThreads: 2 | 29 | + numThreads: 2, |
| 30 | + modelType: "zipformer" | ||
| 30 | ) | 31 | ) |
| 31 | } | 32 | } |
| 32 | 33 | ||
| @@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig { | @@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig { | ||
| 41 | decoder: decoder, | 42 | decoder: decoder, |
| 42 | joiner: joiner, | 43 | joiner: joiner, |
| 43 | tokens: tokens, | 44 | tokens: tokens, |
| 44 | - numThreads: 2 | 45 | + numThreads: 2, |
| 46 | + modelType: "zipformer2" | ||
| 45 | ) | 47 | ) |
| 46 | } | 48 | } |
| 47 | 49 | ||
| @@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig { | @@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig { | ||
| 56 | decoder: decoder, | 58 | decoder: decoder, |
| 57 | joiner: joiner, | 59 | joiner: joiner, |
| 58 | tokens: tokens, | 60 | tokens: tokens, |
| 59 | - numThreads: 2 | 61 | + numThreads: 2, |
| 62 | + modelType: "zipformer2" | ||
| 60 | ) | 63 | ) |
| 61 | } | 64 | } |
| 62 | 65 | ||
| @@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig { | @@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig { | ||
| 71 | decoder: decoder, | 74 | decoder: decoder, |
| 72 | joiner: joiner, | 75 | joiner: joiner, |
| 73 | tokens: tokens, | 76 | tokens: tokens, |
| 74 | - numThreads: 2 | 77 | + numThreads: 2, |
| 78 | + modelType: "zipformer2" | ||
| 75 | ) | 79 | ) |
| 76 | } | 80 | } |
| 77 | 81 |
| @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode | @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode | ||
| 26 | decoder: decoder, | 26 | decoder: decoder, |
| 27 | joiner: joiner, | 27 | joiner: joiner, |
| 28 | tokens: tokens, | 28 | tokens: tokens, |
| 29 | - numThreads: 2 | 29 | + numThreads: 2, |
| 30 | + modelType: "zipformer" | ||
| 30 | ) | 31 | ) |
| 31 | } | 32 | } |
| 32 | 33 |
| @@ -26,6 +26,7 @@ namespace SherpaOnnx | @@ -26,6 +26,7 @@ namespace SherpaOnnx | ||
| 26 | NumThreads = 1; | 26 | NumThreads = 1; |
| 27 | Provider = "cpu"; | 27 | Provider = "cpu"; |
| 28 | Debug = 0; | 28 | Debug = 0; |
| 29 | + ModelType = ""; | ||
| 29 | } | 30 | } |
| 30 | [MarshalAs(UnmanagedType.LPStr)] | 31 | [MarshalAs(UnmanagedType.LPStr)] |
| 31 | public string Encoder; | 32 | public string Encoder; |
| @@ -47,6 +48,9 @@ namespace SherpaOnnx | @@ -47,6 +48,9 @@ namespace SherpaOnnx | ||
| 47 | 48 | ||
| 48 | /// true to print debug information of the model | 49 | /// true to print debug information of the model |
| 49 | public int Debug; | 50 | public int Debug; |
| 51 | + | ||
| 52 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 53 | + public string ModelType; | ||
| 50 | } | 54 | } |
| 51 | 55 | ||
| 52 | /// It expects 16 kHz 16-bit single channel wave format. | 56 | /// It expects 16 kHz 16-bit single channel wave format. |
| @@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | @@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { | ||
| 53 | const char *tokens; | 53 | const char *tokens; |
| 54 | int32_t num_threads; | 54 | int32_t num_threads; |
| 55 | const char *provider; | 55 | const char *provider; |
| 56 | - const char *model_type; | ||
| 57 | int32_t debug; // true to print debug information of the model | 56 | int32_t debug; // true to print debug information of the model |
| 57 | + const char *model_type; | ||
| 58 | } SherpaOnnxOnlineTransducerModelConfig; | 58 | } SherpaOnnxOnlineTransducerModelConfig; |
| 59 | 59 | ||
| 60 | /// It expects 16 kHz 16-bit single channel wave format. | 60 | /// It expects 16 kHz 16-bit single channel wave format. |
| @@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 187 | fid = env->GetFieldID(model_config_cls, "debug", "Z"); | 187 | fid = env->GetFieldID(model_config_cls, "debug", "Z"); |
| 188 | ans.model_config.debug = env->GetBooleanField(model_config, fid); | 188 | ans.model_config.debug = env->GetBooleanField(model_config, fid); |
| 189 | 189 | ||
| 190 | + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | ||
| 191 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 192 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 193 | + ans.model_config.provider = p; | ||
| 194 | + env->ReleaseStringUTFChars(s, p); | ||
| 195 | + | ||
| 196 | + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); | ||
| 197 | + s = (jstring)env->GetObjectField(model_config, fid); | ||
| 198 | + p = env->GetStringUTFChars(s, nullptr); | ||
| 199 | + ans.model_config.model_type = p; | ||
| 200 | + env->ReleaseStringUTFChars(s, p); | ||
| 201 | + | ||
| 190 | //---------- rnn lm model config ---------- | 202 | //---------- rnn lm model config ---------- |
| 191 | fid = env->GetFieldID(cls, "lmConfig", | 203 | fid = env->GetFieldID(cls, "lmConfig", |
| 192 | "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); | 204 | "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); |
| @@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig( | @@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig( | ||
| 36 | tokens: String, | 36 | tokens: String, |
| 37 | numThreads: Int = 2, | 37 | numThreads: Int = 2, |
| 38 | provider: String = "cpu", | 38 | provider: String = "cpu", |
| 39 | - debug: Int = 0 | 39 | + debug: Int = 0, |
| 40 | + modelType: String = "" | ||
| 40 | ) -> SherpaOnnxOnlineTransducerModelConfig { | 41 | ) -> SherpaOnnxOnlineTransducerModelConfig { |
| 41 | return SherpaOnnxOnlineTransducerModelConfig( | 42 | return SherpaOnnxOnlineTransducerModelConfig( |
| 42 | encoder: toCPointer(encoder), | 43 | encoder: toCPointer(encoder), |
| @@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig( | @@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig( | ||
| 45 | tokens: toCPointer(tokens), | 46 | tokens: toCPointer(tokens), |
| 46 | num_threads: Int32(numThreads), | 47 | num_threads: Int32(numThreads), |
| 47 | provider: toCPointer(provider), | 48 | provider: toCPointer(provider), |
| 48 | - debug: Int32(debug) | 49 | + debug: Int32(debug), |
| 50 | + model_type: toCPointer(modelType) | ||
| 49 | ) | 51 | ) |
| 50 | } | 52 | } |
| 51 | 53 |
-
请 注册 或 登录 后发表评论