Committed by
GitHub
encoder only trt ep for transducer (#1130)
正在显示
4 个修改的文件
包含
31 行增加
和
7 行删除
| @@ -33,7 +33,9 @@ namespace sherpa_onnx { | @@ -33,7 +33,9 @@ namespace sherpa_onnx { | ||
| 33 | OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | 33 | OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( |
| 34 | const OnlineModelConfig &config) | 34 | const OnlineModelConfig &config) |
| 35 | : env_(ORT_LOGGING_LEVEL_WARNING), | 35 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 36 | - sess_opts_(GetSessionOptions(config)), | 36 | + encoder_sess_opts_(GetSessionOptions(config)), |
| 37 | + decoder_sess_opts_(GetSessionOptions(config, "decoder")), | ||
| 38 | + joiner_sess_opts_(GetSessionOptions(config, "joiner")), | ||
| 37 | config_(config), | 39 | config_(config), |
| 38 | allocator_{} { | 40 | allocator_{} { |
| 39 | { | 41 | { |
| @@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | @@ -57,7 +59,9 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | ||
| 57 | AAssetManager *mgr, const OnlineModelConfig &config) | 59 | AAssetManager *mgr, const OnlineModelConfig &config) |
| 58 | : env_(ORT_LOGGING_LEVEL_WARNING), | 60 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 59 | config_(config), | 61 | config_(config), |
| 60 | - sess_opts_(GetSessionOptions(config)), | 62 | + encoder_sess_opts_(GetSessionOptions(config)), |
| 63 | + decoder_sess_opts_(GetSessionOptions(config)), | ||
| 64 | + joiner_sess_opts_(GetSessionOptions(config)), | ||
| 61 | allocator_{} { | 65 | allocator_{} { |
| 62 | { | 66 | { |
| 63 | auto buf = ReadFile(mgr, config.transducer.encoder); | 67 | auto buf = ReadFile(mgr, config.transducer.encoder); |
| @@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | @@ -79,7 +83,7 @@ OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( | ||
| 79 | void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | 83 | void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, |
| 80 | size_t model_data_length) { | 84 | size_t model_data_length) { |
| 81 | encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | 85 | encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, |
| 82 | - model_data_length, sess_opts_); | 86 | + model_data_length, encoder_sess_opts_); |
| 83 | 87 | ||
| 84 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, | 88 | GetInputNames(encoder_sess_.get(), &encoder_input_names_, |
| 85 | &encoder_input_names_ptr_); | 89 | &encoder_input_names_ptr_); |
| @@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | @@ -132,7 +136,7 @@ void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, | ||
| 132 | void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, | 136 | void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, |
| 133 | size_t model_data_length) { | 137 | size_t model_data_length) { |
| 134 | decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, | 138 | decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data, |
| 135 | - model_data_length, sess_opts_); | 139 | + model_data_length, decoder_sess_opts_); |
| 136 | 140 | ||
| 137 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, | 141 | GetInputNames(decoder_sess_.get(), &decoder_input_names_, |
| 138 | &decoder_input_names_ptr_); | 142 | &decoder_input_names_ptr_); |
| @@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, | @@ -157,7 +161,7 @@ void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, | ||
| 157 | void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, | 161 | void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, |
| 158 | size_t model_data_length) { | 162 | size_t model_data_length) { |
| 159 | joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, | 163 | joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data, |
| 160 | - model_data_length, sess_opts_); | 164 | + model_data_length, joiner_sess_opts_); |
| 161 | 165 | ||
| 162 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, | 166 | GetInputNames(joiner_sess_.get(), &joiner_input_names_, |
| 163 | &joiner_input_names_ptr_); | 167 | &joiner_input_names_ptr_); |
| @@ -65,7 +65,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | @@ -65,7 +65,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 65 | 65 | ||
| 66 | private: | 66 | private: |
| 67 | Ort::Env env_; | 67 | Ort::Env env_; |
| 68 | - Ort::SessionOptions sess_opts_; | 68 | + Ort::SessionOptions encoder_sess_opts_; |
| 69 | + Ort::SessionOptions decoder_sess_opts_; | ||
| 70 | + Ort::SessionOptions joiner_sess_opts_; | ||
| 71 | + | ||
| 69 | Ort::AllocatorWithDefaultOptions allocator_; | 72 | Ort::AllocatorWithDefaultOptions allocator_; |
| 70 | 73 | ||
| 71 | std::unique_ptr<Ort::Session> encoder_sess_; | 74 | std::unique_ptr<Ort::Session> encoder_sess_; |
| @@ -94,7 +94,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -94,7 +94,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 94 | std::to_string(trt_config.trt_timing_cache_enable); | 94 | std::to_string(trt_config.trt_timing_cache_enable); |
| 95 | auto trt_dump_subgraphs = | 95 | auto trt_dump_subgraphs = |
| 96 | std::to_string(trt_config.trt_dump_subgraphs); | 96 | std::to_string(trt_config.trt_dump_subgraphs); |
| 97 | - | ||
| 98 | std::vector<TrtPairs> trt_options = { | 97 | std::vector<TrtPairs> trt_options = { |
| 99 | {"device_id", device_id.c_str()}, | 98 | {"device_id", device_id.c_str()}, |
| 100 | {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, | 99 | {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, |
| @@ -223,6 +222,21 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { | @@ -223,6 +222,21 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { | ||
| 223 | config.provider_config.provider, &config.provider_config); | 222 | config.provider_config.provider, &config.provider_config); |
| 224 | } | 223 | } |
| 225 | 224 | ||
| 225 | +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | ||
| 226 | + const std::string &model_type) { | ||
| 227 | + /* | ||
| 228 | + Transducer models : Only encoder will run with tensorrt, | ||
| 229 | + decoder and joiner will run with cuda | ||
| 230 | + */ | ||
| 231 | + if(config.provider_config.provider == "trt" && | ||
| 232 | + (model_type == "decoder" || model_type == "joiner")) { | ||
| 233 | + return GetSessionOptionsImpl(config.num_threads, | ||
| 234 | + "cuda", &config.provider_config); | ||
| 235 | + } | ||
| 236 | + return GetSessionOptionsImpl(config.num_threads, | ||
| 237 | + config.provider_config.provider, &config.provider_config); | ||
| 238 | +} | ||
| 239 | + | ||
| 226 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | 240 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { |
| 227 | return GetSessionOptionsImpl(config.num_threads, config.provider); | 241 | return GetSessionOptionsImpl(config.num_threads, config.provider); |
| 228 | } | 242 | } |
| @@ -24,6 +24,9 @@ namespace sherpa_onnx { | @@ -24,6 +24,9 @@ namespace sherpa_onnx { | ||
| 24 | 24 | ||
| 25 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); | 25 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); |
| 26 | 26 | ||
| 27 | +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | ||
| 28 | + const std::string &model_type); | ||
| 29 | + | ||
| 27 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); | 30 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); |
| 28 | 31 | ||
| 29 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); | 32 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); |
-
请 注册 或 登录 后发表评论