provider-config.cc
4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// sherpa-onnx/csrc/provider-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela)
#include "sherpa-onnx/csrc/provider-config.h"
#include <sstream>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void CudaConfig::Register(ParseOptions *po) {
po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search,
"CuDNN convolution algrorithm search");
}
bool CudaConfig::Validate() const {
if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) {
SHERPA_ONNX_LOGE(
"cudnn_conv_algo_search: '%d' is not a valid option."
"Options : [1,3]. Check OnnxRT docs",
cudnn_conv_algo_search);
return false;
}
return true;
}
std::string CudaConfig::ToString() const {
std::ostringstream os;
os << "CudaConfig(";
os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")";
return os.str();
}
void TensorrtConfig::Register(ParseOptions *po) {
po->Register("trt-max-workspace-size", &trt_max_workspace_size,
"Set TensorRT EP GPU memory usage limit.");
po->Register("trt-max-partition-iterations", &trt_max_partition_iterations,
"Limit partitioning iterations for model conversion.");
po->Register("trt-min-subgraph-size", &trt_min_subgraph_size,
"Set minimum size for subgraphs in partitioning.");
po->Register("trt-fp16-enable", &trt_fp16_enable,
"Enable FP16 precision for faster performance.");
po->Register("trt-detailed-build-log", &trt_detailed_build_log,
"Enable detailed logging of build steps.");
po->Register("trt-engine-cache-enable", &trt_engine_cache_enable,
"Enable caching of TensorRT engines.");
po->Register("trt-timing-cache-enable", &trt_timing_cache_enable,
"Enable use of timing cache to speed up builds.");
po->Register("trt-engine-cache-path", &trt_engine_cache_path,
"Set path to store cached TensorRT engines.");
po->Register("trt-timing-cache-path", &trt_timing_cache_path,
"Set path for storing timing cache.");
po->Register("trt-dump-subgraphs", &trt_dump_subgraphs,
"Dump optimized subgraphs for debugging.");
}
bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) {
std::ostringstream os;
os << "trt_max_workspace_size: " << trt_max_workspace_size
<< " is not valid.";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
return false;
}
if (trt_max_partition_iterations < 0) {
SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.",
trt_max_partition_iterations);
return false;
}
if (trt_min_subgraph_size < 0) {
SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.",
trt_min_subgraph_size);
return false;
}
return true;
}
std::string TensorrtConfig::ToString() const {
std::ostringstream os;
os << "TensorrtConfig(";
os << "trt_max_workspace_size=" << trt_max_workspace_size << ", ";
os << "trt_max_partition_iterations=" << trt_max_partition_iterations << ", ";
os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", ";
os << "trt_fp16_enable=\"" << (trt_fp16_enable ? "True" : "False") << "\", ";
os << "trt_detailed_build_log=\""
<< (trt_detailed_build_log ? "True" : "False") << "\", ";
os << "trt_engine_cache_enable=\""
<< (trt_engine_cache_enable ? "True" : "False") << "\", ";
os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", ";
os << "trt_timing_cache_enable=\""
<< (trt_timing_cache_enable ? "True" : "False") << "\", ";
os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\",";
os << "trt_dump_subgraphs=\"" << (trt_dump_subgraphs ? "True" : "False")
<< "\" )";
return os.str();
}
void ProviderConfig::Register(ParseOptions *po) {
cuda_config.Register(po);
trt_config.Register(po);
po->Register("device", &device, "GPU device index for CUDA and Trt EP");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool ProviderConfig::Validate() const {
if (device < 0) {
SHERPA_ONNX_LOGE("device: '%d' is invalid.", device);
return false;
}
if (provider == "cuda" && !cuda_config.Validate()) {
return false;
}
if (provider == "trt" && !trt_config.Validate()) {
return false;
}
return true;
}
std::string ProviderConfig::ToString() const {
std::ostringstream os;
os << "ProviderConfig(";
os << "device=" << device << ", ";
os << "provider=\"" << provider << "\", ";
os << "cuda_config=" << cuda_config.ToString() << ", ";
os << "trt_config=" << trt_config.ToString() << ")";
return os.str();
}
} // namespace sherpa_onnx