Fangjun Kuang
Committed by GitHub

Support paraformer on Android (#264)

... ... @@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type = 3
val type = 5
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
... ... @@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() {
lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
)
model = SherpaOnnx(
... ...
... ... @@ -15,9 +15,19 @@ data class EndpointConfig(
)
data class OnlineTransducerModelConfig(
var encoder: String,
var decoder: String,
var joiner: String,
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OnlineParaformerModelConfig(
var encoder: String = "",
var decoder: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
... ... @@ -37,8 +47,8 @@ data class FeatureConfig(
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var lmConfig : OnlineLMConfig,
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
... ... @@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3 - int8 encoder
4 - float32 encoder
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
*/
fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
1 -> {
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
... ... @@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
2 -> {
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
... ... @@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
3 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
... ... @@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
4 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}
5 -> {
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
return OnlineModelConfig(
paraformer = OnlineParaformerModelConfig(
encoder = "$modelDir/encoder.int8.onnx",
decoder = "$modelDir/decoder.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
}
return null;
}
... ... @@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
... ...
... ... @@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineParaformerDecoderResult r;
s->SetParaformerResult(r);
// the internal model caches are not reset
s->GetStates().clear();
s->GetParaformerEncoderOutCache().clear();
s->GetParaformerAlphaCache().clear();
// s->GetParaformerFeatCache().clear();
// Note: We only update counters. The underlying audio samples
// are not discarded.
... ...
... ... @@ -47,7 +47,7 @@ class SherpaOnnx {
}
void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size());
stream_->InputFinished();
... ... @@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
//---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
// transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(transducer_config);
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(transducer_config, fid);
ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(transducer_config, fid);
ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
... ...