Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Manickavela
2024-07-15 12:22:33 +0530
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-07-15 14:52:33 +0800
Commit
11cfd33b10782b20175f0a0139409a327a4dff39
11cfd33b
1 parent
c35200dc
encoder only trt ep for transducer (#1130)
隐藏空白字符变更
内嵌
并排对比
正在显示
4 个修改的文件
包含
31 行增加
和
7 行删除
sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
sherpa-onnx/csrc/online-zipformer2-transducer-model.h
sherpa-onnx/csrc/session.cc
sherpa-onnx/csrc/session.h
sherpa-onnx/csrc/online-zipformer2-transducer-model.cc
查看文件 @
11cfd33
...
...
@@ -33,7 +33,9 @@ namespace sherpa_onnx {
OnlineZipformer2TransducerModel
::
OnlineZipformer2TransducerModel
(
const
OnlineModelConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
),
sess_opts_
(
GetSessionOptions
(
config
)),
encoder_sess_opts_
(
GetSessionOptions
(
config
)),
decoder_sess_opts_
(
GetSessionOptions
(
config
,
"decoder"
)),
joiner_sess_opts_
(
GetSessionOptions
(
config
,
"joiner"
)),
config_
(
config
),
allocator_
{}
{
{
...
...
@@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
AAssetManager
*
mgr
,
const
OnlineModelConfig
&
config
)
:
env_
(
ORT_LOGGING_LEVEL_WARNING
),
config_
(
config
),
sess_opts_
(
GetSessionOptions
(
config
)),
encoder_sess_opts_
(
GetSessionOptions
(
config
)),
decoder_sess_opts_
(
GetSessionOptions
(
config
)),
joiner_sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
{
auto
buf
=
ReadFile
(
mgr
,
config
.
transducer
.
encoder
);
...
...
@@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
void
OnlineZipformer2TransducerModel
::
InitEncoder
(
void
*
model_data
,
size_t
model_data_length
)
{
encoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
model_data_length
,
encoder_
sess_opts_
);
GetInputNames
(
encoder_sess_
.
get
(),
&
encoder_input_names_
,
&
encoder_input_names_ptr_
);
...
...
@@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data,
void
OnlineZipformer2TransducerModel
::
InitDecoder
(
void
*
model_data
,
size_t
model_data_length
)
{
decoder_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
model_data_length
,
decoder_
sess_opts_
);
GetInputNames
(
decoder_sess_
.
get
(),
&
decoder_input_names_
,
&
decoder_input_names_ptr_
);
...
...
@@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data,
void
OnlineZipformer2TransducerModel
::
InitJoiner
(
void
*
model_data
,
size_t
model_data_length
)
{
joiner_sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
model_data_length
,
joiner_
sess_opts_
);
GetInputNames
(
joiner_sess_
.
get
(),
&
joiner_input_names_
,
&
joiner_input_names_ptr_
);
...
...
sherpa-onnx/csrc/online-zipformer2-transducer-model.h
查看文件 @
11cfd33
...
...
@@ -65,7 +65,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
private
:
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
SessionOptions
encoder_sess_opts_
;
Ort
::
SessionOptions
decoder_sess_opts_
;
Ort
::
SessionOptions
joiner_sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
encoder_sess_
;
...
...
sherpa-onnx/csrc/session.cc
查看文件 @
11cfd33
...
...
@@ -94,7 +94,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std
::
to_string
(
trt_config
.
trt_timing_cache_enable
);
auto
trt_dump_subgraphs
=
std
::
to_string
(
trt_config
.
trt_dump_subgraphs
);
std
::
vector
<
TrtPairs
>
trt_options
=
{
{
"device_id"
,
device_id
.
c_str
()},
{
"trt_max_workspace_size"
,
trt_max_workspace_size
.
c_str
()},
...
...
@@ -223,6 +222,21 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
config
.
provider_config
.
provider
,
&
config
.
provider_config
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
,
const
std
::
string
&
model_type
)
{
/*
Transducer models : Only encoder will run with tensorrt,
decoder and joiner will run with cuda
*/
if
(
config
.
provider_config
.
provider
==
"trt"
&&
(
model_type
==
"decoder"
||
model_type
==
"joiner"
))
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
"cuda"
,
&
config
.
provider_config
);
}
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider_config
.
provider
,
&
config
.
provider_config
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
...
...
sherpa-onnx/csrc/session.h
查看文件 @
11cfd33
...
...
@@ -24,6 +24,9 @@ namespace sherpa_onnx {
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
,
const
std
::
string
&
model_type
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineLMConfig
&
config
);
...
...
请
注册
或
登录
后发表评论