Committed by
GitHub
updating trt workspace int64 (#1094)
Signed-off-by: Manix <manickavela1998@gmail.com>
正在显示
3 个修改的文件
包含
4 行增加
和
4 行删除
| @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { | @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { | ||
| 60 | 60 | ||
| 61 | bool TensorrtConfig::Validate() const { | 61 | bool TensorrtConfig::Validate() const { |
| 62 | if (trt_max_workspace_size < 0) { | 62 | if (trt_max_workspace_size < 0) { |
| 63 | - SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", | 63 | + SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.", |
| 64 | trt_max_workspace_size); | 64 | trt_max_workspace_size); |
| 65 | return false; | 65 | return false; |
| 66 | } | 66 | } |
| @@ -27,7 +27,7 @@ struct CudaConfig { | @@ -27,7 +27,7 @@ struct CudaConfig { | ||
| 27 | }; | 27 | }; |
| 28 | 28 | ||
| 29 | struct TensorrtConfig { | 29 | struct TensorrtConfig { |
| 30 | - int32_t trt_max_workspace_size = 2147483647; | 30 | + int64_t trt_max_workspace_size = 2147483647; |
| 31 | int32_t trt_max_partition_iterations = 10; | 31 | int32_t trt_max_partition_iterations = 10; |
| 32 | int32_t trt_min_subgraph_size = 5; | 32 | int32_t trt_min_subgraph_size = 5; |
| 33 | bool trt_fp16_enable = true; | 33 | bool trt_fp16_enable = true; |
| @@ -39,7 +39,7 @@ struct TensorrtConfig { | @@ -39,7 +39,7 @@ struct TensorrtConfig { | ||
| 39 | bool trt_dump_subgraphs = false; | 39 | bool trt_dump_subgraphs = false; |
| 40 | 40 | ||
| 41 | TensorrtConfig() = default; | 41 | TensorrtConfig() = default; |
| 42 | - TensorrtConfig(int32_t trt_max_workspace_size, | 42 | + TensorrtConfig(int64_t trt_max_workspace_size, |
| 43 | int32_t trt_max_partition_iterations, | 43 | int32_t trt_max_partition_iterations, |
| 44 | int32_t trt_min_subgraph_size, | 44 | int32_t trt_min_subgraph_size, |
| 45 | bool trt_fp16_enable, | 45 | bool trt_fp16_enable, |
| @@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) { | @@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) { | ||
| 14 | using PyClass = TensorrtConfig; | 14 | using PyClass = TensorrtConfig; |
| 15 | py::class_<PyClass>(*m, "TensorrtConfig") | 15 | py::class_<PyClass>(*m, "TensorrtConfig") |
| 16 | .def(py::init<>()) | 16 | .def(py::init<>()) |
| 17 | - .def(py::init([](int32_t trt_max_workspace_size, | 17 | + .def(py::init([](int64_t trt_max_workspace_size, |
| 18 | int32_t trt_max_partition_iterations, | 18 | int32_t trt_max_partition_iterations, |
| 19 | int32_t trt_min_subgraph_size, | 19 | int32_t trt_min_subgraph_size, |
| 20 | bool trt_fp16_enable, | 20 | bool trt_fp16_enable, |
-
请 注册 或 登录 后发表评论