Committed by
GitHub
Add onnxruntime gpu for cmake (#153)
* add onnxruntime gpu for cmake * fix clang * fix typo * cpplint
正在显示
4 个修改的文件
包含
90 行增加
和
18 行删除
| @@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) | @@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) | ||
| 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | 19 | option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) |
| 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) | 20 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) |
| 21 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) | 21 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) |
| 22 | +option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) | ||
| 22 | 23 | ||
| 23 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 24 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| 24 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") | 25 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") |
| @@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") | @@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") | ||
| 71 | message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") | 72 | message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") |
| 72 | message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") | 73 | message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") |
| 73 | message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") | 74 | message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") |
| 75 | +message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}") | ||
| 74 | 76 | ||
| 75 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") | 77 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") |
| 76 | set(CMAKE_CXX_EXTENSIONS OFF) | 78 | set(CMAKE_CXX_EXTENSIONS OFF) |
| @@ -33,6 +33,14 @@ function(download_onnxruntime) | @@ -33,6 +33,14 @@ function(download_onnxruntime) | ||
| 33 | # | 33 | # |
| 34 | # ./include | 34 | # ./include |
| 35 | # It contains all the needed header files | 35 | # It contains all the needed header files |
| 36 | + if(SHERPA_ONNX_ENABLE_GPU) | ||
| 37 | + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-gpu-1.14.1.tgz") | ||
| 38 | + endif() | ||
| 39 | + # After downloading, it contains: | ||
| 40 | + # ./lib/libonnxruntime.so.1.14.1 | ||
| 41 | + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.1 | ||
| 42 | + # ./lib/libonnxruntime_providers_cuda.so | ||
| 43 | + # ./include, which contains all the needed header files | ||
| 36 | elseif(APPLE) | 44 | elseif(APPLE) |
| 37 | # If you don't have access to the Internet, | 45 | # If you don't have access to the Internet, |
| 38 | # please pre-download onnxruntime | 46 | # please pre-download onnxruntime |
| @@ -97,21 +105,28 @@ function(download_onnxruntime) | @@ -97,21 +105,28 @@ function(download_onnxruntime) | ||
| 97 | message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") | 105 | message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") |
| 98 | endif() | 106 | endif() |
| 99 | 107 | ||
| 100 | - foreach(f IN LISTS possible_file_locations) | ||
| 101 | - if(EXISTS ${f}) | ||
| 102 | - set(onnxruntime_URL "${f}") | ||
| 103 | - file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) | ||
| 104 | - set(onnxruntime_URL2) | ||
| 105 | - break() | ||
| 106 | - endif() | ||
| 107 | - endforeach() | 108 | + if(NOT SHERPA_ONNX_ENABLE_GPU) |
| 109 | + foreach(f IN LISTS possible_file_locations) | ||
| 110 | + if(EXISTS ${f}) | ||
| 111 | + set(onnxruntime_URL "${f}") | ||
| 112 | + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) | ||
| 113 | + set(onnxruntime_URL2) | ||
| 114 | + break() | ||
| 115 | + endif() | ||
| 116 | + endforeach() | ||
| 108 | 117 | ||
| 109 | - FetchContent_Declare(onnxruntime | ||
| 110 | - URL | ||
| 111 | - ${onnxruntime_URL} | ||
| 112 | - ${onnxruntime_URL2} | ||
| 113 | - URL_HASH ${onnxruntime_HASH} | ||
| 114 | - ) | 118 | + FetchContent_Declare(onnxruntime |
| 119 | + URL | ||
| 120 | + ${onnxruntime_URL} | ||
| 121 | + ${onnxruntime_URL2} | ||
| 122 | + URL_HASH ${onnxruntime_HASH} | ||
| 123 | + ) | ||
| 124 | + else() | ||
| 125 | + FetchContent_Declare(onnxruntime | ||
| 126 | + URL | ||
| 127 | + ${onnxruntime_URL} | ||
| 128 | + ) | ||
| 129 | + endif() | ||
| 115 | 130 | ||
| 116 | FetchContent_GetProperties(onnxruntime) | 131 | FetchContent_GetProperties(onnxruntime) |
| 117 | if(NOT onnxruntime_POPULATED) | 132 | if(NOT onnxruntime_POPULATED) |
| @@ -134,6 +149,19 @@ function(download_onnxruntime) | @@ -134,6 +149,19 @@ function(download_onnxruntime) | ||
| 134 | IMPORTED_LOCATION ${location_onnxruntime} | 149 | IMPORTED_LOCATION ${location_onnxruntime} |
| 135 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" | 150 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" |
| 136 | ) | 151 | ) |
| 152 | + | ||
| 153 | + if(SHERPA_ONNX_ENABLE_GPU) | ||
| 154 | + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda | ||
| 155 | + PATHS | ||
| 156 | + "${onnxruntime_SOURCE_DIR}/lib" | ||
| 157 | + NO_CMAKE_SYSTEM_PATH | ||
| 158 | + ) | ||
| 159 | + add_library(onnxruntime_providers_cuda SHARED IMPORTED) | ||
| 160 | + set_target_properties(onnxruntime_providers_cuda PROPERTIES | ||
| 161 | + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} | ||
| 162 | + ) | ||
| 163 | + endif() | ||
| 164 | + | ||
| 137 | if(WIN32) | 165 | if(WIN32) |
| 138 | set_property(TARGET onnxruntime | 166 | set_property(TARGET onnxruntime |
| 139 | PROPERTY | 167 | PROPERTY |
| @@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR}) | @@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR}) | ||
| 185 | if(NOT EXISTS ${location_onnxruntime_lib}) | 213 | if(NOT EXISTS ${location_onnxruntime_lib}) |
| 186 | set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a) | 214 | set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a) |
| 187 | endif() | 215 | endif() |
| 216 | + if(SHERPA_ONNX_ENABLE_GPU) | ||
| 217 | + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.so) | ||
| 218 | + if(NOT EXISTS ${location_onnxruntime_cuda_lib}) | ||
| 219 | + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.a) | ||
| 220 | + endif() | ||
| 221 | + endif() | ||
| 188 | else() | 222 | else() |
| 189 | find_library(location_onnxruntime_lib onnxruntime | 223 | find_library(location_onnxruntime_lib onnxruntime |
| 190 | PATHS | 224 | PATHS |
| @@ -192,9 +226,21 @@ else() | @@ -192,9 +226,21 @@ else() | ||
| 192 | /usr/lib | 226 | /usr/lib |
| 193 | /usr/local/lib | 227 | /usr/local/lib |
| 194 | ) | 228 | ) |
| 229 | + | ||
| 230 | + if(SHERPA_ONNX_ENABLE_GPU) | ||
| 231 | + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda | ||
| 232 | + PATHS | ||
| 233 | + /lib | ||
| 234 | + /usr/lib | ||
| 235 | + /usr/local/lib | ||
| 236 | + ) | ||
| 237 | + endif() | ||
| 195 | endif() | 238 | endif() |
| 196 | 239 | ||
| 197 | message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}") | 240 | message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}") |
| 241 | +if(SHERPA_ONNX_ENABLE_GPU) | ||
| 242 | + message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}") | ||
| 243 | +endif() | ||
| 198 | 244 | ||
| 199 | if(location_onnxruntime_header_dir AND location_onnxruntime_lib) | 245 | if(location_onnxruntime_header_dir AND location_onnxruntime_lib) |
| 200 | add_library(onnxruntime SHARED IMPORTED) | 246 | add_library(onnxruntime SHARED IMPORTED) |
| @@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib) | @@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib) | ||
| 202 | IMPORTED_LOCATION ${location_onnxruntime_lib} | 248 | IMPORTED_LOCATION ${location_onnxruntime_lib} |
| 203 | INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}" | 249 | INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}" |
| 204 | ) | 250 | ) |
| 251 | + if(SHERPA_ONNX_ENABLE_GPU AND location_onnxruntime_cuda_lib) | ||
| 252 | + add_library(onnxruntime_providers_cuda SHARED IMPORTED) | ||
| 253 | + set_target_properties(onnxruntime_providers_cuda PROPERTIES | ||
| 254 | + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib} | ||
| 255 | + ) | ||
| 256 | + endif() | ||
| 205 | else() | 257 | else() |
| 206 | message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime") | 258 | message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime") |
| 207 | download_onnxruntime() | 259 | download_onnxruntime() |
| @@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core | @@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core | ||
| 78 | kaldi-native-fbank-core | 78 | kaldi-native-fbank-core |
| 79 | ) | 79 | ) |
| 80 | 80 | ||
| 81 | +if(SHERPA_ONNX_ENABLE_GPU) | ||
| 82 | + target_link_libraries(sherpa-onnx-core | ||
| 83 | + onnxruntime_providers_cuda | ||
| 84 | + ) | ||
| 85 | +endif() | ||
| 86 | + | ||
| 81 | if(SHERPA_ONNX_ENABLE_CHECK) | 87 | if(SHERPA_ONNX_ENABLE_CHECK) |
| 82 | target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) | 88 | target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) |
| 83 | 89 |
| @@ -4,8 +4,10 @@ | @@ -4,8 +4,10 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/session.h" | 5 | #include "sherpa-onnx/csrc/session.h" |
| 6 | 6 | ||
| 7 | +#include <algorithm> | ||
| 7 | #include <string> | 8 | #include <string> |
| 8 | #include <utility> | 9 | #include <utility> |
| 10 | +#include <vector> | ||
| 9 | 11 | ||
| 10 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | #include "sherpa-onnx/csrc/provider.h" | 13 | #include "sherpa-onnx/csrc/provider.h" |
| @@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | @@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 27 | case Provider::kCPU: | 29 | case Provider::kCPU: |
| 28 | break; // nothing to do for the CPU provider | 30 | break; // nothing to do for the CPU provider |
| 29 | case Provider::kCUDA: { | 31 | case Provider::kCUDA: { |
| 30 | - OrtCUDAProviderOptions options; | ||
| 31 | - options.device_id = 0; | ||
| 32 | - // set more options on need | ||
| 33 | - sess_opts.AppendExecutionProvider_CUDA(options); | 32 | + std::vector<std::string> available_providers = |
| 33 | + Ort::GetAvailableProviders(); | ||
| 34 | + if (std::find(available_providers.begin(), available_providers.end(), | ||
| 35 | + "CUDAExecutionProvider") != available_providers.end()) { | ||
| 36 | + // The CUDA provider is available, proceed with setting the options | ||
| 37 | + OrtCUDAProviderOptions options; | ||
| 38 | + options.device_id = 0; | ||
| 39 | + // set more options on need | ||
| 40 | + sess_opts.AppendExecutionProvider_CUDA(options); | ||
| 41 | + } else { | ||
| 42 | + SHERPA_ONNX_LOGE( | ||
| 43 | + "Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to " | ||
| 44 | + "cpu!"); | ||
| 45 | + } | ||
| 34 | break; | 46 | break; |
| 35 | } | 47 | } |
| 36 | case Provider::kCoreML: { | 48 | case Provider::kCoreML: { |
-
请 注册 或 登录 后发表评论