Manix
Committed by GitHub

Support TensorRT provider (#921)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
@@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) { @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) {
24 return Provider::kXnnpack; 24 return Provider::kXnnpack;
25 } else if (s == "nnapi") { 25 } else if (s == "nnapi") {
26 return Provider::kNNAPI; 26 return Provider::kNNAPI;
  27 + } else if (s == "trt") {
  28 + return Provider::kTRT;
27 } else { 29 } else {
28 SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); 30 SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
29 return Provider::kCPU; 31 return Provider::kCPU;
@@ -18,6 +18,7 @@ enum class Provider { @@ -18,6 +18,7 @@ enum class Provider {
18 kCoreML = 2, // CoreMLExecutionProvider 18 kCoreML = 2, // CoreMLExecutionProvider
19 kXnnpack = 3, // XnnpackExecutionProvider 19 kXnnpack = 3, // XnnpackExecutionProvider
20 kNNAPI = 4, // NnapiExecutionProvider 20 kNNAPI = 4, // NnapiExecutionProvider
  21 + kTRT = 5, // TensorRTExecutionProvider
21 }; 22 };
22 23
23 /** 24 /**
@@ -21,6 +21,16 @@ @@ -21,6 +21,16 @@
21 21
22 namespace sherpa_onnx { 22 namespace sherpa_onnx {
23 23
  24 +
  25 +static void OrtStatusFailure(OrtStatus *status, const char *s) {
  26 + const auto &api = Ort::GetApi();
  27 + const char *msg = api.GetErrorMessage(status);
  28 + SHERPA_ONNX_LOGE(
  29 + "Failed to enable TensorRT : %s."
  30 + "Available providers: %s. Fallback to cuda", msg, s);
  31 + api.ReleaseStatus(status);
  32 +}
  33 +
24 static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, 34 static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
25 std::string provider_str) { 35 std::string provider_str) {
26 Provider p = StringToProvider(std::move(provider_str)); 36 Provider p = StringToProvider(std::move(provider_str));
@@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
53 } 63 }
54 break; 64 break;
55 } 65 }
  66 + case Provider::kTRT: {
  67 + struct TrtPairs {
  68 + const char* op_keys;
  69 + const char* op_values;
  70 + };
  71 +
  72 + std::vector<TrtPairs> trt_options = {
  73 + {"device_id", "0"},
  74 + {"trt_max_workspace_size", "2147483648"},
  75 + {"trt_max_partition_iterations", "10"},
  76 + {"trt_min_subgraph_size", "5"},
  77 + {"trt_fp16_enable", "0"},
  78 + {"trt_detailed_build_log", "0"},
  79 + {"trt_engine_cache_enable", "1"},
  80 + {"trt_engine_cache_path", "."},
  81 + {"trt_timing_cache_enable", "1"},
  82 + {"trt_timing_cache_path", "."}
  83 + };
  84 + // ToDo : Trt configs
  85 + // "trt_int8_enable"
  86 + // "trt_int8_use_native_calibration_table"
  87 + // "trt_dump_subgraphs"
  88 +
  89 + std::vector<const char*> option_keys, option_values;
  90 + for (const TrtPairs& pair : trt_options) {
  91 + option_keys.emplace_back(pair.op_keys);
  92 + option_values.emplace_back(pair.op_values);
  93 + }
  94 +
  95 + std::vector<std::string> available_providers =
  96 + Ort::GetAvailableProviders();
  97 + if (std::find(available_providers.begin(), available_providers.end(),
  98 + "TensorrtExecutionProvider") != available_providers.end()) {
  99 + const auto& api = Ort::GetApi();
  100 +
  101 + OrtTensorRTProviderOptionsV2* tensorrt_options;
  102 + OrtStatus *statusC = api.CreateTensorRTProviderOptions(
  103 + &tensorrt_options);
  104 + OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
  105 + tensorrt_options, option_keys.data(), option_values.data(),
  106 + option_keys.size());
  107 + sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
  108 +
  109 + if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
  110 + if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }
  111 +
  112 + api.ReleaseTensorRTProviderOptions(tensorrt_options);
  113 + }
  114 + // break; is omitted here intentionally so that
  115 + // if TRT not available, CUDA will be used
  116 + }
56 case Provider::kCUDA: { 117 case Provider::kCUDA: {
57 if (std::find(available_providers.begin(), available_providers.end(), 118 if (std::find(available_providers.begin(), available_providers.end(),
58 "CUDAExecutionProvider") != available_providers.end()) { 119 "CUDAExecutionProvider") != available_providers.end()) {
@@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
116 break; 177 break;
117 } 178 }
118 } 179 }
119 -  
120 return sess_opts; 180 return sess_opts;
121 } 181 }
122 182