Manix
Committed by GitHub

updating trt workspace int64 (#1094)

Signed-off-by: Manix <manickavela1998@gmail.com>
@@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) {
60 60
61 bool TensorrtConfig::Validate() const { 61 bool TensorrtConfig::Validate() const {
62 if (trt_max_workspace_size < 0) { 62 if (trt_max_workspace_size < 0) {
63 - SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", 63 + SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.",
64 trt_max_workspace_size); 64 trt_max_workspace_size);
65 return false; 65 return false;
66 } 66 }
@@ -27,7 +27,7 @@ struct CudaConfig { @@ -27,7 +27,7 @@ struct CudaConfig {
27 }; 27 };
28 28
29 struct TensorrtConfig { 29 struct TensorrtConfig {
30 - int32_t trt_max_workspace_size = 2147483647; 30 + int64_t trt_max_workspace_size = 2147483647;
31 int32_t trt_max_partition_iterations = 10; 31 int32_t trt_max_partition_iterations = 10;
32 int32_t trt_min_subgraph_size = 5; 32 int32_t trt_min_subgraph_size = 5;
33 bool trt_fp16_enable = true; 33 bool trt_fp16_enable = true;
@@ -39,7 +39,7 @@ struct TensorrtConfig { @@ -39,7 +39,7 @@ struct TensorrtConfig {
39 bool trt_dump_subgraphs = false; 39 bool trt_dump_subgraphs = false;
40 40
41 TensorrtConfig() = default; 41 TensorrtConfig() = default;
42 - TensorrtConfig(int32_t trt_max_workspace_size, 42 + TensorrtConfig(int64_t trt_max_workspace_size,
43 int32_t trt_max_partition_iterations, 43 int32_t trt_max_partition_iterations,
44 int32_t trt_min_subgraph_size, 44 int32_t trt_min_subgraph_size,
45 bool trt_fp16_enable, 45 bool trt_fp16_enable,
@@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) { @@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) {
14 using PyClass = TensorrtConfig; 14 using PyClass = TensorrtConfig;
15 py::class_<PyClass>(*m, "TensorrtConfig") 15 py::class_<PyClass>(*m, "TensorrtConfig")
16 .def(py::init<>()) 16 .def(py::init<>())
17 - .def(py::init([](int32_t trt_max_workspace_size, 17 + .def(py::init([](int64_t trt_max_workspace_size,
18 int32_t trt_max_partition_iterations, 18 int32_t trt_max_partition_iterations,
19 int32_t trt_min_subgraph_size, 19 int32_t trt_min_subgraph_size,
20 bool trt_fp16_enable, 20 bool trt_fp16_enable,