Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2023-08-14 12:26:15 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2023-08-14 12:26:15 +0800
Commit
35526e26e1ff74602d99bd7426397421e1300bfa
35526e26
1 parent
6038e2aa
Support paraformer on Android (#264)
显示空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
97 行增加
和
31 行删除
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
sherpa-onnx/jni/jni.cc
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt
查看文件 @
35526e2
...
...
@@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() {
// Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models
val type =
3
val type =
5
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
...
...
@@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() {
lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
)
model = SherpaOnnx(
...
...
android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt
查看文件 @
35526e2
...
...
@@ -15,9 +15,19 @@ data class EndpointConfig(
)
data class OnlineTransducerModelConfig(
var encoder: String,
var decoder: String,
var joiner: String,
var encoder: String = "",
var decoder: String = "",
var joiner: String = "",
)
data class OnlineParaformerModelConfig(
var encoder: String = "",
var decoder: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var tokens: String,
var numThreads: Int = 1,
var debug: Boolean = false,
...
...
@@ -37,8 +47,8 @@ data class FeatureConfig(
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var lmConfig : OnlineLMConfig,
var modelConfig: OnlineModelConfig,
var lmConfig: OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
...
...
@@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model
by following the code)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3
,4
- pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3 - int8 encoder
4 - float32 encoder
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
*/
fun getModelConfig(type: Int): Online
Transducer
ModelConfig? {
fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
)
}
1 -> {
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
...
...
@@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
2 -> {
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "lstm",
)
...
...
@@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
3 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
...
...
@@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
4 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig(
return OnlineModelConfig(
transducer = OnlineTransducerModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
)
}
5 -> {
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
return OnlineModelConfig(
paraformer = OnlineParaformerModelConfig(
encoder = "$modelDir/encoder.int8.onnx",
decoder = "$modelDir/decoder.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
}
return null;
}
...
...
@@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type
: Int): OnlineLMConfig {
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
...
...
sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
查看文件 @
35526e2
...
...
@@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineParaformerDecoderResult
r
;
s
->
SetParaformerResult
(
r
);
// the internal model caches are not reset
s
->
GetStates
().
clear
();
s
->
GetParaformerEncoderOutCache
().
clear
();
s
->
GetParaformerAlphaCache
().
clear
();
// s->GetParaformerFeatCache().clear();
// Note: We only update counters. The underlying audio samples
// are not discarded.
...
...
sherpa-onnx/jni/jni.cc
查看文件 @
35526e2
...
...
@@ -47,7 +47,7 @@ class SherpaOnnx {
}
void
InputFinished
()
const
{
std
::
vector
<
float
>
tail_padding
(
input_sample_rate_
*
0.
32
,
0
);
std
::
vector
<
float
>
tail_padding
(
input_sample_rate_
*
0.
6
,
0
);
stream_
->
AcceptWaveform
(
input_sample_rate_
,
tail_padding
.
data
(),
tail_padding
.
size
());
stream_
->
InputFinished
();
...
...
@@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
//---------- model config ----------
fid
=
env
->
GetFieldID
(
cls
,
"modelConfig"
,
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"
);
jobject
model_config
=
env
->
GetObjectField
(
config
,
fid
);
jclass
model_config_cls
=
env
->
GetObjectClass
(
model_config
);
// transducer
fid
=
env
->
GetFieldID
(
model_config_cls
,
"transducer"
,
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"
);
jobject
transducer_config
=
env
->
GetObjectField
(
config
,
fid
);
jclass
model_config_cls
=
env
->
GetObjectClass
(
transducer_config
);
jobject
transducer_config
=
env
->
GetObjectField
(
model_config
,
fid
);
jclass
transducer_config_cls
=
env
->
GetObjectClass
(
transducer_config
);
fid
=
env
->
GetFieldID
(
model
_config_cls
,
"encoder"
,
"Ljava/lang/String;"
);
fid
=
env
->
GetFieldID
(
transducer
_config_cls
,
"encoder"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
transducer
.
encoder
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
model
_config_cls
,
"decoder"
,
"Ljava/lang/String;"
);
fid
=
env
->
GetFieldID
(
transducer
_config_cls
,
"decoder"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
transducer
.
decoder
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
model
_config_cls
,
"joiner"
,
"Ljava/lang/String;"
);
fid
=
env
->
GetFieldID
(
transducer
_config_cls
,
"joiner"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
transducer
.
joiner
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
// paraformer
fid
=
env
->
GetFieldID
(
model_config_cls
,
"paraformer"
,
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"
);
jobject
paraformer_config
=
env
->
GetObjectField
(
model_config
,
fid
);
jclass
paraformer_config_config_cls
=
env
->
GetObjectClass
(
paraformer_config
);
fid
=
env
->
GetFieldID
(
paraformer_config_config_cls
,
"encoder"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
paraformer_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
paraformer
.
encoder
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
paraformer_config_config_cls
,
"decoder"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
paraformer_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
paraformer
.
decoder
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
model_config_cls
,
"tokens"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer
_config
,
fid
);
s
=
(
jstring
)
env
->
GetObjectField
(
model
_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
tokens
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
model_config_cls
,
"numThreads"
,
"I"
);
ans
.
model_config
.
num_threads
=
env
->
GetIntField
(
transducer
_config
,
fid
);
ans
.
model_config
.
num_threads
=
env
->
GetIntField
(
model
_config
,
fid
);
fid
=
env
->
GetFieldID
(
model_config_cls
,
"debug"
,
"Z"
);
ans
.
model_config
.
debug
=
env
->
GetBooleanField
(
transducer
_config
,
fid
);
ans
.
model_config
.
debug
=
env
->
GetBooleanField
(
model
_config
,
fid
);
fid
=
env
->
GetFieldID
(
model_config_cls
,
"provider"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer
_config
,
fid
);
s
=
(
jstring
)
env
->
GetObjectField
(
model
_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
provider
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
fid
=
env
->
GetFieldID
(
model_config_cls
,
"modelType"
,
"Ljava/lang/String;"
);
s
=
(
jstring
)
env
->
GetObjectField
(
transducer
_config
,
fid
);
s
=
(
jstring
)
env
->
GetObjectField
(
model
_config
,
fid
);
p
=
env
->
GetStringUTFChars
(
s
,
nullptr
);
ans
.
model_config
.
model_type
=
p
;
env
->
ReleaseStringUTFChars
(
s
,
p
);
...
...
请
注册
或
登录
后发表评论