Manix
Committed by GitHub

Add config for TensorRT and CUDA execution provider (#992)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
@@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
73 SHERPA_ONNX_OR(config->model_config.tokens, ""); 73 SHERPA_ONNX_OR(config->model_config.tokens, "");
74 recognizer_config.model_config.num_threads = 74 recognizer_config.model_config.num_threads =
75 SHERPA_ONNX_OR(config->model_config.num_threads, 1); 75 SHERPA_ONNX_OR(config->model_config.num_threads, 1);
76 - recognizer_config.model_config.provider = 76 + recognizer_config.model_config.provider_config.provider =
77 SHERPA_ONNX_OR(config->model_config.provider, "cpu"); 77 SHERPA_ONNX_OR(config->model_config.provider, "cpu");
78 recognizer_config.model_config.model_type = 78 recognizer_config.model_config.model_type =
79 SHERPA_ONNX_OR(config->model_config.model_type, ""); 79 SHERPA_ONNX_OR(config->model_config.model_type, "");
@@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter( @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
570 SHERPA_ONNX_OR(config->model_config.tokens, ""); 570 SHERPA_ONNX_OR(config->model_config.tokens, "");
571 spotter_config.model_config.num_threads = 571 spotter_config.model_config.num_threads =
572 SHERPA_ONNX_OR(config->model_config.num_threads, 1); 572 SHERPA_ONNX_OR(config->model_config.num_threads, 1);
573 - spotter_config.model_config.provider = 573 + spotter_config.model_config.provider_config.provider =
574 SHERPA_ONNX_OR(config->model_config.provider, "cpu"); 574 SHERPA_ONNX_OR(config->model_config.provider, "cpu");
575 spotter_config.model_config.model_type = 575 spotter_config.model_config.model_type =
576 SHERPA_ONNX_OR(config->model_config.model_type, ""); 576 SHERPA_ONNX_OR(config->model_config.model_type, "");
@@ -87,6 +87,7 @@ set(sources @@ -87,6 +87,7 @@ set(sources
87 packed-sequence.cc 87 packed-sequence.cc
88 pad-sequence.cc 88 pad-sequence.cc
89 parse-options.cc 89 parse-options.cc
  90 + provider-config.cc
90 provider.cc 91 provider.cc
91 resample.cc 92 resample.cc
92 session.cc 93 session.cc
@@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
16 wenet_ctc.Register(po); 16 wenet_ctc.Register(po);
17 zipformer2_ctc.Register(po); 17 zipformer2_ctc.Register(po);
18 nemo_ctc.Register(po); 18 nemo_ctc.Register(po);
  19 + provider_config.Register(po);
19 20
20 po->Register("tokens", &tokens, "Path to tokens.txt"); 21 po->Register("tokens", &tokens, "Path to tokens.txt");
21 22
@@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) { @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) {
29 po->Register("debug", &debug, 30 po->Register("debug", &debug,
30 "true to print model information while loading it."); 31 "true to print model information while loading it.");
31 32
32 - po->Register("provider", &provider,  
33 - "Specify a provider to use: cpu, cuda, coreml");  
34 -  
35 po->Register("modeling-unit", &modeling_unit, 33 po->Register("modeling-unit", &modeling_unit,
36 "The modeling unit of the model, commonly used units are bpe, " 34 "The modeling unit of the model, commonly used units are bpe, "
37 "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " 35 "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
@@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const { @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const {
87 return nemo_ctc.Validate(); 85 return nemo_ctc.Validate();
88 } 86 }
89 87
  88 + if (!provider_config.Validate()) {
  89 + return false;
  90 + }
  91 +
90 return transducer.Validate(); 92 return transducer.Validate();
91 } 93 }
92 94
@@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const { @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const {
99 os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; 101 os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
100 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; 102 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
101 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 103 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
  104 + os << "provider_config=" << provider_config.ToString() << ", ";
102 os << "tokens=\"" << tokens << "\", "; 105 os << "tokens=\"" << tokens << "\", ";
103 os << "num_threads=" << num_threads << ", "; 106 os << "num_threads=" << num_threads << ", ";
104 os << "warm_up=" << warm_up << ", "; 107 os << "warm_up=" << warm_up << ", ";
105 os << "debug=" << (debug ? "True" : "False") << ", "; 108 os << "debug=" << (debug ? "True" : "False") << ", ";
106 - os << "provider=\"" << provider << "\", ";  
107 os << "model_type=\"" << model_type << "\", "; 109 os << "model_type=\"" << model_type << "\", ";
108 os << "modeling_unit=\"" << modeling_unit << "\", "; 110 os << "modeling_unit=\"" << modeling_unit << "\", ";
109 os << "bpe_vocab=\"" << bpe_vocab << "\")"; 111 os << "bpe_vocab=\"" << bpe_vocab << "\")";
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
12 #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" 12 #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
13 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" 13 #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
  14 +#include "sherpa-onnx/csrc/provider-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
@@ -20,11 +21,11 @@ struct OnlineModelConfig { @@ -20,11 +21,11 @@ struct OnlineModelConfig {
20 OnlineWenetCtcModelConfig wenet_ctc; 21 OnlineWenetCtcModelConfig wenet_ctc;
21 OnlineZipformer2CtcModelConfig zipformer2_ctc; 22 OnlineZipformer2CtcModelConfig zipformer2_ctc;
22 OnlineNeMoCtcModelConfig nemo_ctc; 23 OnlineNeMoCtcModelConfig nemo_ctc;
  24 + ProviderConfig provider_config;
23 std::string tokens; 25 std::string tokens;
24 int32_t num_threads = 1; 26 int32_t num_threads = 1;
25 int32_t warm_up = 0; 27 int32_t warm_up = 0;
26 bool debug = false; 28 bool debug = false;
27 - std::string provider = "cpu";  
28 29
29 // Valid values: 30 // Valid values:
30 // - conformer, conformer transducer from icefall 31 // - conformer, conformer transducer from icefall
@@ -50,8 +51,9 @@ struct OnlineModelConfig { @@ -50,8 +51,9 @@ struct OnlineModelConfig {
50 const OnlineWenetCtcModelConfig &wenet_ctc, 51 const OnlineWenetCtcModelConfig &wenet_ctc,
51 const OnlineZipformer2CtcModelConfig &zipformer2_ctc, 52 const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
52 const OnlineNeMoCtcModelConfig &nemo_ctc, 53 const OnlineNeMoCtcModelConfig &nemo_ctc,
  54 + const ProviderConfig &provider_config,
53 const std::string &tokens, int32_t num_threads, 55 const std::string &tokens, int32_t num_threads,
54 - int32_t warm_up, bool debug, const std::string &provider, 56 + int32_t warm_up, bool debug,
55 const std::string &model_type, 57 const std::string &model_type,
56 const std::string &modeling_unit, 58 const std::string &modeling_unit,
57 const std::string &bpe_vocab) 59 const std::string &bpe_vocab)
@@ -60,11 +62,11 @@ struct OnlineModelConfig { @@ -60,11 +62,11 @@ struct OnlineModelConfig {
60 wenet_ctc(wenet_ctc), 62 wenet_ctc(wenet_ctc),
61 zipformer2_ctc(zipformer2_ctc), 63 zipformer2_ctc(zipformer2_ctc),
62 nemo_ctc(nemo_ctc), 64 nemo_ctc(nemo_ctc),
  65 + provider_config(provider_config),
63 tokens(tokens), 66 tokens(tokens),
64 num_threads(num_threads), 67 num_threads(num_threads),
65 warm_up(warm_up), 68 warm_up(warm_up),
66 debug(debug), 69 debug(debug),
67 - provider(provider),  
68 model_type(model_type), 70 model_type(model_type),
69 modeling_unit(modeling_unit), 71 modeling_unit(modeling_unit),
70 bpe_vocab(bpe_vocab) {} 72 bpe_vocab(bpe_vocab) {}
  1 +// sherpa-onnx/csrc/provider-config.cc
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela)
  4 +
  5 +#include "sherpa-onnx/csrc/provider-config.h"
  6 +
  7 +#include <sstream>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void CudaConfig::Register(ParseOptions *po) {
  15 + po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search,
  16 + "CuDNN convolution algrorithm search");
  17 +}
  18 +
  19 +bool CudaConfig::Validate() const {
  20 + if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) {
  21 + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option."
  22 + "Options : [1,3]. Check OnnxRT docs",
  23 + cudnn_conv_algo_search);
  24 + return false;
  25 + }
  26 + return true;
  27 +}
  28 +
  29 +std::string CudaConfig::ToString() const {
  30 + std::ostringstream os;
  31 +
  32 + os << "CudaConfig(";
  33 + os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +void TensorrtConfig::Register(ParseOptions *po) {
  39 + po->Register("trt-max-workspace-size", &trt_max_workspace_size,
  40 + "Set TensorRT EP GPU memory usage limit.");
  41 + po->Register("trt-max-partition-iterations", &trt_max_partition_iterations,
  42 + "Limit partitioning iterations for model conversion.");
  43 + po->Register("trt-min-subgraph-size", &trt_min_subgraph_size,
  44 + "Set minimum size for subgraphs in partitioning.");
  45 + po->Register("trt-fp16-enable", &trt_fp16_enable,
  46 + "Enable FP16 precision for faster performance.");
  47 + po->Register("trt-detailed-build-log", &trt_detailed_build_log,
  48 + "Enable detailed logging of build steps.");
  49 + po->Register("trt-engine-cache-enable", &trt_engine_cache_enable,
  50 + "Enable caching of TensorRT engines.");
  51 + po->Register("trt-timing-cache-enable", &trt_timing_cache_enable,
  52 + "Enable use of timing cache to speed up builds.");
  53 + po->Register("trt-engine-cache-path", &trt_engine_cache_path,
  54 + "Set path to store cached TensorRT engines.");
  55 + po->Register("trt-timing-cache-path", &trt_timing_cache_path,
  56 + "Set path for storing timing cache.");
  57 + po->Register("trt-dump-subgraphs", &trt_dump_subgraphs,
  58 + "Dump optimized subgraphs for debugging.");
  59 +}
  60 +
  61 +bool TensorrtConfig::Validate() const {
  62 + if (trt_max_workspace_size < 0) {
  63 + SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.",
  64 + trt_max_workspace_size);
  65 + return false;
  66 + }
  67 + if (trt_max_partition_iterations < 0) {
  68 + SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.",
  69 + trt_max_partition_iterations);
  70 + return false;
  71 + }
  72 + if (trt_min_subgraph_size < 0) {
  73 + SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.",
  74 + trt_min_subgraph_size);
  75 + return false;
  76 + }
  77 +
  78 + return true;
  79 +}
  80 +
  81 +std::string TensorrtConfig::ToString() const {
  82 + std::ostringstream os;
  83 +
  84 + os << "TensorrtConfig(";
  85 + os << "trt_max_workspace_size=" << trt_max_workspace_size << ", ";
  86 + os << "trt_max_partition_iterations="
  87 + << trt_max_partition_iterations << ", ";
  88 + os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", ";
  89 + os << "trt_fp16_enable=\""
  90 + << (trt_fp16_enable? "True" : "False") << "\", ";
  91 + os << "trt_detailed_build_log=\""
  92 + << (trt_detailed_build_log? "True" : "False") << "\", ";
  93 + os << "trt_engine_cache_enable=\""
  94 + << (trt_engine_cache_enable? "True" : "False") << "\", ";
  95 + os << "trt_engine_cache_path=\""
  96 + << trt_engine_cache_path.c_str() << "\", ";
  97 + os << "trt_timing_cache_enable=\""
  98 + << (trt_timing_cache_enable? "True" : "False") << "\", ";
  99 + os << "trt_timing_cache_path=\""
  100 + << trt_timing_cache_path.c_str() << "\",";
  101 + os << "trt_dump_subgraphs=\""
  102 + << (trt_dump_subgraphs? "True" : "False") << "\" )";
  103 + return os.str();
  104 +}
  105 +
  106 +void ProviderConfig::Register(ParseOptions *po) {
  107 + cuda_config.Register(po);
  108 + trt_config.Register(po);
  109 +
  110 + po->Register("device", &device, "GPU device index for CUDA and Trt EP");
  111 + po->Register("provider", &provider,
  112 + "Specify a provider to use: cpu, cuda, coreml");
  113 +}
  114 +
  115 +bool ProviderConfig::Validate() const {
  116 + if (device < 0) {
  117 + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device);
  118 + return false;
  119 + }
  120 +
  121 + if (provider == "cuda" && !cuda_config.Validate()) {
  122 + return false;
  123 + }
  124 +
  125 + if (provider == "trt" && !trt_config.Validate()) {
  126 + return false;
  127 + }
  128 +
  129 + return true;
  130 +}
  131 +
  132 +std::string ProviderConfig::ToString() const {
  133 + std::ostringstream os;
  134 +
  135 + os << "ProviderConfig(";
  136 + os << "device=" << device << ", ";
  137 + os << "provider=\"" << provider << "\", ";
  138 + os << "cuda_config=" << cuda_config.ToString() << ", ";
  139 + os << "trt_config=" << trt_config.ToString() << ")";
  140 + return os.str();
  141 +}
  142 +
  143 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/provider-config.h
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela)
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "onnxruntime_cxx_api.h" // NOLINT
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +struct CudaConfig {
  17 + int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;
  18 +
  19 + CudaConfig() = default;
  20 + explicit CudaConfig(int32_t cudnn_conv_algo_search)
  21 + : cudnn_conv_algo_search(cudnn_conv_algo_search) {}
  22 +
  23 + void Register(ParseOptions *po);
  24 + bool Validate() const;
  25 +
  26 + std::string ToString() const;
  27 +};
  28 +
  29 +struct TensorrtConfig {
  30 + int32_t trt_max_workspace_size = 2147483647;
  31 + int32_t trt_max_partition_iterations = 10;
  32 + int32_t trt_min_subgraph_size = 5;
  33 + bool trt_fp16_enable = true;
  34 + bool trt_detailed_build_log = false;
  35 + bool trt_engine_cache_enable = true;
  36 + bool trt_timing_cache_enable = true;
  37 + std::string trt_engine_cache_path = ".";
  38 + std::string trt_timing_cache_path = ".";
  39 + bool trt_dump_subgraphs = false;
  40 +
  41 + TensorrtConfig() = default;
  42 + TensorrtConfig(int32_t trt_max_workspace_size,
  43 + int32_t trt_max_partition_iterations,
  44 + int32_t trt_min_subgraph_size,
  45 + bool trt_fp16_enable,
  46 + bool trt_detailed_build_log,
  47 + bool trt_engine_cache_enable,
  48 + bool trt_timing_cache_enable,
  49 + const std::string &trt_engine_cache_path,
  50 + const std::string &trt_timing_cache_path,
  51 + bool trt_dump_subgraphs)
  52 + : trt_max_workspace_size(trt_max_workspace_size),
  53 + trt_max_partition_iterations(trt_max_partition_iterations),
  54 + trt_min_subgraph_size(trt_min_subgraph_size),
  55 + trt_fp16_enable(trt_fp16_enable),
  56 + trt_detailed_build_log(trt_detailed_build_log),
  57 + trt_engine_cache_enable(trt_engine_cache_enable),
  58 + trt_timing_cache_enable(trt_timing_cache_enable),
  59 + trt_engine_cache_path(trt_engine_cache_path),
  60 + trt_timing_cache_path(trt_timing_cache_path),
  61 + trt_dump_subgraphs(trt_dump_subgraphs) {}
  62 +
  63 + void Register(ParseOptions *po);
  64 + bool Validate() const;
  65 +
  66 + std::string ToString() const;
  67 +};
  68 +
  69 +struct ProviderConfig {
  70 + TensorrtConfig trt_config;
  71 + CudaConfig cuda_config;
  72 + std::string provider = "cpu";
  73 + int32_t device = 0;
  74 + // device only used for cuda and trt
  75 +
  76 + ProviderConfig() = default;
  77 + ProviderConfig(const std::string &provider,
  78 + int32_t device)
  79 + : provider(provider), device(device) {}
  80 + ProviderConfig(const TensorrtConfig &trt_config,
  81 + const CudaConfig &cuda_config,
  82 + const std::string &provider,
  83 + int32_t device)
  84 + : trt_config(trt_config), cuda_config(cuda_config),
  85 + provider(provider), device(device) {}
  86 +
  87 + void Register(ParseOptions *po);
  88 + bool Validate() const;
  89 +
  90 + std::string ToString() const;
  91 +};
  92 +
  93 +} // namespace sherpa_onnx
  94 +
  95 +#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
@@ -7,6 +7,7 @@ @@ -7,6 +7,7 @@
7 7
8 #include <string> 8 #include <string>
9 9
  10 +#include "sherpa-onnx/csrc/provider-config.h"
10 namespace sherpa_onnx { 11 namespace sherpa_onnx {
11 12
12 // Please refer to 13 // Please refer to
@@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { @@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
32 } 32 }
33 33
34 static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, 34 static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
35 - std::string provider_str) {  
36 - Provider p = StringToProvider(std::move(provider_str)); 35 + const std::string &provider_str,
  36 + const ProviderConfig *provider_config = nullptr) {
  37 + Provider p = StringToProvider(provider_str);
37 38
38 Ort::SessionOptions sess_opts; 39 Ort::SessionOptions sess_opts;
39 sess_opts.SetIntraOpNumThreads(num_threads); 40 sess_opts.SetIntraOpNumThreads(num_threads);
  41 +
40 sess_opts.SetInterOpNumThreads(num_threads); 42 sess_opts.SetInterOpNumThreads(num_threads);
41 43
42 std::vector<std::string> available_providers = Ort::GetAvailableProviders(); 44 std::vector<std::string> available_providers = Ort::GetAvailableProviders();
@@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -64,26 +66,51 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
64 break; 66 break;
65 } 67 }
66 case Provider::kTRT: { 68 case Provider::kTRT: {
  69 + if (provider_config == nullptr) {
  70 + SHERPA_ONNX_LOGE("Tensorrt support for Online models ony,"
  71 + "Must be extended for offline and others");
  72 + exit(1);
  73 + }
  74 + auto trt_config = provider_config->trt_config;
67 struct TrtPairs { 75 struct TrtPairs {
68 const char *op_keys; 76 const char *op_keys;
69 const char *op_values; 77 const char *op_values;
70 }; 78 };
71 79
  80 + auto device_id = std::to_string(provider_config->device);
  81 + auto trt_max_workspace_size =
  82 + std::to_string(trt_config.trt_max_workspace_size);
  83 + auto trt_max_partition_iterations =
  84 + std::to_string(trt_config.trt_max_partition_iterations);
  85 + auto trt_min_subgraph_size =
  86 + std::to_string(trt_config.trt_min_subgraph_size);
  87 + auto trt_fp16_enable =
  88 + std::to_string(trt_config.trt_fp16_enable);
  89 + auto trt_detailed_build_log =
  90 + std::to_string(trt_config.trt_detailed_build_log);
  91 + auto trt_engine_cache_enable =
  92 + std::to_string(trt_config.trt_engine_cache_enable);
  93 + auto trt_timing_cache_enable =
  94 + std::to_string(trt_config.trt_timing_cache_enable);
  95 + auto trt_dump_subgraphs =
  96 + std::to_string(trt_config.trt_dump_subgraphs);
  97 +
72 std::vector<TrtPairs> trt_options = { 98 std::vector<TrtPairs> trt_options = {
73 - {"device_id", "0"},  
74 - {"trt_max_workspace_size", "2147483648"},  
75 - {"trt_max_partition_iterations", "10"},  
76 - {"trt_min_subgraph_size", "5"},  
77 - {"trt_fp16_enable", "0"},  
78 - {"trt_detailed_build_log", "0"},  
79 - {"trt_engine_cache_enable", "1"},  
80 - {"trt_engine_cache_path", "."},  
81 - {"trt_timing_cache_enable", "1"},  
82 - {"trt_timing_cache_path", "."}}; 99 + {"device_id", device_id.c_str()},
  100 + {"trt_max_workspace_size", trt_max_workspace_size.c_str()},
  101 + {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()},
  102 + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()},
  103 + {"trt_fp16_enable", trt_fp16_enable.c_str()},
  104 + {"trt_detailed_build_log", trt_detailed_build_log.c_str()},
  105 + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()},
  106 + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()},
  107 + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()},
  108 + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()},
  109 + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}
  110 + };
83 // ToDo : Trt configs 111 // ToDo : Trt configs
84 // "trt_int8_enable" 112 // "trt_int8_enable"
85 // "trt_int8_use_native_calibration_table" 113 // "trt_int8_use_native_calibration_table"
86 - // "trt_dump_subgraphs"  
87 114
88 std::vector<const char *> option_keys, option_values; 115 std::vector<const char *> option_keys, option_values;
89 for (const TrtPairs &pair : trt_options) { 116 for (const TrtPairs &pair : trt_options) {
@@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -122,10 +149,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
122 "CUDAExecutionProvider") != available_providers.end()) { 149 "CUDAExecutionProvider") != available_providers.end()) {
123 // The CUDA provider is available, proceed with setting the options 150 // The CUDA provider is available, proceed with setting the options
124 OrtCUDAProviderOptions options; 151 OrtCUDAProviderOptions options;
125 - options.device_id = 0;  
126 - // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow  
127 - options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;  
128 - // set more options on need 152 +
  153 + if (provider_config != nullptr) {
  154 + options.device_id = provider_config->device;
  155 + options.cudnn_conv_algo_search =
  156 + OrtCudnnConvAlgoSearch(provider_config->cuda_config
  157 + .cudnn_conv_algo_search);
  158 + } else {
  159 + options.device_id = 0;
  160 + // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow
  161 + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;
  162 + // set more options on need
  163 + }
129 sess_opts.AppendExecutionProvider_CUDA(options); 164 sess_opts.AppendExecutionProvider_CUDA(options);
130 } else { 165 } else {
131 SHERPA_ONNX_LOGE( 166 SHERPA_ONNX_LOGE(
@@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -184,7 +219,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
184 } 219 }
185 220
186 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { 221 Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
187 - return GetSessionOptionsImpl(config.num_threads, config.provider); 222 + return GetSessionOptionsImpl(config.num_threads,
  223 + config.provider_config.provider, &config.provider_config);
188 } 224 }
189 225
190 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { 226 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
@@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { @@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
94 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); 94 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
95 s = (jstring)env->GetObjectField(model_config, fid); 95 s = (jstring)env->GetObjectField(model_config, fid);
96 p = env->GetStringUTFChars(s, nullptr); 96 p = env->GetStringUTFChars(s, nullptr);
97 - ans.model_config.provider = p; 97 + ans.model_config.provider_config.provider = p;
98 env->ReleaseStringUTFChars(s, p); 98 env->ReleaseStringUTFChars(s, p);
99 99
100 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); 100 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
@@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
198 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); 198 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
199 s = (jstring)env->GetObjectField(model_config, fid); 199 s = (jstring)env->GetObjectField(model_config, fid);
200 p = env->GetStringUTFChars(s, nullptr); 200 p = env->GetStringUTFChars(s, nullptr);
201 - ans.model_config.provider = p; 201 + ans.model_config.provider_config.provider = p;
202 env->ReleaseStringUTFChars(s, p); 202 env->ReleaseStringUTFChars(s, p);
203 203
204 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); 204 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
3 set(srcs 3 set(srcs
4 audio-tagging.cc 4 audio-tagging.cc
5 circular-buffer.cc 5 circular-buffer.cc
  6 + cuda-config.cc
6 display.cc 7 display.cc
7 endpoint.cc 8 endpoint.cc
8 features.cc 9 features.cc
@@ -30,11 +31,13 @@ set(srcs @@ -30,11 +31,13 @@ set(srcs
30 online-transducer-model-config.cc 31 online-transducer-model-config.cc
31 online-wenet-ctc-model-config.cc 32 online-wenet-ctc-model-config.cc
32 online-zipformer2-ctc-model-config.cc 33 online-zipformer2-ctc-model-config.cc
  34 + provider-config.cc
33 sherpa-onnx.cc 35 sherpa-onnx.cc
34 silero-vad-model-config.cc 36 silero-vad-model-config.cc
35 speaker-embedding-extractor.cc 37 speaker-embedding-extractor.cc
36 speaker-embedding-manager.cc 38 speaker-embedding-manager.cc
37 spoken-language-identification.cc 39 spoken-language-identification.cc
  40 + tensorrt-config.cc
38 vad-model-config.cc 41 vad-model-config.cc
39 vad-model.cc 42 vad-model.cc
40 voice-activity-detector.cc 43 voice-activity-detector.cc
  1 +// sherpa-onnx/python/csrc/cuda-config.cc
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +#include "sherpa-onnx/python/csrc/cuda-config.h"
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/provider-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindCudaConfig(py::module *m) {
  15 + using PyClass = CudaConfig;
  16 + py::class_<PyClass>(*m, "CudaConfig")
  17 + .def(py::init<>())
  18 + .def(py::init<int32_t>(),
  19 + py::arg("cudnn_conv_algo_search") = 1)
  20 + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search)
  21 + .def("__str__", &PyClass::ToString);
  22 +}
  23 +
  24 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/cuda-config.h
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindCudaConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_
@@ -9,11 +9,13 @@ @@ -9,11 +9,13 @@
9 9
10 #include "sherpa-onnx/csrc/online-model-config.h" 10 #include "sherpa-onnx/csrc/online-model-config.h"
11 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
  12 +#include "sherpa-onnx/csrc/provider-config.h"
12 #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" 13 #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
13 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" 14 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
14 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 15 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
15 #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" 16 #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
16 #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" 17 #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
  18 +#include "sherpa-onnx/python/csrc/provider-config.h"
17 19
18 namespace sherpa_onnx { 20 namespace sherpa_onnx {
19 21
@@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) { @@ -23,6 +25,7 @@ void PybindOnlineModelConfig(py::module *m) {
23 PybindOnlineWenetCtcModelConfig(m); 25 PybindOnlineWenetCtcModelConfig(m);
24 PybindOnlineZipformer2CtcModelConfig(m); 26 PybindOnlineZipformer2CtcModelConfig(m);
25 PybindOnlineNeMoCtcModelConfig(m); 27 PybindOnlineNeMoCtcModelConfig(m);
  28 + PybindProviderConfig(m);
26 29
27 using PyClass = OnlineModelConfig; 30 using PyClass = OnlineModelConfig;
28 py::class_<PyClass>(*m, "OnlineModelConfig") 31 py::class_<PyClass>(*m, "OnlineModelConfig")
@@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) { @@ -30,33 +33,34 @@ void PybindOnlineModelConfig(py::module *m) {
30 const OnlineParaformerModelConfig &, 33 const OnlineParaformerModelConfig &,
31 const OnlineWenetCtcModelConfig &, 34 const OnlineWenetCtcModelConfig &,
32 const OnlineZipformer2CtcModelConfig &, 35 const OnlineZipformer2CtcModelConfig &,
33 - const OnlineNeMoCtcModelConfig &, const std::string &,  
34 - int32_t, int32_t, bool, const std::string &,  
35 - const std::string &, const std::string &, 36 + const OnlineNeMoCtcModelConfig &,
  37 + const ProviderConfig &,
  38 + const std::string &, int32_t, int32_t,
  39 + bool, const std::string &, const std::string &,
36 const std::string &>(), 40 const std::string &>(),
37 py::arg("transducer") = OnlineTransducerModelConfig(), 41 py::arg("transducer") = OnlineTransducerModelConfig(),
38 py::arg("paraformer") = OnlineParaformerModelConfig(), 42 py::arg("paraformer") = OnlineParaformerModelConfig(),
39 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), 43 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
40 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), 44 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
41 - py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"),  
42 - py::arg("num_threads"), py::arg("warm_up") = 0,  
43 - py::arg("debug") = false, py::arg("provider") = "cpu",  
44 - py::arg("model_type") = "", py::arg("modeling_unit") = "",  
45 - py::arg("bpe_vocab") = "") 45 + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(),
  46 + py::arg("provider_config") = ProviderConfig(),
  47 + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
  48 + py::arg("debug") = false, py::arg("model_type") = "",
  49 + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "")
46 .def_readwrite("transducer", &PyClass::transducer) 50 .def_readwrite("transducer", &PyClass::transducer)
47 .def_readwrite("paraformer", &PyClass::paraformer) 51 .def_readwrite("paraformer", &PyClass::paraformer)
48 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 52 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
49 .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) 53 .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
50 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 54 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
  55 + .def_readwrite("provider_config", &PyClass::provider_config)
51 .def_readwrite("tokens", &PyClass::tokens) 56 .def_readwrite("tokens", &PyClass::tokens)
52 .def_readwrite("num_threads", &PyClass::num_threads) 57 .def_readwrite("num_threads", &PyClass::num_threads)
  58 + .def_readwrite("warm_up", &PyClass::warm_up)
53 .def_readwrite("debug", &PyClass::debug) 59 .def_readwrite("debug", &PyClass::debug)
54 - .def_readwrite("provider", &PyClass::provider)  
55 .def_readwrite("model_type", &PyClass::model_type) 60 .def_readwrite("model_type", &PyClass::model_type)
56 .def_readwrite("modeling_unit", &PyClass::modeling_unit) 61 .def_readwrite("modeling_unit", &PyClass::modeling_unit)
57 .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) 62 .def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
58 .def("validate", &PyClass::Validate) 63 .def("validate", &PyClass::Validate)
59 .def("__str__", &PyClass::ToString); 64 .def("__str__", &PyClass::ToString);
60 } 65 }
61 -  
62 } // namespace sherpa_onnx 66 } // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/provider-config.cc
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +
  6 +#include "sherpa-onnx/python/csrc/provider-config.h"
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/provider-config.h"
  11 +#include "sherpa-onnx/python/csrc/cuda-config.h"
  12 +#include "sherpa-onnx/python/csrc/tensorrt-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +void PybindProviderConfig(py::module *m) {
  17 + PybindCudaConfig(m);
  18 + PybindTensorrtConfig(m);
  19 +
  20 + using PyClass = ProviderConfig;
  21 + py::class_<PyClass>(*m, "ProviderConfig")
  22 + .def(py::init<>())
  23 + .def(py::init<const std::string &, int32_t>(),
  24 + py::arg("provider") = "cpu",
  25 + py::arg("device") = 0)
  26 + .def(py::init<const TensorrtConfig &, const CudaConfig &,
  27 + const std::string &, int32_t>(),
  28 + py::arg("trt_config") = TensorrtConfig{},
  29 + py::arg("cuda_config") = CudaConfig{},
  30 + py::arg("provider") = "cpu",
  31 + py::arg("device") = 0)
  32 + .def_readwrite("trt_config", &PyClass::trt_config)
  33 + .def_readwrite("cuda_config", &PyClass::cuda_config)
  34 + .def_readwrite("provider", &PyClass::provider)
  35 + .def_readwrite("device", &PyClass::device)
  36 + .def("__str__", &PyClass::ToString)
  37 + .def("validate", &PyClass::Validate);
  38 +}
  39 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/provider-config.h
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindProviderConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_
@@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -51,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
51 PybindEndpoint(&m); 51 PybindEndpoint(&m);
52 PybindOnlineRecognizer(&m); 52 PybindOnlineRecognizer(&m);
53 PybindKeywordSpotter(&m); 53 PybindKeywordSpotter(&m);
54 -  
55 PybindDisplay(&m); 54 PybindDisplay(&m);
56 55
57 PybindOfflineStream(&m); 56 PybindOfflineStream(&m);
  1 +// sherpa-onnx/python/csrc/tensorrt-config.cc
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +#include "sherpa-onnx/python/csrc/tensorrt-config.h"
  6 +
  7 +#include <string>
  8 +#include <memory>
  9 +#include "sherpa-onnx/csrc/provider-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindTensorrtConfig(py::module *m) {
  14 + using PyClass = TensorrtConfig;
  15 + py::class_<PyClass>(*m, "TensorrtConfig")
  16 + .def(py::init<>())
  17 + .def(py::init([](int32_t trt_max_workspace_size,
  18 + int32_t trt_max_partition_iterations,
  19 + int32_t trt_min_subgraph_size,
  20 + bool trt_fp16_enable,
  21 + bool trt_detailed_build_log,
  22 + bool trt_engine_cache_enable,
  23 + bool trt_timing_cache_enable,
  24 + const std::string &trt_engine_cache_path,
  25 + const std::string &trt_timing_cache_path,
  26 + bool trt_dump_subgraphs) -> std::unique_ptr<PyClass> {
  27 + auto ans = std::make_unique<PyClass>();
  28 +
  29 + ans->trt_max_workspace_size = trt_max_workspace_size;
  30 + ans->trt_max_partition_iterations = trt_max_partition_iterations;
  31 + ans->trt_min_subgraph_size = trt_min_subgraph_size;
  32 + ans->trt_fp16_enable = trt_fp16_enable;
  33 + ans->trt_detailed_build_log = trt_detailed_build_log;
  34 + ans->trt_engine_cache_enable = trt_engine_cache_enable;
  35 + ans->trt_timing_cache_enable = trt_timing_cache_enable;
  36 + ans->trt_engine_cache_path = trt_engine_cache_path;
  37 + ans->trt_timing_cache_path = trt_timing_cache_path;
  38 + ans->trt_dump_subgraphs = trt_dump_subgraphs;
  39 +
  40 + return ans;
  41 + }),
  42 + py::arg("trt_max_workspace_size") = 2147483647,
  43 + py::arg("trt_max_partition_iterations") = 10,
  44 + py::arg("trt_min_subgraph_size") = 5,
  45 + py::arg("trt_fp16_enable") = true,
  46 + py::arg("trt_detailed_build_log") = false,
  47 + py::arg("trt_engine_cache_enable") = true,
  48 + py::arg("trt_timing_cache_enable") = true,
  49 + py::arg("trt_engine_cache_path") = ".",
  50 + py::arg("trt_timing_cache_path") = ".",
  51 + py::arg("trt_dump_subgraphs") = false)
  52 +
  53 + .def_readwrite("trt_max_workspace_size",
  54 + &PyClass::trt_max_workspace_size)
  55 + .def_readwrite("trt_max_partition_iterations",
  56 + &PyClass::trt_max_partition_iterations)
  57 + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size)
  58 + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable)
  59 + .def_readwrite("trt_detailed_build_log",
  60 + &PyClass::trt_detailed_build_log)
  61 + .def_readwrite("trt_engine_cache_enable",
  62 + &PyClass::trt_engine_cache_enable)
  63 + .def_readwrite("trt_timing_cache_enable",
  64 + &PyClass::trt_timing_cache_enable)
  65 + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path)
  66 + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path)
  67 + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs)
  68 + .def("__str__", &PyClass::ToString)
  69 + .def("validate", &PyClass::Validate);
  70 +}
  71 +
  72 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/tensorrt-config.h
  2 +//
  3 +// Copyright (c) 2024 Uniphore (Author: Manickavela A)
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindTensorrtConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_
@@ -9,6 +9,7 @@ from _sherpa_onnx import ( @@ -9,6 +9,7 @@ from _sherpa_onnx import (
9 OnlineModelConfig, 9 OnlineModelConfig,
10 OnlineTransducerModelConfig, 10 OnlineTransducerModelConfig,
11 OnlineStream, 11 OnlineStream,
  12 + ProviderConfig,
12 ) 13 )
13 14
14 from _sherpa_onnx import KeywordSpotter as _KeywordSpotter 15 from _sherpa_onnx import KeywordSpotter as _KeywordSpotter
@@ -41,6 +42,7 @@ class KeywordSpotter(object): @@ -41,6 +42,7 @@ class KeywordSpotter(object):
41 keywords_threshold: float = 0.25, 42 keywords_threshold: float = 0.25,
42 num_trailing_blanks: int = 1, 43 num_trailing_blanks: int = 1,
43 provider: str = "cpu", 44 provider: str = "cpu",
  45 + device: int = 0,
44 ): 46 ):
45 """ 47 """
46 Please refer to 48 Please refer to
@@ -85,6 +87,8 @@ class KeywordSpotter(object): @@ -85,6 +87,8 @@ class KeywordSpotter(object):
85 between each other. 87 between each other.
86 provider: 88 provider:
87 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 89 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  90 + device:
  91 + onnxruntime cuda device index.
88 """ 92 """
89 _assert_file_exists(tokens) 93 _assert_file_exists(tokens)
90 _assert_file_exists(encoder) 94 _assert_file_exists(encoder)
@@ -99,11 +103,16 @@ class KeywordSpotter(object): @@ -99,11 +103,16 @@ class KeywordSpotter(object):
99 joiner=joiner, 103 joiner=joiner,
100 ) 104 )
101 105
  106 + provider_config = ProviderConfig(
  107 + provider=provider,
  108 + device = device,
  109 + )
  110 +
102 model_config = OnlineModelConfig( 111 model_config = OnlineModelConfig(
103 transducer=transducer_config, 112 transducer=transducer_config,
104 tokens=tokens, 113 tokens=tokens,
105 num_threads=num_threads, 114 num_threads=num_threads,
106 - provider=provider, 115 + provider_config=provider_config,
107 ) 116 )
108 117
109 feat_config = FeatureExtractorConfig( 118 feat_config = FeatureExtractorConfig(
@@ -11,6 +11,9 @@ from _sherpa_onnx import ( @@ -11,6 +11,9 @@ from _sherpa_onnx import (
11 ) 11 )
12 from _sherpa_onnx import OnlineRecognizer as _Recognizer 12 from _sherpa_onnx import OnlineRecognizer as _Recognizer
13 from _sherpa_onnx import ( 13 from _sherpa_onnx import (
  14 + CudaConfig,
  15 + TensorrtConfig,
  16 + ProviderConfig,
14 OnlineRecognizerConfig, 17 OnlineRecognizerConfig,
15 OnlineRecognizerResult, 18 OnlineRecognizerResult,
16 OnlineStream, 19 OnlineStream,
@@ -56,7 +59,6 @@ class OnlineRecognizer(object): @@ -56,7 +59,6 @@ class OnlineRecognizer(object):
56 hotwords_score: float = 1.5, 59 hotwords_score: float = 1.5,
57 blank_penalty: float = 0.0, 60 blank_penalty: float = 0.0,
58 hotwords_file: str = "", 61 hotwords_file: str = "",
59 - provider: str = "cpu",  
60 model_type: str = "", 62 model_type: str = "",
61 modeling_unit: str = "cjkchar", 63 modeling_unit: str = "cjkchar",
62 bpe_vocab: str = "", 64 bpe_vocab: str = "",
@@ -66,6 +68,19 @@ class OnlineRecognizer(object): @@ -66,6 +68,19 @@ class OnlineRecognizer(object):
66 debug: bool = False, 68 debug: bool = False,
67 rule_fsts: str = "", 69 rule_fsts: str = "",
68 rule_fars: str = "", 70 rule_fars: str = "",
  71 + provider: str = "cpu",
  72 + device: int = 0,
  73 + cudnn_conv_algo_search: int = 1,
  74 + trt_max_workspace_size: int = 2147483647,
  75 + trt_max_partition_iterations: int = 10,
  76 + trt_min_subgraph_size: int = 5,
  77 + trt_fp16_enable: bool = True,
  78 + trt_detailed_build_log: bool = False,
  79 + trt_engine_cache_enable: bool = True,
  80 + trt_timing_cache_enable: bool = True,
  81 + trt_engine_cache_path: str ="",
  82 + trt_timing_cache_path: str ="",
  83 + trt_dump_subgraphs: bool = False,
69 ): 84 ):
70 """ 85 """
71 Please refer to 86 Please refer to
@@ -135,8 +150,6 @@ class OnlineRecognizer(object): @@ -135,8 +150,6 @@ class OnlineRecognizer(object):
135 Temperature scaling for output symbol confidence estiamation. 150 Temperature scaling for output symbol confidence estiamation.
136 It affects only confidence values, the decoding uses the original 151 It affects only confidence values, the decoding uses the original
137 logits without temperature. 152 logits without temperature.
138 - provider:  
139 - onnxruntime execution providers. Valid values are: cpu, cuda, coreml.  
140 model_type: 153 model_type:
141 Online transducer model type. Valid values are: conformer, lstm, 154 Online transducer model type. Valid values are: conformer, lstm,
142 zipformer, zipformer2. All other values lead to loading the model twice. 155 zipformer, zipformer2. All other values lead to loading the model twice.
@@ -156,6 +169,32 @@ class OnlineRecognizer(object): @@ -156,6 +169,32 @@ class OnlineRecognizer(object):
156 rule_fars: 169 rule_fars:
157 If not empty, it specifies fst archives for inverse text normalization. 170 If not empty, it specifies fst archives for inverse text normalization.
158 If there are multiple archives, they are separated by a comma. 171 If there are multiple archives, they are separated by a comma.
  172 + provider:
  173 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
  174 + device:
  175 + onnxruntime cuda device index.
  176 + cudnn_conv_algo_search:
  177 + onxrt CuDNN convolution search algorithm selection. CUDA EP
  178 + trt_max_workspace_size:
  179 + Set TensorRT EP GPU memory usage limit. TensorRT EP
  180 + trt_max_partition_iterations:
  181 + Limit partitioning iterations for model conversion. TensorRT EP
  182 + trt_min_subgraph_size:
  183 + Set minimum size for subgraphs in partitioning. TensorRT EP
  184 + trt_fp16_enable: bool = True,
  185 + Enable FP16 precision for faster performance. TensorRT EP
  186 + trt_detailed_build_log: bool = False,
  187 + Enable detailed logging of build steps. TensorRT EP
  188 + trt_engine_cache_enable: bool = True,
  189 + Enable caching of TensorRT engines. TensorRT EP
  190 + trt_timing_cache_enable: bool = True,
  191 + "Enable use of timing cache to speed up builds." TensorRT EP
  192 + trt_engine_cache_path: str ="",
  193 + "Set path to store cached TensorRT engines." TensorRT EP
  194 + trt_timing_cache_path: str ="",
  195 + "Set path for storing timing cache." TensorRT EP
  196 + trt_dump_subgraphs: bool = False,
  197 + "Dump optimized subgraphs for debugging." TensorRT EP
159 """ 198 """
160 self = cls.__new__(cls) 199 self = cls.__new__(cls)
161 _assert_file_exists(tokens) 200 _assert_file_exists(tokens)
@@ -171,11 +210,35 @@ class OnlineRecognizer(object): @@ -171,11 +210,35 @@ class OnlineRecognizer(object):
171 joiner=joiner, 210 joiner=joiner,
172 ) 211 )
173 212
  213 + cuda_config = CudaConfig(
  214 + cudnn_conv_algo_search=cudnn_conv_algo_search,
  215 + )
  216 +
  217 + trt_config = TensorrtConfig(
  218 + trt_max_workspace_size=trt_max_workspace_size,
  219 + trt_max_partition_iterations=trt_max_partition_iterations,
  220 + trt_min_subgraph_size=trt_min_subgraph_size,
  221 + trt_fp16_enable=trt_fp16_enable,
  222 + trt_detailed_build_log=trt_detailed_build_log,
  223 + trt_engine_cache_enable=trt_engine_cache_enable,
  224 + trt_timing_cache_enable=trt_timing_cache_enable,
  225 + trt_engine_cache_path=trt_engine_cache_path,
  226 + trt_timing_cache_path=trt_timing_cache_path,
  227 + trt_dump_subgraphs=trt_dump_subgraphs,
  228 + )
  229 +
  230 + provider_config = ProviderConfig(
  231 + trt_config=trt_config,
  232 + cuda_config=cuda_config,
  233 + provider=provider,
  234 + device=device,
  235 + )
  236 +
174 model_config = OnlineModelConfig( 237 model_config = OnlineModelConfig(
175 transducer=transducer_config, 238 transducer=transducer_config,
176 tokens=tokens, 239 tokens=tokens,
177 num_threads=num_threads, 240 num_threads=num_threads,
178 - provider=provider, 241 + provider_config=provider_config,
179 model_type=model_type, 242 model_type=model_type,
180 modeling_unit=modeling_unit, 243 modeling_unit=modeling_unit,
181 bpe_vocab=bpe_vocab, 244 bpe_vocab=bpe_vocab,
@@ -251,6 +314,7 @@ class OnlineRecognizer(object): @@ -251,6 +314,7 @@ class OnlineRecognizer(object):
251 debug: bool = False, 314 debug: bool = False,
252 rule_fsts: str = "", 315 rule_fsts: str = "",
253 rule_fars: str = "", 316 rule_fars: str = "",
  317 + device: int = 0,
254 ): 318 ):
255 """ 319 """
256 Please refer to 320 Please refer to
@@ -301,6 +365,8 @@ class OnlineRecognizer(object): @@ -301,6 +365,8 @@ class OnlineRecognizer(object):
301 rule_fars: 365 rule_fars:
302 If not empty, it specifies fst archives for inverse text normalization. 366 If not empty, it specifies fst archives for inverse text normalization.
303 If there are multiple archives, they are separated by a comma. 367 If there are multiple archives, they are separated by a comma.
  368 + device:
  369 + onnxruntime cuda device index.
304 """ 370 """
305 self = cls.__new__(cls) 371 self = cls.__new__(cls)
306 _assert_file_exists(tokens) 372 _assert_file_exists(tokens)
@@ -314,11 +380,16 @@ class OnlineRecognizer(object): @@ -314,11 +380,16 @@ class OnlineRecognizer(object):
314 decoder=decoder, 380 decoder=decoder,
315 ) 381 )
316 382
  383 + provider_config = ProviderConfig(
  384 + provider=provider,
  385 + device=device,
  386 + )
  387 +
317 model_config = OnlineModelConfig( 388 model_config = OnlineModelConfig(
318 paraformer=paraformer_config, 389 paraformer=paraformer_config,
319 tokens=tokens, 390 tokens=tokens,
320 num_threads=num_threads, 391 num_threads=num_threads,
321 - provider=provider, 392 + provider_config=provider_config,
322 model_type="paraformer", 393 model_type="paraformer",
323 debug=debug, 394 debug=debug,
324 ) 395 )
@@ -367,6 +438,7 @@ class OnlineRecognizer(object): @@ -367,6 +438,7 @@ class OnlineRecognizer(object):
367 debug: bool = False, 438 debug: bool = False,
368 rule_fsts: str = "", 439 rule_fsts: str = "",
369 rule_fars: str = "", 440 rule_fars: str = "",
  441 + device: int = 0,
370 ): 442 ):
371 """ 443 """
372 Please refer to 444 Please refer to
@@ -421,6 +493,8 @@ class OnlineRecognizer(object): @@ -421,6 +493,8 @@ class OnlineRecognizer(object):
421 rule_fars: 493 rule_fars:
422 If not empty, it specifies fst archives for inverse text normalization. 494 If not empty, it specifies fst archives for inverse text normalization.
423 If there are multiple archives, they are separated by a comma. 495 If there are multiple archives, they are separated by a comma.
  496 + device:
  497 + onnxruntime cuda device index.
424 """ 498 """
425 self = cls.__new__(cls) 499 self = cls.__new__(cls)
426 _assert_file_exists(tokens) 500 _assert_file_exists(tokens)
@@ -430,11 +504,16 @@ class OnlineRecognizer(object): @@ -430,11 +504,16 @@ class OnlineRecognizer(object):
430 504
431 zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) 505 zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model)
432 506
  507 + provider_config = ProviderConfig(
  508 + provider=provider,
  509 + device=device,
  510 + )
  511 +
433 model_config = OnlineModelConfig( 512 model_config = OnlineModelConfig(
434 zipformer2_ctc=zipformer2_ctc_config, 513 zipformer2_ctc=zipformer2_ctc_config,
435 tokens=tokens, 514 tokens=tokens,
436 num_threads=num_threads, 515 num_threads=num_threads,
437 - provider=provider, 516 + provider_config=provider_config,
438 debug=debug, 517 debug=debug,
439 ) 518 )
440 519
@@ -486,6 +565,7 @@ class OnlineRecognizer(object): @@ -486,6 +565,7 @@ class OnlineRecognizer(object):
486 debug: bool = False, 565 debug: bool = False,
487 rule_fsts: str = "", 566 rule_fsts: str = "",
488 rule_fars: str = "", 567 rule_fars: str = "",
  568 + device: int = 0,
489 ): 569 ):
490 """ 570 """
491 Please refer to 571 Please refer to
@@ -535,6 +615,8 @@ class OnlineRecognizer(object): @@ -535,6 +615,8 @@ class OnlineRecognizer(object):
535 rule_fars: 615 rule_fars:
536 If not empty, it specifies fst archives for inverse text normalization. 616 If not empty, it specifies fst archives for inverse text normalization.
537 If there are multiple archives, they are separated by a comma. 617 If there are multiple archives, they are separated by a comma.
  618 + device:
  619 + onnxruntime cuda device index.
538 """ 620 """
539 self = cls.__new__(cls) 621 self = cls.__new__(cls)
540 _assert_file_exists(tokens) 622 _assert_file_exists(tokens)
@@ -546,11 +628,16 @@ class OnlineRecognizer(object): @@ -546,11 +628,16 @@ class OnlineRecognizer(object):
546 model=model, 628 model=model,
547 ) 629 )
548 630
  631 + provider_config = ProviderConfig(
  632 + provider=provider,
  633 + device=device,
  634 + )
  635 +
549 model_config = OnlineModelConfig( 636 model_config = OnlineModelConfig(
550 nemo_ctc=nemo_ctc_config, 637 nemo_ctc=nemo_ctc_config,
551 tokens=tokens, 638 tokens=tokens,
552 num_threads=num_threads, 639 num_threads=num_threads,
553 - provider=provider, 640 + provider_config=provider_config,
554 debug=debug, 641 debug=debug,
555 ) 642 )
556 643
@@ -598,6 +685,7 @@ class OnlineRecognizer(object): @@ -598,6 +685,7 @@ class OnlineRecognizer(object):
598 debug: bool = False, 685 debug: bool = False,
599 rule_fsts: str = "", 686 rule_fsts: str = "",
600 rule_fars: str = "", 687 rule_fars: str = "",
  688 + device: int = 0,
601 ): 689 ):
602 """ 690 """
603 Please refer to 691 Please refer to
@@ -650,6 +738,8 @@ class OnlineRecognizer(object): @@ -650,6 +738,8 @@ class OnlineRecognizer(object):
650 rule_fars: 738 rule_fars:
651 If not empty, it specifies fst archives for inverse text normalization. 739 If not empty, it specifies fst archives for inverse text normalization.
652 If there are multiple archives, they are separated by a comma. 740 If there are multiple archives, they are separated by a comma.
  741 + device:
  742 + onnxruntime cuda device index.
653 """ 743 """
654 self = cls.__new__(cls) 744 self = cls.__new__(cls)
655 _assert_file_exists(tokens) 745 _assert_file_exists(tokens)
@@ -663,11 +753,16 @@ class OnlineRecognizer(object): @@ -663,11 +753,16 @@ class OnlineRecognizer(object):
663 num_left_chunks=num_left_chunks, 753 num_left_chunks=num_left_chunks,
664 ) 754 )
665 755
  756 + provider_config = ProviderConfig(
  757 + provider=provider,
  758 + device=device,
  759 + )
  760 +
666 model_config = OnlineModelConfig( 761 model_config = OnlineModelConfig(
667 wenet_ctc=wenet_ctc_config, 762 wenet_ctc=wenet_ctc_config,
668 tokens=tokens, 763 tokens=tokens,
669 num_threads=num_threads, 764 num_threads=num_threads,
670 - provider=provider, 765 + provider_config=provider_config,
671 debug=debug, 766 debug=debug,
672 ) 767 )
673 768