正在显示
1 个修改的文件
包含
25 行增加
和
27 行删除
| @@ -7,9 +7,9 @@ | @@ -7,9 +7,9 @@ | ||
| 7 | 7 | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | 9 | ||
| 10 | -#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | -#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| @@ -40,25 +40,23 @@ struct TensorrtConfig { | @@ -40,25 +40,23 @@ struct TensorrtConfig { | ||
| 40 | 40 | ||
| 41 | TensorrtConfig() = default; | 41 | TensorrtConfig() = default; |
| 42 | TensorrtConfig(int64_t trt_max_workspace_size, | 42 | TensorrtConfig(int64_t trt_max_workspace_size, |
| 43 | - int32_t trt_max_partition_iterations, | ||
| 44 | - int32_t trt_min_subgraph_size, | ||
| 45 | - bool trt_fp16_enable, | ||
| 46 | - bool trt_detailed_build_log, | ||
| 47 | - bool trt_engine_cache_enable, | ||
| 48 | - bool trt_timing_cache_enable, | ||
| 49 | - const std::string &trt_engine_cache_path, | ||
| 50 | - const std::string &trt_timing_cache_path, | ||
| 51 | - bool trt_dump_subgraphs) | 43 | + int32_t trt_max_partition_iterations, |
| 44 | + int32_t trt_min_subgraph_size, bool trt_fp16_enable, | ||
| 45 | + bool trt_detailed_build_log, bool trt_engine_cache_enable, | ||
| 46 | + bool trt_timing_cache_enable, | ||
| 47 | + const std::string &trt_engine_cache_path, | ||
| 48 | + const std::string &trt_timing_cache_path, | ||
| 49 | + bool trt_dump_subgraphs) | ||
| 52 | : trt_max_workspace_size(trt_max_workspace_size), | 50 | : trt_max_workspace_size(trt_max_workspace_size), |
| 53 | - trt_max_partition_iterations(trt_max_partition_iterations), | ||
| 54 | - trt_min_subgraph_size(trt_min_subgraph_size), | ||
| 55 | - trt_fp16_enable(trt_fp16_enable), | ||
| 56 | - trt_detailed_build_log(trt_detailed_build_log), | ||
| 57 | - trt_engine_cache_enable(trt_engine_cache_enable), | ||
| 58 | - trt_timing_cache_enable(trt_timing_cache_enable), | ||
| 59 | - trt_engine_cache_path(trt_engine_cache_path), | ||
| 60 | - trt_timing_cache_path(trt_timing_cache_path), | ||
| 61 | - trt_dump_subgraphs(trt_dump_subgraphs) {} | 51 | + trt_max_partition_iterations(trt_max_partition_iterations), |
| 52 | + trt_min_subgraph_size(trt_min_subgraph_size), | ||
| 53 | + trt_fp16_enable(trt_fp16_enable), | ||
| 54 | + trt_detailed_build_log(trt_detailed_build_log), | ||
| 55 | + trt_engine_cache_enable(trt_engine_cache_enable), | ||
| 56 | + trt_timing_cache_enable(trt_timing_cache_enable), | ||
| 57 | + trt_engine_cache_path(trt_engine_cache_path), | ||
| 58 | + trt_timing_cache_path(trt_timing_cache_path), | ||
| 59 | + trt_dump_subgraphs(trt_dump_subgraphs) {} | ||
| 62 | 60 | ||
| 63 | void Register(ParseOptions *po); | 61 | void Register(ParseOptions *po); |
| 64 | bool Validate() const; | 62 | bool Validate() const; |
| @@ -74,15 +72,15 @@ struct ProviderConfig { | @@ -74,15 +72,15 @@ struct ProviderConfig { | ||
| 74 | // device only used for cuda and trt | 72 | // device only used for cuda and trt |
| 75 | 73 | ||
| 76 | ProviderConfig() = default; | 74 | ProviderConfig() = default; |
| 77 | - ProviderConfig(const std::string &provider, | ||
| 78 | - int32_t device) | 75 | + ProviderConfig(const std::string &provider, int32_t device) |
| 79 | : provider(provider), device(device) {} | 76 | : provider(provider), device(device) {} |
| 80 | ProviderConfig(const TensorrtConfig &trt_config, | 77 | ProviderConfig(const TensorrtConfig &trt_config, |
| 81 | - const CudaConfig &cuda_config, | ||
| 82 | - const std::string &provider, | ||
| 83 | - int32_t device) | ||
| 84 | - : trt_config(trt_config), cuda_config(cuda_config), | ||
| 85 | - provider(provider), device(device) {} | 78 | + const CudaConfig &cuda_config, const std::string &provider, |
| 79 | + int32_t device) | ||
| 80 | + : trt_config(trt_config), | ||
| 81 | + cuda_config(cuda_config), | ||
| 82 | + provider(provider), | ||
| 83 | + device(device) {} | ||
| 86 | 84 | ||
| 87 | void Register(ParseOptions *po); | 85 | void Register(ParseOptions *po); |
| 88 | bool Validate() const; | 86 | bool Validate() const; |
-
请 注册 或 登录 后发表评论