Committed by
GitHub
Add config for TensorRT and CUDA execution provider (#992)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com> Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
正在显示
21 个修改的文件
包含
622 行增加
和
49 行删除
| @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( | ||
| 73 | SHERPA_ONNX_OR(config->model_config.tokens, ""); | 73 | SHERPA_ONNX_OR(config->model_config.tokens, ""); |
| 74 | recognizer_config.model_config.num_threads = | 74 | recognizer_config.model_config.num_threads = |
| 75 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); | 75 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); |
| 76 | - recognizer_config.model_config.provider = | 76 | + recognizer_config.model_config.provider_config.provider = |
| 77 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | 77 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); |
| 78 | recognizer_config.model_config.model_type = | 78 | recognizer_config.model_config.model_type = |
| 79 | SHERPA_ONNX_OR(config->model_config.model_type, ""); | 79 | SHERPA_ONNX_OR(config->model_config.model_type, ""); |
| @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter( | @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter( | ||
| 570 | SHERPA_ONNX_OR(config->model_config.tokens, ""); | 570 | SHERPA_ONNX_OR(config->model_config.tokens, ""); |
| 571 | spotter_config.model_config.num_threads = | 571 | spotter_config.model_config.num_threads = |
| 572 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); | 572 | SHERPA_ONNX_OR(config->model_config.num_threads, 1); |
| 573 | - spotter_config.model_config.provider = | 573 | + spotter_config.model_config.provider_config.provider = |
| 574 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | 574 | SHERPA_ONNX_OR(config->model_config.provider, "cpu"); |
| 575 | spotter_config.model_config.model_type = | 575 | spotter_config.model_config.model_type = |
| 576 | SHERPA_ONNX_OR(config->model_config.model_type, ""); | 576 | SHERPA_ONNX_OR(config->model_config.model_type, ""); |
| @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 16 | wenet_ctc.Register(po); | 16 | wenet_ctc.Register(po); |
| 17 | zipformer2_ctc.Register(po); | 17 | zipformer2_ctc.Register(po); |
| 18 | nemo_ctc.Register(po); | 18 | nemo_ctc.Register(po); |
| 19 | + provider_config.Register(po); | ||
| 19 | 20 | ||
| 20 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 21 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 21 | 22 | ||
| @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 29 | po->Register("debug", &debug, | 30 | po->Register("debug", &debug, |
| 30 | "true to print model information while loading it."); | 31 | "true to print model information while loading it."); |
| 31 | 32 | ||
| 32 | - po->Register("provider", &provider, | ||
| 33 | - "Specify a provider to use: cpu, cuda, coreml"); | ||
| 34 | - | ||
| 35 | po->Register("modeling-unit", &modeling_unit, | 33 | po->Register("modeling-unit", &modeling_unit, |
| 36 | "The modeling unit of the model, commonly used units are bpe, " | 34 | "The modeling unit of the model, commonly used units are bpe, " |
| 37 | "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " | 35 | "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " |
| @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const { | @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const { | ||
| 87 | return nemo_ctc.Validate(); | 85 | return nemo_ctc.Validate(); |
| 88 | } | 86 | } |
| 89 | 87 | ||
| 88 | + if (!provider_config.Validate()) { | ||
| 89 | + return false; | ||
| 90 | + } | ||
| 91 | + | ||
| 90 | return transducer.Validate(); | 92 | return transducer.Validate(); |
| 91 | } | 93 | } |
| 92 | 94 | ||
| @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const { | @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const { | ||
| 99 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | 101 | os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; |
| 100 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; | 102 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; |
| 101 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | 103 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; |
| 104 | + os << "provider_config=" << provider_config.ToString() << ", "; | ||
| 102 | os << "tokens=\"" << tokens << "\", "; | 105 | os << "tokens=\"" << tokens << "\", "; |
| 103 | os << "num_threads=" << num_threads << ", "; | 106 | os << "num_threads=" << num_threads << ", "; |
| 104 | os << "warm_up=" << warm_up << ", "; | 107 | os << "warm_up=" << warm_up << ", "; |
| 105 | os << "debug=" << (debug ? "True" : "False") << ", "; | 108 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 106 | - os << "provider=\"" << provider << "\", "; | ||
| 107 | os << "model_type=\"" << model_type << "\", "; | 109 | os << "model_type=\"" << model_type << "\", "; |
| 108 | os << "modeling_unit=\"" << modeling_unit << "\", "; | 110 | os << "modeling_unit=\"" << modeling_unit << "\", "; |
| 109 | os << "bpe_vocab=\"" << bpe_vocab << "\")"; | 111 | os << "bpe_vocab=\"" << bpe_vocab << "\")"; |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | 12 | #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" |
| 14 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| @@ -20,11 +21,11 @@ struct OnlineModelConfig { | @@ -20,11 +21,11 @@ struct OnlineModelConfig { | ||
| 20 | OnlineWenetCtcModelConfig wenet_ctc; | 21 | OnlineWenetCtcModelConfig wenet_ctc; |
| 21 | OnlineZipformer2CtcModelConfig zipformer2_ctc; | 22 | OnlineZipformer2CtcModelConfig zipformer2_ctc; |
| 22 | OnlineNeMoCtcModelConfig nemo_ctc; | 23 | OnlineNeMoCtcModelConfig nemo_ctc; |
| 24 | + ProviderConfig provider_config; | ||
| 23 | std::string tokens; | 25 | std::string tokens; |
| 24 | int32_t num_threads = 1; | 26 | int32_t num_threads = 1; |
| 25 | int32_t warm_up = 0; | 27 | int32_t warm_up = 0; |
| 26 | bool debug = false; | 28 | bool debug = false; |
| 27 | - std::string provider = "cpu"; | ||
| 28 | 29 | ||
| 29 | // Valid values: | 30 | // Valid values: |
| 30 | // - conformer, conformer transducer from icefall | 31 | // - conformer, conformer transducer from icefall |
| @@ -50,8 +51,9 @@ struct OnlineModelConfig { | @@ -50,8 +51,9 @@ struct OnlineModelConfig { | ||
| 50 | const OnlineWenetCtcModelConfig &wenet_ctc, | 51 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 51 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | 52 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, |
| 52 | const OnlineNeMoCtcModelConfig &nemo_ctc, | 53 | const OnlineNeMoCtcModelConfig &nemo_ctc, |
| 54 | + const ProviderConfig &provider_config, | ||
| 53 | const std::string &tokens, int32_t num_threads, | 55 | const std::string &tokens, int32_t num_threads, |
| 54 | - int32_t warm_up, bool debug, const std::string &provider, | 56 | + int32_t warm_up, bool debug, |
| 55 | const std::string &model_type, | 57 | const std::string &model_type, |
| 56 | const std::string &modeling_unit, | 58 | const std::string &modeling_unit, |
| 57 | const std::string &bpe_vocab) | 59 | const std::string &bpe_vocab) |
| @@ -60,11 +62,11 @@ struct OnlineModelConfig { | @@ -60,11 +62,11 @@ struct OnlineModelConfig { | ||
| 60 | wenet_ctc(wenet_ctc), | 62 | wenet_ctc(wenet_ctc), |
| 61 | zipformer2_ctc(zipformer2_ctc), | 63 | zipformer2_ctc(zipformer2_ctc), |
| 62 | nemo_ctc(nemo_ctc), | 64 | nemo_ctc(nemo_ctc), |
| 65 | + provider_config(provider_config), | ||
| 63 | tokens(tokens), | 66 | tokens(tokens), |
| 64 | num_threads(num_threads), | 67 | num_threads(num_threads), |
| 65 | warm_up(warm_up), | 68 | warm_up(warm_up), |
| 66 | debug(debug), | 69 | debug(debug), |
| 67 | - provider(provider), | ||
| 68 | model_type(model_type), | 70 | model_type(model_type), |
| 69 | modeling_unit(modeling_unit), | 71 | modeling_unit(modeling_unit), |
| 70 | bpe_vocab(bpe_vocab) {} | 72 | bpe_vocab(bpe_vocab) {} |
sherpa-onnx/csrc/provider-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/provider-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void CudaConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, | ||
| 16 | + "CuDNN convolution algrorithm search"); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool CudaConfig::Validate() const { | ||
| 20 | + if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { | ||
| 21 | + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." | ||
| 22 | + "Options : [1,3]. Check OnnxRT docs", | ||
| 23 | + cudnn_conv_algo_search); | ||
| 24 | + return false; | ||
| 25 | + } | ||
| 26 | + return true; | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +std::string CudaConfig::ToString() const { | ||
| 30 | + std::ostringstream os; | ||
| 31 | + | ||
| 32 | + os << "CudaConfig("; | ||
| 33 | + os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")"; | ||
| 34 | + | ||
| 35 | + return os.str(); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +void TensorrtConfig::Register(ParseOptions *po) { | ||
| 39 | + po->Register("trt-max-workspace-size", &trt_max_workspace_size, | ||
| 40 | + "Set TensorRT EP GPU memory usage limit."); | ||
| 41 | + po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, | ||
| 42 | + "Limit partitioning iterations for model conversion."); | ||
| 43 | + po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, | ||
| 44 | + "Set minimum size for subgraphs in partitioning."); | ||
| 45 | + po->Register("trt-fp16-enable", &trt_fp16_enable, | ||
| 46 | + "Enable FP16 precision for faster performance."); | ||
| 47 | + po->Register("trt-detailed-build-log", &trt_detailed_build_log, | ||
| 48 | + "Enable detailed logging of build steps."); | ||
| 49 | + po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, | ||
| 50 | + "Enable caching of TensorRT engines."); | ||
| 51 | + po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, | ||
| 52 | + "Enable use of timing cache to speed up builds."); | ||
| 53 | + po->Register("trt-engine-cache-path", &trt_engine_cache_path, | ||
| 54 | + "Set path to store cached TensorRT engines."); | ||
| 55 | + po->Register("trt-timing-cache-path", &trt_timing_cache_path, | ||
| 56 | + "Set path for storing timing cache."); | ||
| 57 | + po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, | ||
| 58 | + "Dump optimized subgraphs for debugging."); | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +bool TensorrtConfig::Validate() const { | ||
| 62 | + if (trt_max_workspace_size < 0) { | ||
| 63 | + SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", | ||
| 64 | + trt_max_workspace_size); | ||
| 65 | + return false; | ||
| 66 | + } | ||
| 67 | + if (trt_max_partition_iterations < 0) { | ||
| 68 | + SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", | ||
| 69 | + trt_max_partition_iterations); | ||
| 70 | + return false; | ||
| 71 | + } | ||
| 72 | + if (trt_min_subgraph_size < 0) { | ||
| 73 | + SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", | ||
| 74 | + trt_min_subgraph_size); | ||
| 75 | + return false; | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + return true; | ||
| 79 | +} | ||
| 80 | + | ||
| 81 | +std::string TensorrtConfig::ToString() const { | ||
| 82 | + std::ostringstream os; | ||
| 83 | + | ||
| 84 | + os << "TensorrtConfig("; | ||
| 85 | + os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; | ||
| 86 | + os << "trt_max_partition_iterations=" | ||
| 87 | + << trt_max_partition_iterations << ", "; | ||
| 88 | + os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; | ||
| 89 | + os << "trt_fp16_enable=\"" | ||
| 90 | + << (trt_fp16_enable? "True" : "False") << "\", "; | ||
| 91 | + os << "trt_detailed_build_log=\"" | ||
| 92 | + << (trt_detailed_build_log? "True" : "False") << "\", "; | ||
| 93 | + os << "trt_engine_cache_enable=\"" | ||
| 94 | + << (trt_engine_cache_enable? "True" : "False") << "\", "; | ||
| 95 | + os << "trt_engine_cache_path=\"" | ||
| 96 | + << trt_engine_cache_path.c_str() << "\", "; | ||
| 97 | + os << "trt_timing_cache_enable=\"" | ||
| 98 | + << (trt_timing_cache_enable? "True" : "False") << "\", "; | ||
| 99 | + os << "trt_timing_cache_path=\"" | ||
| 100 | + << trt_timing_cache_path.c_str() << "\","; | ||
| 101 | + os << "trt_dump_subgraphs=\"" | ||
| 102 | + << (trt_dump_subgraphs? "True" : "False") << "\" )"; | ||
| 103 | + return os.str(); | ||
| 104 | +} | ||
| 105 | + | ||
| 106 | +void ProviderConfig::Register(ParseOptions *po) { | ||
| 107 | + cuda_config.Register(po); | ||
| 108 | + trt_config.Register(po); | ||
| 109 | + | ||
| 110 | + po->Register("device", &device, "GPU device index for CUDA and Trt EP"); | ||
| 111 | + po->Register("provider", &provider, | ||
| 112 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 113 | +} | ||
| 114 | + | ||
| 115 | +bool ProviderConfig::Validate() const { | ||
| 116 | + if (device < 0) { | ||
| 117 | + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); | ||
| 118 | + return false; | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + if (provider == "cuda" && !cuda_config.Validate()) { | ||
| 122 | + return false; | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + if (provider == "trt" && !trt_config.Validate()) { | ||
| 126 | + return false; | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + return true; | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +std::string ProviderConfig::ToString() const { | ||
| 133 | + std::ostringstream os; | ||
| 134 | + | ||
| 135 | + os << "ProviderConfig("; | ||
| 136 | + os << "device=" << device << ", "; | ||
| 137 | + os << "provider=\"" << provider << "\", "; | ||
| 138 | + os << "cuda_config=" << cuda_config.ToString() << ", "; | ||
| 139 | + os << "trt_config=" << trt_config.ToString() << ")"; | ||
| 140 | + return os.str(); | ||
| 141 | +} | ||
| 142 | + | ||
| 143 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/provider-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/provider-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +struct CudaConfig { | ||
| 17 | + int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; | ||
| 18 | + | ||
| 19 | + CudaConfig() = default; | ||
| 20 | + explicit CudaConfig(int32_t cudnn_conv_algo_search) | ||
| 21 | + : cudnn_conv_algo_search(cudnn_conv_algo_search) {} | ||
| 22 | + | ||
| 23 | + void Register(ParseOptions *po); | ||
| 24 | + bool Validate() const; | ||
| 25 | + | ||
| 26 | + std::string ToString() const; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +struct TensorrtConfig { | ||
| 30 | + int32_t trt_max_workspace_size = 2147483647; | ||
| 31 | + int32_t trt_max_partition_iterations = 10; | ||
| 32 | + int32_t trt_min_subgraph_size = 5; | ||
| 33 | + bool trt_fp16_enable = true; | ||
| 34 | + bool trt_detailed_build_log = false; | ||
| 35 | + bool trt_engine_cache_enable = true; | ||
| 36 | + bool trt_timing_cache_enable = true; | ||
| 37 | + std::string trt_engine_cache_path = "."; | ||
| 38 | + std::string trt_timing_cache_path = "."; | ||
| 39 | + bool trt_dump_subgraphs = false; | ||
| 40 | + | ||
| 41 | + TensorrtConfig() = default; | ||
| 42 | + TensorrtConfig(int32_t trt_max_workspace_size, | ||
| 43 | + int32_t trt_max_partition_iterations, | ||
| 44 | + int32_t trt_min_subgraph_size, | ||
| 45 | + bool trt_fp16_enable, | ||
| 46 | + bool trt_detailed_build_log, | ||
| 47 | + bool trt_engine_cache_enable, | ||
| 48 | + bool trt_timing_cache_enable, | ||
| 49 | + const std::string &trt_engine_cache_path, | ||
| 50 | + const std::string &trt_timing_cache_path, | ||
| 51 | + bool trt_dump_subgraphs) | ||
| 52 | + : trt_max_workspace_size(trt_max_workspace_size), | ||
| 53 | + trt_max_partition_iterations(trt_max_partition_iterations), | ||
| 54 | + trt_min_subgraph_size(trt_min_subgraph_size), | ||
| 55 | + trt_fp16_enable(trt_fp16_enable), | ||
| 56 | + trt_detailed_build_log(trt_detailed_build_log), | ||
| 57 | + trt_engine_cache_enable(trt_engine_cache_enable), | ||
| 58 | + trt_timing_cache_enable(trt_timing_cache_enable), | ||
| 59 | + trt_engine_cache_path(trt_engine_cache_path), | ||
| 60 | + trt_timing_cache_path(trt_timing_cache_path), | ||
| 61 | + trt_dump_subgraphs(trt_dump_subgraphs) {} | ||
| 62 | + | ||
| 63 | + void Register(ParseOptions *po); | ||
| 64 | + bool Validate() const; | ||
| 65 | + | ||
| 66 | + std::string ToString() const; | ||
| 67 | +}; | ||
| 68 | + | ||
| 69 | +struct ProviderConfig { | ||
| 70 | + TensorrtConfig trt_config; | ||
| 71 | + CudaConfig cuda_config; | ||
| 72 | + std::string provider = "cpu"; | ||
| 73 | + int32_t device = 0; | ||
| 74 | + // device only used for cuda and trt | ||
| 75 | + | ||
| 76 | + ProviderConfig() = default; | ||
| 77 | + ProviderConfig(const std::string &provider, | ||
| 78 | + int32_t device) | ||
| 79 | + : provider(provider), device(device) {} | ||
| 80 | + ProviderConfig(const TensorrtConfig &trt_config, | ||
| 81 | + const CudaConfig &cuda_config, | ||
| 82 | + const std::string &provider, | ||
| 83 | + int32_t device) | ||
| 84 | + : trt_config(trt_config), cuda_config(cuda_config), | ||
| 85 | + provider(provider), device(device) {} | ||
| 86 | + | ||
| 87 | + void Register(ParseOptions *po); | ||
| 88 | + bool Validate() const; | ||
| 89 | + | ||
| 90 | + std::string ToString() const; | ||
| 91 | +}; | ||
| 92 | + | ||
| 93 | +} // namespace sherpa_onnx | ||
| 94 | + | ||
| 95 | +#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ |
| @@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | @@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | ||
| 32 | } | 32 | } |
| 33 | 33 | ||
| 34 | static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | 34 | static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, |
| 35 | - std::string provider_str) { | ||
| 36 | - Provider p = StringToProvider(std::move(provider_str)); | 35 | + const std::string &provider_str, |
| 36 | + const ProviderConfig *provider_config = nullptr) { | ||
| 37 | + Provider p = StringToProvider(provider_str); | ||
| 37 | 38 | ||
| 38 | Ort::SessionOptions sess_opts; | 39 | Ort::SessionOptions sess_opts; |
| 39 | sess_opts.SetIntraOpNumThreads(num_threads); | 40 | sess_opts.SetIntraOpNumThreads(num_threads); |
| 41 | + | ||
| 40 | sess_opts.SetInterOpNumThreads(num_threads); | 42 | sess_opts.SetInterOpNumThreads(num_threads); |
| 41 | 43 | ||
| 42 | std::vector<std::string> available_providers = Ort::GetAvailableProviders(); | 44 | std::vector<std::string> available_providers = Ort::GetAvailableProviders(); |
| @@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 64 | break; | 66 | break; |
| 65 | } | 67 | } |
| 66 | case Provider::kTRT: { | 68 | case Provider::kTRT: { |
| 69 | + if (provider_config == nullptr) { | ||
| 70 | + SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," | ||
| 71 | + "Must be extended for offline and others"); | ||
| 72 | + exit(1); | ||
| 73 | + } | ||
| 74 | + auto trt_config = provider_config->trt_config; | ||
| 67 | struct TrtPairs { | 75 | struct TrtPairs { |
| 68 | const char *op_keys; | 76 | const char *op_keys; |
| 69 | const char *op_values; | 77 | const char *op_values; |
| 70 | }; | 78 | }; |
| 71 | 79 | ||
| 80 | + auto device_id = std::to_string(provider_config->device); | ||
| 81 | + auto trt_max_workspace_size = | ||
| 82 | + std::to_string(trt_config.trt_max_workspace_size); | ||
| 83 | + auto trt_max_partition_iterations = | ||
| 84 | + std::to_string(trt_config.trt_max_partition_iterations); | ||
| 85 | + auto trt_min_subgraph_size = | ||
| 86 | + std::to_string(trt_config.trt_min_subgraph_size); | ||
| 87 | + auto trt_fp16_enable = | ||
| 88 | + std::to_string(trt_config.trt_fp16_enable); | ||
| 89 | + auto trt_detailed_build_log = | ||
| 90 | + std::to_string(trt_config.trt_detailed_build_log); | ||
| 91 | + auto trt_engine_cache_enable = | ||
| 92 | + std::to_string(trt_config.trt_engine_cache_enable); | ||
| 93 | + auto trt_timing_cache_enable = | ||
| 94 | + std::to_string(trt_config.trt_timing_cache_enable); | ||
| 95 | + auto trt_dump_subgraphs = | ||
| 96 | + std::to_string(trt_config.trt_dump_subgraphs); | ||
| 97 | + | ||
| 72 | std::vector<TrtPairs> trt_options = { | 98 | std::vector<TrtPairs> trt_options = { |
| 73 | - {"device_id", "0"}, | ||
| 74 | - {"trt_max_workspace_size", "2147483648"}, | ||
| 75 | - {"trt_max_partition_iterations", "10"}, | ||
| 76 | - {"trt_min_subgraph_size", "5"}, | ||
| 77 | - {"trt_fp16_enable", "0"}, | ||
| 78 | - {"trt_detailed_build_log", "0"}, | ||
| 79 | - {"trt_engine_cache_enable", "1"}, | ||
| 80 | - {"trt_engine_cache_path", "."}, | ||
| 81 | - {"trt_timing_cache_enable", "1"}, | ||
| 82 | - {"trt_timing_cache_path", "."}}; | 99 | + {"device_id", device_id.c_str()}, |
| 100 | + {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, | ||
| 101 | + {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, | ||
| 102 | + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, | ||
| 103 | + {"trt_fp16_enable", trt_fp16_enable.c_str()}, | ||
| 104 | + {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, | ||
| 105 | + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, | ||
| 106 | + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, | ||
| 107 | + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, | ||
| 108 | + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, | ||
| 109 | + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()} | ||
| 110 | + }; | ||
| 83 | // ToDo : Trt configs | 111 | // ToDo : Trt configs |
| 84 | // "trt_int8_enable" | 112 | // "trt_int8_enable" |
| 85 | // "trt_int8_use_native_calibration_table" | 113 | // "trt_int8_use_native_calibration_table" |
| 86 | - // "trt_dump_subgraphs" | ||
| 87 | 114 | ||
| 88 | std::vector<const char *> option_keys, option_values; | 115 | std::vector<const char *> option_keys, option_values; |
| 89 | for (const TrtPairs &pair : trt_options) { | 116 | for (const TrtPairs &pair : trt_options) { |
| @@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 122 | "CUDAExecutionProvider") != available_providers.end()) { | 149 | "CUDAExecutionProvider") != available_providers.end()) { |
| 123 | // The CUDA provider is available, proceed with setting the options | 150 | // The CUDA provider is available, proceed with setting the options |
| 124 | OrtCUDAProviderOptions options; | 151 | OrtCUDAProviderOptions options; |
| 125 | - options.device_id = 0; | ||
| 126 | - // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow | ||
| 127 | - options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; | ||
| 128 | - // set more options on need | 152 | + |
| 153 | + if (provider_config != nullptr) { | ||
| 154 | + options.device_id = provider_config->device; | ||
| 155 | + options.cudnn_conv_algo_search = | ||
| 156 | + OrtCudnnConvAlgoSearch(provider_config->cuda_config | ||
| 157 | + .cudnn_conv_algo_search); | ||
| 158 | + } else { | ||
| 159 | + options.device_id = 0; | ||
| 160 | + // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow | ||
| 161 | + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; | ||
| 162 | + // set more options on need | ||
| 163 | + } | ||
| 129 | sess_opts.AppendExecutionProvider_CUDA(options); | 164 | sess_opts.AppendExecutionProvider_CUDA(options); |
| 130 | } else { | 165 | } else { |
| 131 | SHERPA_ONNX_LOGE( | 166 | SHERPA_ONNX_LOGE( |
| @@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 184 | } | 219 | } |
| 185 | 220 | ||
| 186 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { | 221 | Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { |
| 187 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | 222 | + return GetSessionOptionsImpl(config.num_threads, |
| 223 | + config.provider_config.provider, &config.provider_config); | ||
| 188 | } | 224 | } |
| 189 | 225 | ||
| 190 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | 226 | Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { |
| @@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { | @@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { | ||
| 94 | fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | 94 | fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); |
| 95 | s = (jstring)env->GetObjectField(model_config, fid); | 95 | s = (jstring)env->GetObjectField(model_config, fid); |
| 96 | p = env->GetStringUTFChars(s, nullptr); | 96 | p = env->GetStringUTFChars(s, nullptr); |
| 97 | - ans.model_config.provider = p; | 97 | + ans.model_config.provider_config.provider = p; |
| 98 | env->ReleaseStringUTFChars(s, p); | 98 | env->ReleaseStringUTFChars(s, p); |
| 99 | 99 | ||
| 100 | fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); | 100 | fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); |
| @@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | @@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { | ||
| 198 | fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); | 198 | fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); |
| 199 | s = (jstring)env->GetObjectField(model_config, fid); | 199 | s = (jstring)env->GetObjectField(model_config, fid); |
| 200 | p = env->GetStringUTFChars(s, nullptr); | 200 | p = env->GetStringUTFChars(s, nullptr); |
| 201 | - ans.model_config.provider = p; | 201 | + ans.model_config.provider_config.provider = p; |
| 202 | env->ReleaseStringUTFChars(s, p); | 202 | env->ReleaseStringUTFChars(s, p); |
| 203 | 203 | ||
| 204 | fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); | 204 | fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); |
| @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | set(srcs | 3 | set(srcs |
| 4 | audio-tagging.cc | 4 | audio-tagging.cc |
| 5 | circular-buffer.cc | 5 | circular-buffer.cc |
| 6 | + cuda-config.cc | ||
| 6 | display.cc | 7 | display.cc |
| 7 | endpoint.cc | 8 | endpoint.cc |
| 8 | features.cc | 9 | features.cc |
| @@ -30,11 +31,13 @@ set(srcs | @@ -30,11 +31,13 @@ set(srcs | ||
| 30 | online-transducer-model-config.cc | 31 | online-transducer-model-config.cc |
| 31 | online-wenet-ctc-model-config.cc | 32 | online-wenet-ctc-model-config.cc |
| 32 | online-zipformer2-ctc-model-config.cc | 33 | online-zipformer2-ctc-model-config.cc |
| 34 | + provider-config.cc | ||
| 33 | sherpa-onnx.cc | 35 | sherpa-onnx.cc |
| 34 | silero-vad-model-config.cc | 36 | silero-vad-model-config.cc |
| 35 | speaker-embedding-extractor.cc | 37 | speaker-embedding-extractor.cc |
| 36 | speaker-embedding-manager.cc | 38 | speaker-embedding-manager.cc |
| 37 | spoken-language-identification.cc | 39 | spoken-language-identification.cc |
| 40 | + tensorrt-config.cc | ||
| 38 | vad-model-config.cc | 41 | vad-model-config.cc |
| 39 | vad-model.cc | 42 | vad-model.cc |
| 40 | voice-activity-detector.cc | 43 | voice-activity-detector.cc |
sherpa-onnx/python/csrc/cuda-config.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/cuda-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/cuda-config.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindCudaConfig(py::module *m) { | ||
| 15 | + using PyClass = CudaConfig; | ||
| 16 | + py::class_<PyClass>(*m, "CudaConfig") | ||
| 17 | + .def(py::init<>()) | ||
| 18 | + .def(py::init<int32_t>(), | ||
| 19 | + py::arg("cudnn_conv_algo_search") = 1) | ||
| 20 | + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) | ||
| 21 | + .def("__str__", &PyClass::ToString); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/cuda-config.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/cuda-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindCudaConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ |
| @@ -9,11 +9,13 @@ | @@ -9,11 +9,13 @@ | ||
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/online-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 12 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 12 | #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| 15 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" | 16 | #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" |
| 16 | #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" | 17 | #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" |
| 18 | +#include "sherpa-onnx/python/csrc/provider-config.h" | ||
| 17 | 19 | ||
| 18 | namespace sherpa_onnx { | 20 | namespace sherpa_onnx { |
| 19 | 21 | ||
| @@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 23 | PybindOnlineWenetCtcModelConfig(m); | 25 | PybindOnlineWenetCtcModelConfig(m); |
| 24 | PybindOnlineZipformer2CtcModelConfig(m); | 26 | PybindOnlineZipformer2CtcModelConfig(m); |
| 25 | PybindOnlineNeMoCtcModelConfig(m); | 27 | PybindOnlineNeMoCtcModelConfig(m); |
| 28 | + PybindProviderConfig(m); | ||
| 26 | 29 | ||
| 27 | using PyClass = OnlineModelConfig; | 30 | using PyClass = OnlineModelConfig; |
| 28 | py::class_<PyClass>(*m, "OnlineModelConfig") | 31 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| @@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 30 | const OnlineParaformerModelConfig &, | 33 | const OnlineParaformerModelConfig &, |
| 31 | const OnlineWenetCtcModelConfig &, | 34 | const OnlineWenetCtcModelConfig &, |
| 32 | const OnlineZipformer2CtcModelConfig &, | 35 | const OnlineZipformer2CtcModelConfig &, |
| 33 | - const OnlineNeMoCtcModelConfig &, const std::string &, | ||
| 34 | - int32_t, int32_t, bool, const std::string &, | ||
| 35 | - const std::string &, const std::string &, | 36 | + const OnlineNeMoCtcModelConfig &, |
| 37 | + const ProviderConfig &, | ||
| 38 | + const std::string &, int32_t, int32_t, | ||
| 39 | + bool, const std::string &, const std::string &, | ||
| 36 | const std::string &>(), | 40 | const std::string &>(), |
| 37 | py::arg("transducer") = OnlineTransducerModelConfig(), | 41 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 38 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 42 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 39 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 43 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
| 40 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), | 44 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), |
| 41 | - py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), | ||
| 42 | - py::arg("num_threads"), py::arg("warm_up") = 0, | ||
| 43 | - py::arg("debug") = false, py::arg("provider") = "cpu", | ||
| 44 | - py::arg("model_type") = "", py::arg("modeling_unit") = "", | ||
| 45 | - py::arg("bpe_vocab") = "") | 45 | + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), |
| 46 | + py::arg("provider_config") = ProviderConfig(), | ||
| 47 | + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, | ||
| 48 | + py::arg("debug") = false, py::arg("model_type") = "", | ||
| 49 | + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") | ||
| 46 | .def_readwrite("transducer", &PyClass::transducer) | 50 | .def_readwrite("transducer", &PyClass::transducer) |
| 47 | .def_readwrite("paraformer", &PyClass::paraformer) | 51 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 48 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 52 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
| 49 | .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) | 53 | .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) |
| 50 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 54 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| 55 | + .def_readwrite("provider_config", &PyClass::provider_config) | ||
| 51 | .def_readwrite("tokens", &PyClass::tokens) | 56 | .def_readwrite("tokens", &PyClass::tokens) |
| 52 | .def_readwrite("num_threads", &PyClass::num_threads) | 57 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 58 | + .def_readwrite("warm_up", &PyClass::warm_up) | ||
| 53 | .def_readwrite("debug", &PyClass::debug) | 59 | .def_readwrite("debug", &PyClass::debug) |
| 54 | - .def_readwrite("provider", &PyClass::provider) | ||
| 55 | .def_readwrite("model_type", &PyClass::model_type) | 60 | .def_readwrite("model_type", &PyClass::model_type) |
| 56 | .def_readwrite("modeling_unit", &PyClass::modeling_unit) | 61 | .def_readwrite("modeling_unit", &PyClass::modeling_unit) |
| 57 | .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) | 62 | .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) |
| 58 | .def("validate", &PyClass::Validate) | 63 | .def("validate", &PyClass::Validate) |
| 59 | .def("__str__", &PyClass::ToString); | 64 | .def("__str__", &PyClass::ToString); |
| 60 | } | 65 | } |
| 61 | - | ||
| 62 | } // namespace sherpa_onnx | 66 | } // namespace sherpa_onnx |
sherpa-onnx/python/csrc/provider-config.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/provider-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +#include "sherpa-onnx/python/csrc/provider-config.h" | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 11 | +#include "sherpa-onnx/python/csrc/cuda-config.h" | ||
| 12 | +#include "sherpa-onnx/python/csrc/tensorrt-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +void PybindProviderConfig(py::module *m) { | ||
| 17 | + PybindCudaConfig(m); | ||
| 18 | + PybindTensorrtConfig(m); | ||
| 19 | + | ||
| 20 | + using PyClass = ProviderConfig; | ||
| 21 | + py::class_<PyClass>(*m, "ProviderConfig") | ||
| 22 | + .def(py::init<>()) | ||
| 23 | + .def(py::init<const std::string &, int32_t>(), | ||
| 24 | + py::arg("provider") = "cpu", | ||
| 25 | + py::arg("device") = 0) | ||
| 26 | + .def(py::init<const TensorrtConfig &, const CudaConfig &, | ||
| 27 | + const std::string &, int32_t>(), | ||
| 28 | + py::arg("trt_config") = TensorrtConfig{}, | ||
| 29 | + py::arg("cuda_config") = CudaConfig{}, | ||
| 30 | + py::arg("provider") = "cpu", | ||
| 31 | + py::arg("device") = 0) | ||
| 32 | + .def_readwrite("trt_config", &PyClass::trt_config) | ||
| 33 | + .def_readwrite("cuda_config", &PyClass::cuda_config) | ||
| 34 | + .def_readwrite("provider", &PyClass::provider) | ||
| 35 | + .def_readwrite("device", &PyClass::device) | ||
| 36 | + .def("__str__", &PyClass::ToString) | ||
| 37 | + .def("validate", &PyClass::Validate); | ||
| 38 | +} | ||
| 39 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/provider-config.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/provider-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindProviderConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ |
| @@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 51 | PybindEndpoint(&m); | 51 | PybindEndpoint(&m); |
| 52 | PybindOnlineRecognizer(&m); | 52 | PybindOnlineRecognizer(&m); |
| 53 | PybindKeywordSpotter(&m); | 53 | PybindKeywordSpotter(&m); |
| 54 | - | ||
| 55 | PybindDisplay(&m); | 54 | PybindDisplay(&m); |
| 56 | 55 | ||
| 57 | PybindOfflineStream(&m); | 56 | PybindOfflineStream(&m); |
sherpa-onnx/python/csrc/tensorrt-config.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/tensorrt-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/tensorrt-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <memory> | ||
| 9 | +#include "sherpa-onnx/csrc/provider-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindTensorrtConfig(py::module *m) { | ||
| 14 | + using PyClass = TensorrtConfig; | ||
| 15 | + py::class_<PyClass>(*m, "TensorrtConfig") | ||
| 16 | + .def(py::init<>()) | ||
| 17 | + .def(py::init([](int32_t trt_max_workspace_size, | ||
| 18 | + int32_t trt_max_partition_iterations, | ||
| 19 | + int32_t trt_min_subgraph_size, | ||
| 20 | + bool trt_fp16_enable, | ||
| 21 | + bool trt_detailed_build_log, | ||
| 22 | + bool trt_engine_cache_enable, | ||
| 23 | + bool trt_timing_cache_enable, | ||
| 24 | + const std::string &trt_engine_cache_path, | ||
| 25 | + const std::string &trt_timing_cache_path, | ||
| 26 | + bool trt_dump_subgraphs) -> std::unique_ptr<PyClass> { | ||
| 27 | + auto ans = std::make_unique<PyClass>(); | ||
| 28 | + | ||
| 29 | + ans->trt_max_workspace_size = trt_max_workspace_size; | ||
| 30 | + ans->trt_max_partition_iterations = trt_max_partition_iterations; | ||
| 31 | + ans->trt_min_subgraph_size = trt_min_subgraph_size; | ||
| 32 | + ans->trt_fp16_enable = trt_fp16_enable; | ||
| 33 | + ans->trt_detailed_build_log = trt_detailed_build_log; | ||
| 34 | + ans->trt_engine_cache_enable = trt_engine_cache_enable; | ||
| 35 | + ans->trt_timing_cache_enable = trt_timing_cache_enable; | ||
| 36 | + ans->trt_engine_cache_path = trt_engine_cache_path; | ||
| 37 | + ans->trt_timing_cache_path = trt_timing_cache_path; | ||
| 38 | + ans->trt_dump_subgraphs = trt_dump_subgraphs; | ||
| 39 | + | ||
| 40 | + return ans; | ||
| 41 | + }), | ||
| 42 | + py::arg("trt_max_workspace_size") = 2147483647, | ||
| 43 | + py::arg("trt_max_partition_iterations") = 10, | ||
| 44 | + py::arg("trt_min_subgraph_size") = 5, | ||
| 45 | + py::arg("trt_fp16_enable") = true, | ||
| 46 | + py::arg("trt_detailed_build_log") = false, | ||
| 47 | + py::arg("trt_engine_cache_enable") = true, | ||
| 48 | + py::arg("trt_timing_cache_enable") = true, | ||
| 49 | + py::arg("trt_engine_cache_path") = ".", | ||
| 50 | + py::arg("trt_timing_cache_path") = ".", | ||
| 51 | + py::arg("trt_dump_subgraphs") = false) | ||
| 52 | + | ||
| 53 | + .def_readwrite("trt_max_workspace_size", | ||
| 54 | + &PyClass::trt_max_workspace_size) | ||
| 55 | + .def_readwrite("trt_max_partition_iterations", | ||
| 56 | + &PyClass::trt_max_partition_iterations) | ||
| 57 | + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) | ||
| 58 | + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) | ||
| 59 | + .def_readwrite("trt_detailed_build_log", | ||
| 60 | + &PyClass::trt_detailed_build_log) | ||
| 61 | + .def_readwrite("trt_engine_cache_enable", | ||
| 62 | + &PyClass::trt_engine_cache_enable) | ||
| 63 | + .def_readwrite("trt_timing_cache_enable", | ||
| 64 | + &PyClass::trt_timing_cache_enable) | ||
| 65 | + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) | ||
| 66 | + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) | ||
| 67 | + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) | ||
| 68 | + .def("__str__", &PyClass::ToString) | ||
| 69 | + .def("validate", &PyClass::Validate); | ||
| 70 | +} | ||
| 71 | + | ||
| 72 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/tensorrt-config.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/tensorrt-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Uniphore (Author: Manickavela A) | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindTensorrtConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ |
| @@ -9,6 +9,7 @@ from _sherpa_onnx import ( | @@ -9,6 +9,7 @@ from _sherpa_onnx import ( | ||
| 9 | OnlineModelConfig, | 9 | OnlineModelConfig, |
| 10 | OnlineTransducerModelConfig, | 10 | OnlineTransducerModelConfig, |
| 11 | OnlineStream, | 11 | OnlineStream, |
| 12 | + ProviderConfig, | ||
| 12 | ) | 13 | ) |
| 13 | 14 | ||
| 14 | from _sherpa_onnx import KeywordSpotter as _KeywordSpotter | 15 | from _sherpa_onnx import KeywordSpotter as _KeywordSpotter |
| @@ -41,6 +42,7 @@ class KeywordSpotter(object): | @@ -41,6 +42,7 @@ class KeywordSpotter(object): | ||
| 41 | keywords_threshold: float = 0.25, | 42 | keywords_threshold: float = 0.25, |
| 42 | num_trailing_blanks: int = 1, | 43 | num_trailing_blanks: int = 1, |
| 43 | provider: str = "cpu", | 44 | provider: str = "cpu", |
| 45 | + device: int = 0, | ||
| 44 | ): | 46 | ): |
| 45 | """ | 47 | """ |
| 46 | Please refer to | 48 | Please refer to |
| @@ -85,6 +87,8 @@ class KeywordSpotter(object): | @@ -85,6 +87,8 @@ class KeywordSpotter(object): | ||
| 85 | between each other. | 87 | between each other. |
| 86 | provider: | 88 | provider: |
| 87 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 89 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 90 | + device: | ||
| 91 | + onnxruntime cuda device index. | ||
| 88 | """ | 92 | """ |
| 89 | _assert_file_exists(tokens) | 93 | _assert_file_exists(tokens) |
| 90 | _assert_file_exists(encoder) | 94 | _assert_file_exists(encoder) |
| @@ -99,11 +103,16 @@ class KeywordSpotter(object): | @@ -99,11 +103,16 @@ class KeywordSpotter(object): | ||
| 99 | joiner=joiner, | 103 | joiner=joiner, |
| 100 | ) | 104 | ) |
| 101 | 105 | ||
| 106 | + provider_config = ProviderConfig( | ||
| 107 | + provider=provider, | ||
| 108 | + device = device, | ||
| 109 | + ) | ||
| 110 | + | ||
| 102 | model_config = OnlineModelConfig( | 111 | model_config = OnlineModelConfig( |
| 103 | transducer=transducer_config, | 112 | transducer=transducer_config, |
| 104 | tokens=tokens, | 113 | tokens=tokens, |
| 105 | num_threads=num_threads, | 114 | num_threads=num_threads, |
| 106 | - provider=provider, | 115 | + provider_config=provider_config, |
| 107 | ) | 116 | ) |
| 108 | 117 | ||
| 109 | feat_config = FeatureExtractorConfig( | 118 | feat_config = FeatureExtractorConfig( |
| @@ -11,6 +11,9 @@ from _sherpa_onnx import ( | @@ -11,6 +11,9 @@ from _sherpa_onnx import ( | ||
| 11 | ) | 11 | ) |
| 12 | from _sherpa_onnx import OnlineRecognizer as _Recognizer | 12 | from _sherpa_onnx import OnlineRecognizer as _Recognizer |
| 13 | from _sherpa_onnx import ( | 13 | from _sherpa_onnx import ( |
| 14 | + CudaConfig, | ||
| 15 | + TensorrtConfig, | ||
| 16 | + ProviderConfig, | ||
| 14 | OnlineRecognizerConfig, | 17 | OnlineRecognizerConfig, |
| 15 | OnlineRecognizerResult, | 18 | OnlineRecognizerResult, |
| 16 | OnlineStream, | 19 | OnlineStream, |
| @@ -56,7 +59,6 @@ class OnlineRecognizer(object): | @@ -56,7 +59,6 @@ class OnlineRecognizer(object): | ||
| 56 | hotwords_score: float = 1.5, | 59 | hotwords_score: float = 1.5, |
| 57 | blank_penalty: float = 0.0, | 60 | blank_penalty: float = 0.0, |
| 58 | hotwords_file: str = "", | 61 | hotwords_file: str = "", |
| 59 | - provider: str = "cpu", | ||
| 60 | model_type: str = "", | 62 | model_type: str = "", |
| 61 | modeling_unit: str = "cjkchar", | 63 | modeling_unit: str = "cjkchar", |
| 62 | bpe_vocab: str = "", | 64 | bpe_vocab: str = "", |
| @@ -66,6 +68,19 @@ class OnlineRecognizer(object): | @@ -66,6 +68,19 @@ class OnlineRecognizer(object): | ||
| 66 | debug: bool = False, | 68 | debug: bool = False, |
| 67 | rule_fsts: str = "", | 69 | rule_fsts: str = "", |
| 68 | rule_fars: str = "", | 70 | rule_fars: str = "", |
| 71 | + provider: str = "cpu", | ||
| 72 | + device: int = 0, | ||
| 73 | + cudnn_conv_algo_search: int = 1, | ||
| 74 | + trt_max_workspace_size: int = 2147483647, | ||
| 75 | + trt_max_partition_iterations: int = 10, | ||
| 76 | + trt_min_subgraph_size: int = 5, | ||
| 77 | + trt_fp16_enable: bool = True, | ||
| 78 | + trt_detailed_build_log: bool = False, | ||
| 79 | + trt_engine_cache_enable: bool = True, | ||
| 80 | + trt_timing_cache_enable: bool = True, | ||
| 81 | + trt_engine_cache_path: str ="", | ||
| 82 | + trt_timing_cache_path: str ="", | ||
| 83 | + trt_dump_subgraphs: bool = False, | ||
| 69 | ): | 84 | ): |
| 70 | """ | 85 | """ |
| 71 | Please refer to | 86 | Please refer to |
| @@ -135,8 +150,6 @@ class OnlineRecognizer(object): | @@ -135,8 +150,6 @@ class OnlineRecognizer(object): | ||
| 135 | Temperature scaling for output symbol confidence estiamation. | 150 | Temperature scaling for output symbol confidence estiamation. |
| 136 | It affects only confidence values, the decoding uses the original | 151 | It affects only confidence values, the decoding uses the original |
| 137 | logits without temperature. | 152 | logits without temperature. |
| 138 | - provider: | ||
| 139 | - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 140 | model_type: | 153 | model_type: |
| 141 | Online transducer model type. Valid values are: conformer, lstm, | 154 | Online transducer model type. Valid values are: conformer, lstm, |
| 142 | zipformer, zipformer2. All other values lead to loading the model twice. | 155 | zipformer, zipformer2. All other values lead to loading the model twice. |
| @@ -156,6 +169,32 @@ class OnlineRecognizer(object): | @@ -156,6 +169,32 @@ class OnlineRecognizer(object): | ||
| 156 | rule_fars: | 169 | rule_fars: |
| 157 | If not empty, it specifies fst archives for inverse text normalization. | 170 | If not empty, it specifies fst archives for inverse text normalization. |
| 158 | If there are multiple archives, they are separated by a comma. | 171 | If there are multiple archives, they are separated by a comma. |
| 172 | + provider: | ||
| 173 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 174 | + device: | ||
| 175 | + onnxruntime cuda device index. | ||
| 176 | + cudnn_conv_algo_search: | ||
| 177 | + onxrt CuDNN convolution search algorithm selection. CUDA EP | ||
| 178 | + trt_max_workspace_size: | ||
| 179 | + Set TensorRT EP GPU memory usage limit. TensorRT EP | ||
| 180 | + trt_max_partition_iterations: | ||
| 181 | + Limit partitioning iterations for model conversion. TensorRT EP | ||
| 182 | + trt_min_subgraph_size: | ||
| 183 | + Set minimum size for subgraphs in partitioning. TensorRT EP | ||
| 184 | + trt_fp16_enable: bool = True, | ||
| 185 | + Enable FP16 precision for faster performance. TensorRT EP | ||
| 186 | + trt_detailed_build_log: bool = False, | ||
| 187 | + Enable detailed logging of build steps. TensorRT EP | ||
| 188 | + trt_engine_cache_enable: bool = True, | ||
| 189 | + Enable caching of TensorRT engines. TensorRT EP | ||
| 190 | + trt_timing_cache_enable: bool = True, | ||
| 191 | + "Enable use of timing cache to speed up builds." TensorRT EP | ||
| 192 | + trt_engine_cache_path: str ="", | ||
| 193 | + "Set path to store cached TensorRT engines." TensorRT EP | ||
| 194 | + trt_timing_cache_path: str ="", | ||
| 195 | + "Set path for storing timing cache." TensorRT EP | ||
| 196 | + trt_dump_subgraphs: bool = False, | ||
| 197 | + "Dump optimized subgraphs for debugging." TensorRT EP | ||
| 159 | """ | 198 | """ |
| 160 | self = cls.__new__(cls) | 199 | self = cls.__new__(cls) |
| 161 | _assert_file_exists(tokens) | 200 | _assert_file_exists(tokens) |
| @@ -171,11 +210,35 @@ class OnlineRecognizer(object): | @@ -171,11 +210,35 @@ class OnlineRecognizer(object): | ||
| 171 | joiner=joiner, | 210 | joiner=joiner, |
| 172 | ) | 211 | ) |
| 173 | 212 | ||
| 213 | + cuda_config = CudaConfig( | ||
| 214 | + cudnn_conv_algo_search=cudnn_conv_algo_search, | ||
| 215 | + ) | ||
| 216 | + | ||
| 217 | + trt_config = TensorrtConfig( | ||
| 218 | + trt_max_workspace_size=trt_max_workspace_size, | ||
| 219 | + trt_max_partition_iterations=trt_max_partition_iterations, | ||
| 220 | + trt_min_subgraph_size=trt_min_subgraph_size, | ||
| 221 | + trt_fp16_enable=trt_fp16_enable, | ||
| 222 | + trt_detailed_build_log=trt_detailed_build_log, | ||
| 223 | + trt_engine_cache_enable=trt_engine_cache_enable, | ||
| 224 | + trt_timing_cache_enable=trt_timing_cache_enable, | ||
| 225 | + trt_engine_cache_path=trt_engine_cache_path, | ||
| 226 | + trt_timing_cache_path=trt_timing_cache_path, | ||
| 227 | + trt_dump_subgraphs=trt_dump_subgraphs, | ||
| 228 | + ) | ||
| 229 | + | ||
| 230 | + provider_config = ProviderConfig( | ||
| 231 | + trt_config=trt_config, | ||
| 232 | + cuda_config=cuda_config, | ||
| 233 | + provider=provider, | ||
| 234 | + device=device, | ||
| 235 | + ) | ||
| 236 | + | ||
| 174 | model_config = OnlineModelConfig( | 237 | model_config = OnlineModelConfig( |
| 175 | transducer=transducer_config, | 238 | transducer=transducer_config, |
| 176 | tokens=tokens, | 239 | tokens=tokens, |
| 177 | num_threads=num_threads, | 240 | num_threads=num_threads, |
| 178 | - provider=provider, | 241 | + provider_config=provider_config, |
| 179 | model_type=model_type, | 242 | model_type=model_type, |
| 180 | modeling_unit=modeling_unit, | 243 | modeling_unit=modeling_unit, |
| 181 | bpe_vocab=bpe_vocab, | 244 | bpe_vocab=bpe_vocab, |
| @@ -251,6 +314,7 @@ class OnlineRecognizer(object): | @@ -251,6 +314,7 @@ class OnlineRecognizer(object): | ||
| 251 | debug: bool = False, | 314 | debug: bool = False, |
| 252 | rule_fsts: str = "", | 315 | rule_fsts: str = "", |
| 253 | rule_fars: str = "", | 316 | rule_fars: str = "", |
| 317 | + device: int = 0, | ||
| 254 | ): | 318 | ): |
| 255 | """ | 319 | """ |
| 256 | Please refer to | 320 | Please refer to |
| @@ -301,6 +365,8 @@ class OnlineRecognizer(object): | @@ -301,6 +365,8 @@ class OnlineRecognizer(object): | ||
| 301 | rule_fars: | 365 | rule_fars: |
| 302 | If not empty, it specifies fst archives for inverse text normalization. | 366 | If not empty, it specifies fst archives for inverse text normalization. |
| 303 | If there are multiple archives, they are separated by a comma. | 367 | If there are multiple archives, they are separated by a comma. |
| 368 | + device: | ||
| 369 | + onnxruntime cuda device index. | ||
| 304 | """ | 370 | """ |
| 305 | self = cls.__new__(cls) | 371 | self = cls.__new__(cls) |
| 306 | _assert_file_exists(tokens) | 372 | _assert_file_exists(tokens) |
| @@ -314,11 +380,16 @@ class OnlineRecognizer(object): | @@ -314,11 +380,16 @@ class OnlineRecognizer(object): | ||
| 314 | decoder=decoder, | 380 | decoder=decoder, |
| 315 | ) | 381 | ) |
| 316 | 382 | ||
| 383 | + provider_config = ProviderConfig( | ||
| 384 | + provider=provider, | ||
| 385 | + device=device, | ||
| 386 | + ) | ||
| 387 | + | ||
| 317 | model_config = OnlineModelConfig( | 388 | model_config = OnlineModelConfig( |
| 318 | paraformer=paraformer_config, | 389 | paraformer=paraformer_config, |
| 319 | tokens=tokens, | 390 | tokens=tokens, |
| 320 | num_threads=num_threads, | 391 | num_threads=num_threads, |
| 321 | - provider=provider, | 392 | + provider_config=provider_config, |
| 322 | model_type="paraformer", | 393 | model_type="paraformer", |
| 323 | debug=debug, | 394 | debug=debug, |
| 324 | ) | 395 | ) |
| @@ -367,6 +438,7 @@ class OnlineRecognizer(object): | @@ -367,6 +438,7 @@ class OnlineRecognizer(object): | ||
| 367 | debug: bool = False, | 438 | debug: bool = False, |
| 368 | rule_fsts: str = "", | 439 | rule_fsts: str = "", |
| 369 | rule_fars: str = "", | 440 | rule_fars: str = "", |
| 441 | + device: int = 0, | ||
| 370 | ): | 442 | ): |
| 371 | """ | 443 | """ |
| 372 | Please refer to | 444 | Please refer to |
| @@ -421,6 +493,8 @@ class OnlineRecognizer(object): | @@ -421,6 +493,8 @@ class OnlineRecognizer(object): | ||
| 421 | rule_fars: | 493 | rule_fars: |
| 422 | If not empty, it specifies fst archives for inverse text normalization. | 494 | If not empty, it specifies fst archives for inverse text normalization. |
| 423 | If there are multiple archives, they are separated by a comma. | 495 | If there are multiple archives, they are separated by a comma. |
| 496 | + device: | ||
| 497 | + onnxruntime cuda device index. | ||
| 424 | """ | 498 | """ |
| 425 | self = cls.__new__(cls) | 499 | self = cls.__new__(cls) |
| 426 | _assert_file_exists(tokens) | 500 | _assert_file_exists(tokens) |
| @@ -430,11 +504,16 @@ class OnlineRecognizer(object): | @@ -430,11 +504,16 @@ class OnlineRecognizer(object): | ||
| 430 | 504 | ||
| 431 | zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) | 505 | zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) |
| 432 | 506 | ||
| 507 | + provider_config = ProviderConfig( | ||
| 508 | + provider=provider, | ||
| 509 | + device=device, | ||
| 510 | + ) | ||
| 511 | + | ||
| 433 | model_config = OnlineModelConfig( | 512 | model_config = OnlineModelConfig( |
| 434 | zipformer2_ctc=zipformer2_ctc_config, | 513 | zipformer2_ctc=zipformer2_ctc_config, |
| 435 | tokens=tokens, | 514 | tokens=tokens, |
| 436 | num_threads=num_threads, | 515 | num_threads=num_threads, |
| 437 | - provider=provider, | 516 | + provider_config=provider_config, |
| 438 | debug=debug, | 517 | debug=debug, |
| 439 | ) | 518 | ) |
| 440 | 519 | ||
| @@ -486,6 +565,7 @@ class OnlineRecognizer(object): | @@ -486,6 +565,7 @@ class OnlineRecognizer(object): | ||
| 486 | debug: bool = False, | 565 | debug: bool = False, |
| 487 | rule_fsts: str = "", | 566 | rule_fsts: str = "", |
| 488 | rule_fars: str = "", | 567 | rule_fars: str = "", |
| 568 | + device: int = 0, | ||
| 489 | ): | 569 | ): |
| 490 | """ | 570 | """ |
| 491 | Please refer to | 571 | Please refer to |
| @@ -535,6 +615,8 @@ class OnlineRecognizer(object): | @@ -535,6 +615,8 @@ class OnlineRecognizer(object): | ||
| 535 | rule_fars: | 615 | rule_fars: |
| 536 | If not empty, it specifies fst archives for inverse text normalization. | 616 | If not empty, it specifies fst archives for inverse text normalization. |
| 537 | If there are multiple archives, they are separated by a comma. | 617 | If there are multiple archives, they are separated by a comma. |
| 618 | + device: | ||
| 619 | + onnxruntime cuda device index. | ||
| 538 | """ | 620 | """ |
| 539 | self = cls.__new__(cls) | 621 | self = cls.__new__(cls) |
| 540 | _assert_file_exists(tokens) | 622 | _assert_file_exists(tokens) |
| @@ -546,11 +628,16 @@ class OnlineRecognizer(object): | @@ -546,11 +628,16 @@ class OnlineRecognizer(object): | ||
| 546 | model=model, | 628 | model=model, |
| 547 | ) | 629 | ) |
| 548 | 630 | ||
| 631 | + provider_config = ProviderConfig( | ||
| 632 | + provider=provider, | ||
| 633 | + device=device, | ||
| 634 | + ) | ||
| 635 | + | ||
| 549 | model_config = OnlineModelConfig( | 636 | model_config = OnlineModelConfig( |
| 550 | nemo_ctc=nemo_ctc_config, | 637 | nemo_ctc=nemo_ctc_config, |
| 551 | tokens=tokens, | 638 | tokens=tokens, |
| 552 | num_threads=num_threads, | 639 | num_threads=num_threads, |
| 553 | - provider=provider, | 640 | + provider_config=provider_config, |
| 554 | debug=debug, | 641 | debug=debug, |
| 555 | ) | 642 | ) |
| 556 | 643 | ||
| @@ -598,6 +685,7 @@ class OnlineRecognizer(object): | @@ -598,6 +685,7 @@ class OnlineRecognizer(object): | ||
| 598 | debug: bool = False, | 685 | debug: bool = False, |
| 599 | rule_fsts: str = "", | 686 | rule_fsts: str = "", |
| 600 | rule_fars: str = "", | 687 | rule_fars: str = "", |
| 688 | + device: int = 0, | ||
| 601 | ): | 689 | ): |
| 602 | """ | 690 | """ |
| 603 | Please refer to | 691 | Please refer to |
| @@ -650,6 +738,8 @@ class OnlineRecognizer(object): | @@ -650,6 +738,8 @@ class OnlineRecognizer(object): | ||
| 650 | rule_fars: | 738 | rule_fars: |
| 651 | If not empty, it specifies fst archives for inverse text normalization. | 739 | If not empty, it specifies fst archives for inverse text normalization. |
| 652 | If there are multiple archives, they are separated by a comma. | 740 | If there are multiple archives, they are separated by a comma. |
| 741 | + device: | ||
| 742 | + onnxruntime cuda device index. | ||
| 653 | """ | 743 | """ |
| 654 | self = cls.__new__(cls) | 744 | self = cls.__new__(cls) |
| 655 | _assert_file_exists(tokens) | 745 | _assert_file_exists(tokens) |
| @@ -663,11 +753,16 @@ class OnlineRecognizer(object): | @@ -663,11 +753,16 @@ class OnlineRecognizer(object): | ||
| 663 | num_left_chunks=num_left_chunks, | 753 | num_left_chunks=num_left_chunks, |
| 664 | ) | 754 | ) |
| 665 | 755 | ||
| 756 | + provider_config = ProviderConfig( | ||
| 757 | + provider=provider, | ||
| 758 | + device=device, | ||
| 759 | + ) | ||
| 760 | + | ||
| 666 | model_config = OnlineModelConfig( | 761 | model_config = OnlineModelConfig( |
| 667 | wenet_ctc=wenet_ctc_config, | 762 | wenet_ctc=wenet_ctc_config, |
| 668 | tokens=tokens, | 763 | tokens=tokens, |
| 669 | num_threads=num_threads, | 764 | num_threads=num_threads, |
| 670 | - provider=provider, | 765 | + provider_config=provider_config, |
| 671 | debug=debug, | 766 | debug=debug, |
| 672 | ) | 767 | ) |
| 673 | 768 |
-
请 注册 或 登录 后发表评论