Committed by
GitHub
Support TensorRT provider (#921)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com> Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
正在显示
3 个修改的文件
包含
64 行增加
和
1 行删除
| @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) { | @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) { | ||
| 24 | return Provider::kXnnpack; | 24 | return Provider::kXnnpack; |
| 25 | } else if (s == "nnapi") { | 25 | } else if (s == "nnapi") { |
| 26 | return Provider::kNNAPI; | 26 | return Provider::kNNAPI; |
| 27 | + } else if (s == "trt") { | ||
| 28 | + return Provider::kTRT; | ||
| 27 | } else { | 29 | } else { |
| 28 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); | 30 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); |
| 29 | return Provider::kCPU; | 31 | return Provider::kCPU; |
| @@ -18,6 +18,7 @@ enum class Provider { | @@ -18,6 +18,7 @@ enum class Provider { | ||
| 18 | kCoreML = 2, // CoreMLExecutionProvider | 18 | kCoreML = 2, // CoreMLExecutionProvider |
| 19 | kXnnpack = 3, // XnnpackExecutionProvider | 19 | kXnnpack = 3, // XnnpackExecutionProvider |
| 20 | kNNAPI = 4, // NnapiExecutionProvider | 20 | kNNAPI = 4, // NnapiExecutionProvider |
| 21 | + kTRT = 5, // TensorRTExecutionProvider | ||
| 21 | }; | 22 | }; |
| 22 | 23 | ||
| 23 | /** | 24 | /** |
| @@ -21,6 +21,16 @@ | @@ -21,6 +21,16 @@ | ||
| 21 | 21 | ||
| 22 | namespace sherpa_onnx { | 22 | namespace sherpa_onnx { |
| 23 | 23 | ||
| 24 | + | ||
| 25 | +static void OrtStatusFailure(OrtStatus *status, const char *s) { | ||
| 26 | + const auto &api = Ort::GetApi(); | ||
| 27 | + const char *msg = api.GetErrorMessage(status); | ||
| 28 | + SHERPA_ONNX_LOGE( | ||
| 29 | + "Failed to enable TensorRT : %s." | ||
| 30 | + "Available providers: %s. Fallback to cuda", msg, s); | ||
| 31 | + api.ReleaseStatus(status); | ||
| 32 | +} | ||
| 33 | + | ||
| 24 | static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | 34 | static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, |
| 25 | std::string provider_str) { | 35 | std::string provider_str) { |
| 26 | Provider p = StringToProvider(std::move(provider_str)); | 36 | Provider p = StringToProvider(std::move(provider_str)); |
| @@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 53 | } | 63 | } |
| 54 | break; | 64 | break; |
| 55 | } | 65 | } |
| 66 | + case Provider::kTRT: { | ||
| 67 | + struct TrtPairs { | ||
| 68 | + const char* op_keys; | ||
| 69 | + const char* op_values; | ||
| 70 | + }; | ||
| 71 | + | ||
| 72 | + std::vector<TrtPairs> trt_options = { | ||
| 73 | + {"device_id", "0"}, | ||
| 74 | + {"trt_max_workspace_size", "2147483648"}, | ||
| 75 | + {"trt_max_partition_iterations", "10"}, | ||
| 76 | + {"trt_min_subgraph_size", "5"}, | ||
| 77 | + {"trt_fp16_enable", "0"}, | ||
| 78 | + {"trt_detailed_build_log", "0"}, | ||
| 79 | + {"trt_engine_cache_enable", "1"}, | ||
| 80 | + {"trt_engine_cache_path", "."}, | ||
| 81 | + {"trt_timing_cache_enable", "1"}, | ||
| 82 | + {"trt_timing_cache_path", "."} | ||
| 83 | + }; | ||
| 84 | + // ToDo : Trt configs | ||
| 85 | + // "trt_int8_enable" | ||
| 86 | + // "trt_int8_use_native_calibration_table" | ||
| 87 | + // "trt_dump_subgraphs" | ||
| 88 | + | ||
| 89 | + std::vector<const char*> option_keys, option_values; | ||
| 90 | + for (const TrtPairs& pair : trt_options) { | ||
| 91 | + option_keys.emplace_back(pair.op_keys); | ||
| 92 | + option_values.emplace_back(pair.op_values); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + std::vector<std::string> available_providers = | ||
| 96 | + Ort::GetAvailableProviders(); | ||
| 97 | + if (std::find(available_providers.begin(), available_providers.end(), | ||
| 98 | + "TensorrtExecutionProvider") != available_providers.end()) { | ||
| 99 | + const auto& api = Ort::GetApi(); | ||
| 100 | + | ||
| 101 | + OrtTensorRTProviderOptionsV2* tensorrt_options; | ||
| 102 | + OrtStatus *statusC = api.CreateTensorRTProviderOptions( | ||
| 103 | + &tensorrt_options); | ||
| 104 | + OrtStatus *statusU = api.UpdateTensorRTProviderOptions( | ||
| 105 | + tensorrt_options, option_keys.data(), option_values.data(), | ||
| 106 | + option_keys.size()); | ||
| 107 | + sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); | ||
| 108 | + | ||
| 109 | + if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); } | ||
| 110 | + if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); } | ||
| 111 | + | ||
| 112 | + api.ReleaseTensorRTProviderOptions(tensorrt_options); | ||
| 113 | + } | ||
| 114 | + // break; is omitted here intentionally so that | ||
| 115 | + // if TRT not available, CUDA will be used | ||
| 116 | + } | ||
| 56 | case Provider::kCUDA: { | 117 | case Provider::kCUDA: { |
| 57 | if (std::find(available_providers.begin(), available_providers.end(), | 118 | if (std::find(available_providers.begin(), available_providers.end(), |
| 58 | "CUDAExecutionProvider") != available_providers.end()) { | 119 | "CUDAExecutionProvider") != available_providers.end()) { |
| @@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 116 | break; | 177 | break; |
| 117 | } | 178 | } |
| 118 | } | 179 | } |
| 119 | - | ||
| 120 | return sess_opts; | 180 | return sess_opts; |
| 121 | } | 181 | } |
| 122 | 182 |
-
请 注册 或 登录 后发表评论