Fangjun Kuang
Committed by GitHub

Support paraformer on Android (#264)

@@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() { @@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() {
177 // Please change getModelConfig() to add new models 177 // Please change getModelConfig() to add new models
178 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 178 // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
179 // for a list of available models 179 // for a list of available models
180 - val type = 3 180 + val type = 5
181 println("Select model type ${type}") 181 println("Select model type ${type}")
182 val config = OnlineRecognizerConfig( 182 val config = OnlineRecognizerConfig(
183 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), 183 featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
@@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() { @@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() {
185 lmConfig = getOnlineLMConfig(type = type), 185 lmConfig = getOnlineLMConfig(type = type),
186 endpointConfig = getEndpointConfig(), 186 endpointConfig = getEndpointConfig(),
187 enableEndpoint = true, 187 enableEndpoint = true,
188 - decodingMethod = "modified_beam_search",  
189 - maxActivePaths = 4,  
190 ) 188 )
191 189
192 model = SherpaOnnx( 190 model = SherpaOnnx(
@@ -15,9 +15,19 @@ data class EndpointConfig( @@ -15,9 +15,19 @@ data class EndpointConfig(
15 ) 15 )
16 16
17 data class OnlineTransducerModelConfig( 17 data class OnlineTransducerModelConfig(
18 - var encoder: String,  
19 - var decoder: String,  
20 - var joiner: String, 18 + var encoder: String = "",
  19 + var decoder: String = "",
  20 + var joiner: String = "",
  21 +)
  22 +
  23 +data class OnlineParaformerModelConfig(
  24 + var encoder: String = "",
  25 + var decoder: String = "",
  26 +)
  27 +
  28 +data class OnlineModelConfig(
  29 + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
  30 + var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
21 var tokens: String, 31 var tokens: String,
22 var numThreads: Int = 1, 32 var numThreads: Int = 1,
23 var debug: Boolean = false, 33 var debug: Boolean = false,
@@ -37,8 +47,8 @@ data class FeatureConfig( @@ -37,8 +47,8 @@ data class FeatureConfig(
37 47
38 data class OnlineRecognizerConfig( 48 data class OnlineRecognizerConfig(
39 var featConfig: FeatureConfig = FeatureConfig(), 49 var featConfig: FeatureConfig = FeatureConfig(),
40 - var modelConfig: OnlineTransducerModelConfig,  
41 - var lmConfig : OnlineLMConfig, 50 + var modelConfig: OnlineModelConfig,
  51 + var lmConfig: OnlineLMConfig,
42 var endpointConfig: EndpointConfig = EndpointConfig(), 52 var endpointConfig: EndpointConfig = EndpointConfig(),
43 var enableEndpoint: Boolean = true, 53 var enableEndpoint: Boolean = true,
44 var decodingMethod: String = "greedy_search", 54 var decodingMethod: String = "greedy_search",
@@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model @@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model
115 by following the code) 125 by following the code)
116 126
117 @param type 127 @param type
118 -0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) 128 +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
119 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english 129 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
120 130
121 -1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese) 131 +1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
122 132
123 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese 133 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
124 134
125 -2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English) 135 +2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
126 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english 136 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
127 137
128 -3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 138 +3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
129 https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 139 https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
  140 + 3 - int8 encoder
  141 + 4 - float32 encoder
  142 +
  143 +5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
  144 + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
  145 +
130 */ 146 */
131 -fun getModelConfig(type: Int): OnlineTransducerModelConfig? { 147 +fun getModelConfig(type: Int): OnlineModelConfig? {
132 when (type) { 148 when (type) {
133 0 -> { 149 0 -> {
134 val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" 150 val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
135 - return OnlineTransducerModelConfig(  
136 - encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",  
137 - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",  
138 - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", 151 + return OnlineModelConfig(
  152 + transducer = OnlineTransducerModelConfig(
  153 + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
  154 + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
  155 + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
  156 + ),
139 tokens = "$modelDir/tokens.txt", 157 tokens = "$modelDir/tokens.txt",
140 modelType = "zipformer", 158 modelType = "zipformer",
141 ) 159 )
142 } 160 }
143 1 -> { 161 1 -> {
144 val modelDir = "sherpa-onnx-lstm-zh-2023-02-20" 162 val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
145 - return OnlineTransducerModelConfig(  
146 - encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",  
147 - decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",  
148 - joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", 163 + return OnlineModelConfig(
  164 + transducer = OnlineTransducerModelConfig(
  165 + encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
  166 + decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
  167 + joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
  168 + ),
149 tokens = "$modelDir/tokens.txt", 169 tokens = "$modelDir/tokens.txt",
150 modelType = "lstm", 170 modelType = "lstm",
151 ) 171 )
@@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
153 173
154 2 -> { 174 2 -> {
155 val modelDir = "sherpa-onnx-lstm-en-2023-02-17" 175 val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
156 - return OnlineTransducerModelConfig(  
157 - encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",  
158 - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",  
159 - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", 176 + return OnlineModelConfig(
  177 + transducer = OnlineTransducerModelConfig(
  178 + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
  179 + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
  180 + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
  181 + ),
160 tokens = "$modelDir/tokens.txt", 182 tokens = "$modelDir/tokens.txt",
161 modelType = "lstm", 183 modelType = "lstm",
162 ) 184 )
@@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
164 186
165 3 -> { 187 3 -> {
166 val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" 188 val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
167 - return OnlineTransducerModelConfig(  
168 - encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",  
169 - decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",  
170 - joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", 189 + return OnlineModelConfig(
  190 + transducer = OnlineTransducerModelConfig(
  191 + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
  192 + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
  193 + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
  194 + ),
171 tokens = "$modelDir/data/lang_char/tokens.txt", 195 tokens = "$modelDir/data/lang_char/tokens.txt",
172 modelType = "zipformer2", 196 modelType = "zipformer2",
173 ) 197 )
@@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { @@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
175 199
176 4 -> { 200 4 -> {
177 val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" 201 val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
178 - return OnlineTransducerModelConfig(  
179 - encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",  
180 - decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",  
181 - joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", 202 + return OnlineModelConfig(
  203 + transducer = OnlineTransducerModelConfig(
  204 + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
  205 + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
  206 + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
  207 + ),
182 tokens = "$modelDir/data/lang_char/tokens.txt", 208 tokens = "$modelDir/data/lang_char/tokens.txt",
183 modelType = "zipformer2", 209 modelType = "zipformer2",
184 ) 210 )
185 } 211 }
  212 +
  213 + 5 -> {
  214 + val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
  215 + return OnlineModelConfig(
  216 + paraformer = OnlineParaformerModelConfig(
  217 + encoder = "$modelDir/encoder.int8.onnx",
  218 + decoder = "$modelDir/decoder.int8.onnx",
  219 + ),
  220 + tokens = "$modelDir/tokens.txt",
  221 + modelType = "paraformer",
  222 + )
  223 + }
186 } 224 }
187 return null; 225 return null;
188 } 226 }
@@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn @@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
200 0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) 238 0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
201 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english 239 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
202 */ 240 */
203 -fun getOnlineLMConfig(type : Int): OnlineLMConfig { 241 +fun getOnlineLMConfig(type: Int): OnlineLMConfig {
204 when (type) { 242 when (type) {
205 0 -> { 243 0 -> {
206 val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" 244 val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
@@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { @@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
190 OnlineParaformerDecoderResult r; 190 OnlineParaformerDecoderResult r;
191 s->SetParaformerResult(r); 191 s->SetParaformerResult(r);
192 192
193 - // the internal model caches are not reset 193 + s->GetStates().clear();
  194 + s->GetParaformerEncoderOutCache().clear();
  195 + s->GetParaformerAlphaCache().clear();
  196 +
  197 + // s->GetParaformerFeatCache().clear();
194 198
195 // Note: We only update counters. The underlying audio samples 199 // Note: We only update counters. The underlying audio samples
196 // are not discarded. 200 // are not discarded.
@@ -47,7 +47,7 @@ class SherpaOnnx { @@ -47,7 +47,7 @@ class SherpaOnnx {
47 } 47 }
48 48
49 void InputFinished() const { 49 void InputFinished() const {
50 - std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0); 50 + std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
51 stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), 51 stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
52 tail_padding.size()); 52 tail_padding.size());
53 stream_->InputFinished(); 53 stream_->InputFinished();
@@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
158 158
159 //---------- model config ---------- 159 //---------- model config ----------
160 fid = env->GetFieldID(cls, "modelConfig", 160 fid = env->GetFieldID(cls, "modelConfig",
  161 + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
  162 + jobject model_config = env->GetObjectField(config, fid);
  163 + jclass model_config_cls = env->GetObjectClass(model_config);
  164 +
  165 + // transducer
  166 + fid = env->GetFieldID(model_config_cls, "transducer",
161 "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); 167 "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
162 - jobject transducer_config = env->GetObjectField(config, fid);  
163 - jclass model_config_cls = env->GetObjectClass(transducer_config); 168 + jobject transducer_config = env->GetObjectField(model_config, fid);
  169 + jclass transducer_config_cls = env->GetObjectClass(transducer_config);
164 170
165 - fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); 171 + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
166 s = (jstring)env->GetObjectField(transducer_config, fid); 172 s = (jstring)env->GetObjectField(transducer_config, fid);
167 p = env->GetStringUTFChars(s, nullptr); 173 p = env->GetStringUTFChars(s, nullptr);
168 ans.model_config.transducer.encoder = p; 174 ans.model_config.transducer.encoder = p;
169 env->ReleaseStringUTFChars(s, p); 175 env->ReleaseStringUTFChars(s, p);
170 176
171 - fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); 177 + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
172 s = (jstring)env->GetObjectField(transducer_config, fid); 178 s = (jstring)env->GetObjectField(transducer_config, fid);
173 p = env->GetStringUTFChars(s, nullptr); 179 p = env->GetStringUTFChars(s, nullptr);
174 ans.model_config.transducer.decoder = p; 180 ans.model_config.transducer.decoder = p;
175 env->ReleaseStringUTFChars(s, p); 181 env->ReleaseStringUTFChars(s, p);
176 182
177 - fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); 183 + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
178 s = (jstring)env->GetObjectField(transducer_config, fid); 184 s = (jstring)env->GetObjectField(transducer_config, fid);
179 p = env->GetStringUTFChars(s, nullptr); 185 p = env->GetStringUTFChars(s, nullptr);
180 ans.model_config.transducer.joiner = p; 186 ans.model_config.transducer.joiner = p;
181 env->ReleaseStringUTFChars(s, p); 187 env->ReleaseStringUTFChars(s, p);
182 188
  189 + // paraformer
  190 + fid = env->GetFieldID(model_config_cls, "paraformer",
  191 + "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
  192 + jobject paraformer_config = env->GetObjectField(model_config, fid);
  193 + jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
  194 +
  195 + fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
  196 + "Ljava/lang/String;");
  197 + s = (jstring)env->GetObjectField(paraformer_config, fid);
  198 + p = env->GetStringUTFChars(s, nullptr);
  199 + ans.model_config.paraformer.encoder = p;
  200 + env->ReleaseStringUTFChars(s, p);
  201 +
  202 + fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
  203 + "Ljava/lang/String;");
  204 + s = (jstring)env->GetObjectField(paraformer_config, fid);
  205 + p = env->GetStringUTFChars(s, nullptr);
  206 + ans.model_config.paraformer.decoder = p;
  207 + env->ReleaseStringUTFChars(s, p);
  208 +
183 fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); 209 fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
184 - s = (jstring)env->GetObjectField(transducer_config, fid); 210 + s = (jstring)env->GetObjectField(model_config, fid);
185 p = env->GetStringUTFChars(s, nullptr); 211 p = env->GetStringUTFChars(s, nullptr);
186 ans.model_config.tokens = p; 212 ans.model_config.tokens = p;
187 env->ReleaseStringUTFChars(s, p); 213 env->ReleaseStringUTFChars(s, p);
188 214
189 fid = env->GetFieldID(model_config_cls, "numThreads", "I"); 215 fid = env->GetFieldID(model_config_cls, "numThreads", "I");
190 - ans.model_config.num_threads = env->GetIntField(transducer_config, fid); 216 + ans.model_config.num_threads = env->GetIntField(model_config, fid);
191 217
192 fid = env->GetFieldID(model_config_cls, "debug", "Z"); 218 fid = env->GetFieldID(model_config_cls, "debug", "Z");
193 - ans.model_config.debug = env->GetBooleanField(transducer_config, fid); 219 + ans.model_config.debug = env->GetBooleanField(model_config, fid);
194 220
195 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); 221 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
196 - s = (jstring)env->GetObjectField(transducer_config, fid); 222 + s = (jstring)env->GetObjectField(model_config, fid);
197 p = env->GetStringUTFChars(s, nullptr); 223 p = env->GetStringUTFChars(s, nullptr);
198 ans.model_config.provider = p; 224 ans.model_config.provider = p;
199 env->ReleaseStringUTFChars(s, p); 225 env->ReleaseStringUTFChars(s, p);
200 226
201 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); 227 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
202 - s = (jstring)env->GetObjectField(transducer_config, fid); 228 + s = (jstring)env->GetObjectField(model_config, fid);
203 p = env->GetStringUTFChars(s, nullptr); 229 p = env->GetStringUTFChars(s, nullptr);
204 ans.model_config.model_type = p; 230 ans.model_config.model_type = p;
205 env->ReleaseStringUTFChars(s, p); 231 env->ReleaseStringUTFChars(s, p);