Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2025-06-03 20:28:57 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2025-06-03 20:28:57 +0800
Commit
1fabc6c79a69f225edac7715428d9bd4677f12ab
1fabc6c7
1 parent
818b3f6d
Fix rknn for multi-threads (#2274)
显示空白字符变更
内嵌
并排对比
正在显示
2 个修改的文件
包含
43 行增加
和
18 行删除
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc
查看文件 @
1fabc6c
...
...
@@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl {
}
std
::
pair
<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
Run
(
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
const
{
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
{
std
::
vector
<
rknn_input
>
inputs
(
input_attrs_
.
size
());
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
inputs
.
size
());
++
i
)
{
...
...
@@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl {
}
}
auto
ret
=
rknn_inputs_set
(
ctx_
,
inputs
.
size
(),
inputs
.
data
());
rknn_context
ctx
=
0
;
auto
ret
=
rknn_dup_context
(
&
ctx_
,
&
ctx
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to duplicate the ctx"
);
ret
=
rknn_inputs_set
(
ctx
,
inputs
.
size
(),
inputs
.
data
());
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to set inputs"
);
ret
=
rknn_run
(
ctx
_
,
nullptr
);
ret
=
rknn_run
(
ctx
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to run the model"
);
ret
=
rknn_outputs_get
(
ctx
_
,
outputs
.
size
(),
outputs
.
data
(),
nullptr
);
ret
=
rknn_outputs_get
(
ctx
,
outputs
.
size
(),
outputs
.
data
(),
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get model output"
);
for
(
int32_t
i
=
0
;
i
<
next_states
.
size
();
++
i
)
{
...
...
@@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl {
}
}
rknn_destroy
(
ctx
);
return
{
std
::
move
(
out
),
std
::
move
(
next_states
)};
}
...
...
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
查看文件 @
1fabc6c
...
...
@@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
std
::
pair
<
std
::
vector
<
float
>
,
std
::
vector
<
std
::
vector
<
uint8_t
>>>
RunEncoder
(
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
const
{
std
::
vector
<
float
>
features
,
std
::
vector
<
std
::
vector
<
uint8_t
>>
states
)
{
std
::
vector
<
rknn_input
>
inputs
(
encoder_input_attrs_
.
size
());
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
inputs
.
size
());
++
i
)
{
...
...
@@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
}
auto
ret
=
rknn_inputs_set
(
encoder_ctx_
,
inputs
.
size
(),
inputs
.
data
());
rknn_context
encoder_ctx
=
0
;
// https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75
// rknn_dup_context(rknn_context* context_in, rknn_context* context_out);
auto
ret
=
rknn_dup_context
(
&
encoder_ctx_
,
&
encoder_ctx
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to duplicate the encoder ctx"
);
ret
=
rknn_inputs_set
(
encoder_ctx
,
inputs
.
size
(),
inputs
.
data
());
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to set encoder inputs"
);
ret
=
rknn_run
(
encoder_ctx
_
,
nullptr
);
ret
=
rknn_run
(
encoder_ctx
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to run encoder"
);
ret
=
rknn_outputs_get
(
encoder_ctx
_
,
outputs
.
size
(),
outputs
.
data
(),
nullptr
);
rknn_outputs_get
(
encoder_ctx
,
outputs
.
size
(),
outputs
.
data
(),
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get encoder output"
);
for
(
int32_t
i
=
0
;
i
<
next_states
.
size
();
++
i
)
{
...
...
@@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
}
rknn_destroy
(
encoder_ctx
);
return
{
std
::
move
(
encoder_out
),
std
::
move
(
next_states
)};
}
std
::
vector
<
float
>
RunDecoder
(
std
::
vector
<
int64_t
>
decoder_input
)
const
{
std
::
vector
<
float
>
RunDecoder
(
std
::
vector
<
int64_t
>
decoder_input
)
{
auto
&
attr
=
decoder_input_attrs_
[
0
];
rknn_input
input
;
...
...
@@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl {
output
.
size
=
decoder_out
.
size
()
*
sizeof
(
float
);
output
.
buf
=
decoder_out
.
data
();
auto
ret
=
rknn_inputs_set
(
decoder_ctx_
,
1
,
&
input
);
rknn_context
decoder_ctx
=
0
;
auto
ret
=
rknn_dup_context
(
&
decoder_ctx_
,
&
decoder_ctx
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to duplicate the decoder ctx"
);
ret
=
rknn_inputs_set
(
decoder_ctx
,
1
,
&
input
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to set decoder inputs"
);
ret
=
rknn_run
(
decoder_ctx
_
,
nullptr
);
ret
=
rknn_run
(
decoder_ctx
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to run decoder"
);
ret
=
rknn_outputs_get
(
decoder_ctx
_
,
1
,
&
output
,
nullptr
);
ret
=
rknn_outputs_get
(
decoder_ctx
,
1
,
&
output
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get decoder output"
);
rknn_destroy
(
decoder_ctx
);
return
decoder_out
;
}
std
::
vector
<
float
>
RunJoiner
(
const
float
*
encoder_out
,
const
float
*
decoder_out
)
const
{
const
float
*
decoder_out
)
{
std
::
vector
<
rknn_input
>
inputs
(
2
);
inputs
[
0
].
index
=
0
;
inputs
[
0
].
type
=
RKNN_TENSOR_FLOAT32
;
...
...
@@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
output
.
size
=
joiner_out
.
size
()
*
sizeof
(
float
);
output
.
buf
=
joiner_out
.
data
();
auto
ret
=
rknn_inputs_set
(
joiner_ctx_
,
inputs
.
size
(),
inputs
.
data
());
rknn_context
joiner_ctx
=
0
;
auto
ret
=
rknn_dup_context
(
&
joiner_ctx_
,
&
joiner_ctx
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to duplicate the joiner ctx"
);
ret
=
rknn_inputs_set
(
joiner_ctx
,
inputs
.
size
(),
inputs
.
data
());
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to set joiner inputs"
);
ret
=
rknn_run
(
joiner_ctx
_
,
nullptr
);
ret
=
rknn_run
(
joiner_ctx
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to run joiner"
);
ret
=
rknn_outputs_get
(
joiner_ctx
_
,
1
,
&
output
,
nullptr
);
ret
=
rknn_outputs_get
(
joiner_ctx
,
1
,
&
output
,
nullptr
);
SHERPA_ONNX_RKNN_CHECK
(
ret
,
"Failed to get joiner output"
);
rknn_destroy
(
joiner_ctx
);
return
joiner_out
;
}
...
...
请
注册
或
登录
后发表评论