Yuekai Zhang
Committed by GitHub

Add onnxruntime gpu for cmake (#153)

* add onnxruntime gpu for cmake

* fix clang

* fix typo

* cpplint
@@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) @@ -19,6 +19,7 @@ option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) 19 option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) 20 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
21 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) 21 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
  22 +option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
22 23
23 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 24 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
24 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 25 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}") @@ -71,6 +72,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
71 message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}") 72 message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
72 message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}") 73 message(STATUS "SHERPA_ONNX_ENABLE_C_API ${SHERPA_ONNX_ENABLE_C_API}")
73 message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}") 74 message(STATUS "SHERPA_ONNX_ENABLE_WEBSOCKET ${SHERPA_ONNX_ENABLE_WEBSOCKET}")
  75 +message(STATUS "SHERPA_ONNX_ENABLE_GPU ${SHERPA_ONNX_ENABLE_GPU}")
74 76
75 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 77 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
76 set(CMAKE_CXX_EXTENSIONS OFF) 78 set(CMAKE_CXX_EXTENSIONS OFF)
@@ -33,6 +33,14 @@ function(download_onnxruntime) @@ -33,6 +33,14 @@ function(download_onnxruntime)
33 # 33 #
34 # ./include 34 # ./include
35 # It contains all the needed header files 35 # It contains all the needed header files
  36 + if(SHERPA_ONNX_ENABLE_GPU)
  37 + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/onnxruntime-linux-x64-gpu-1.14.1.tgz")
  38 + endif()
  39 + # After downloading, it contains:
  40 + # ./lib/libonnxruntime.so.1.14.1
  41 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.1
  42 + # ./lib/libonnxruntime_providers_cuda.so
  43 + # ./include, which contains all the needed header files
36 elseif(APPLE) 44 elseif(APPLE)
37 # If you don't have access to the Internet, 45 # If you don't have access to the Internet,
38 # please pre-download onnxruntime 46 # please pre-download onnxruntime
@@ -97,21 +105,28 @@ function(download_onnxruntime) @@ -97,21 +105,28 @@ function(download_onnxruntime)
97 message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later") 105 message(FATAL_ERROR "Only support Linux, macOS, and Windows at present. Will support other OSes later")
98 endif() 106 endif()
99 107
100 - foreach(f IN LISTS possible_file_locations)  
101 - if(EXISTS ${f})  
102 - set(onnxruntime_URL "${f}")  
103 - file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)  
104 - set(onnxruntime_URL2)  
105 - break()  
106 - endif()  
107 - endforeach() 108 + if(NOT SHERPA_ONNX_ENABLE_GPU)
  109 + foreach(f IN LISTS possible_file_locations)
  110 + if(EXISTS ${f})
  111 + set(onnxruntime_URL "${f}")
  112 + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
  113 + set(onnxruntime_URL2)
  114 + break()
  115 + endif()
  116 + endforeach()
108 117
109 - FetchContent_Declare(onnxruntime  
110 - URL  
111 - ${onnxruntime_URL}  
112 - ${onnxruntime_URL2}  
113 - URL_HASH ${onnxruntime_HASH}  
114 - ) 118 + FetchContent_Declare(onnxruntime
  119 + URL
  120 + ${onnxruntime_URL}
  121 + ${onnxruntime_URL2}
  122 + URL_HASH ${onnxruntime_HASH}
  123 + )
  124 + else()
  125 + FetchContent_Declare(onnxruntime
  126 + URL
  127 + ${onnxruntime_URL}
  128 + )
  129 + endif()
115 130
116 FetchContent_GetProperties(onnxruntime) 131 FetchContent_GetProperties(onnxruntime)
117 if(NOT onnxruntime_POPULATED) 132 if(NOT onnxruntime_POPULATED)
@@ -134,6 +149,19 @@ function(download_onnxruntime) @@ -134,6 +149,19 @@ function(download_onnxruntime)
134 IMPORTED_LOCATION ${location_onnxruntime} 149 IMPORTED_LOCATION ${location_onnxruntime}
135 INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include" 150 INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include"
136 ) 151 )
  152 +
  153 + if(SHERPA_ONNX_ENABLE_GPU)
  154 + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
  155 + PATHS
  156 + "${onnxruntime_SOURCE_DIR}/lib"
  157 + NO_CMAKE_SYSTEM_PATH
  158 + )
  159 + add_library(onnxruntime_providers_cuda SHARED IMPORTED)
  160 + set_target_properties(onnxruntime_providers_cuda PROPERTIES
  161 + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
  162 + )
  163 + endif()
  164 +
137 if(WIN32) 165 if(WIN32)
138 set_property(TARGET onnxruntime 166 set_property(TARGET onnxruntime
139 PROPERTY 167 PROPERTY
@@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR}) @@ -185,6 +213,12 @@ if(DEFINED ENV{SHERPA_ONNXRUNTIME_LIB_DIR})
185 if(NOT EXISTS ${location_onnxruntime_lib}) 213 if(NOT EXISTS ${location_onnxruntime_lib})
186 set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a) 214 set(location_onnxruntime_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime.a)
187 endif() 215 endif()
  216 + if(SHERPA_ONNX_ENABLE_GPU)
  217 + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.so)
  218 + if(NOT EXISTS ${location_onnxruntime_cuda_lib})
  219 + set(location_onnxruntime_cuda_lib $ENV{SHERPA_ONNXRUNTIME_LIB_DIR}/libonnxruntime_providers_cuda.a)
  220 + endif()
  221 + endif()
188 else() 222 else()
189 find_library(location_onnxruntime_lib onnxruntime 223 find_library(location_onnxruntime_lib onnxruntime
190 PATHS 224 PATHS
@@ -192,9 +226,21 @@ else() @@ -192,9 +226,21 @@ else()
192 /usr/lib 226 /usr/lib
193 /usr/local/lib 227 /usr/local/lib
194 ) 228 )
  229 +
  230 + if(SHERPA_ONNX_ENABLE_GPU)
  231 + find_library(location_onnxruntime_cuda_lib onnxruntime_providers_cuda
  232 + PATHS
  233 + /lib
  234 + /usr/lib
  235 + /usr/local/lib
  236 + )
  237 + endif()
195 endif() 238 endif()
196 239
197 message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}") 240 message(STATUS "location_onnxruntime_lib: ${location_onnxruntime_lib}")
  241 +if(SHERPA_ONNX_ENABLE_GPU)
  242 + message(STATUS "location_onnxruntime_cuda_lib: ${location_onnxruntime_cuda_lib}")
  243 +endif()
198 244
199 if(location_onnxruntime_header_dir AND location_onnxruntime_lib) 245 if(location_onnxruntime_header_dir AND location_onnxruntime_lib)
200 add_library(onnxruntime SHARED IMPORTED) 246 add_library(onnxruntime SHARED IMPORTED)
@@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib) @@ -202,6 +248,12 @@ if(location_onnxruntime_header_dir AND location_onnxruntime_lib)
202 IMPORTED_LOCATION ${location_onnxruntime_lib} 248 IMPORTED_LOCATION ${location_onnxruntime_lib}
203 INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}" 249 INTERFACE_INCLUDE_DIRECTORIES "${location_onnxruntime_header_dir}"
204 ) 250 )
  251 + if(SHERPA_ONNX_ENABLE_GPU AND location_onnxruntime_cuda_lib)
  252 + add_library(onnxruntime_providers_cuda SHARED IMPORTED)
  253 + set_target_properties(onnxruntime_providers_cuda PROPERTIES
  254 + IMPORTED_LOCATION ${location_onnxruntime_cuda_lib}
  255 + )
  256 + endif()
205 else() 257 else()
206 message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime") 258 message(STATUS "Could not find a pre-installed onnxruntime. Downloading pre-compiled onnxruntime")
207 download_onnxruntime() 259 download_onnxruntime()
@@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core @@ -78,6 +78,12 @@ target_link_libraries(sherpa-onnx-core
78 kaldi-native-fbank-core 78 kaldi-native-fbank-core
79 ) 79 )
80 80
  81 +if(SHERPA_ONNX_ENABLE_GPU)
  82 + target_link_libraries(sherpa-onnx-core
  83 + onnxruntime_providers_cuda
  84 + )
  85 +endif()
  86 +
81 if(SHERPA_ONNX_ENABLE_CHECK) 87 if(SHERPA_ONNX_ENABLE_CHECK)
82 target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1) 88 target_compile_definitions(sherpa-onnx-core PUBLIC SHERPA_ONNX_ENABLE_CHECK=1)
83 89
@@ -4,8 +4,10 @@ @@ -4,8 +4,10 @@
4 4
5 #include "sherpa-onnx/csrc/session.h" 5 #include "sherpa-onnx/csrc/session.h"
6 6
  7 +#include <algorithm>
7 #include <string> 8 #include <string>
8 #include <utility> 9 #include <utility>
  10 +#include <vector>
9 11
10 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
11 #include "sherpa-onnx/csrc/provider.h" 13 #include "sherpa-onnx/csrc/provider.h"
@@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -27,10 +29,20 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
27 case Provider::kCPU: 29 case Provider::kCPU:
28 break; // nothing to do for the CPU provider 30 break; // nothing to do for the CPU provider
29 case Provider::kCUDA: { 31 case Provider::kCUDA: {
30 - OrtCUDAProviderOptions options;  
31 - options.device_id = 0;  
32 - // set more options on need  
33 - sess_opts.AppendExecutionProvider_CUDA(options); 32 + std::vector<std::string> available_providers =
  33 + Ort::GetAvailableProviders();
  34 + if (std::find(available_providers.begin(), available_providers.end(),
  35 + "CUDAExecutionProvider") != available_providers.end()) {
  36 + // The CUDA provider is available, proceed with setting the options
  37 + OrtCUDAProviderOptions options;
  38 + options.device_id = 0;
  39 + // set more options on need
  40 + sess_opts.AppendExecutionProvider_CUDA(options);
  41 + } else {
  42 + SHERPA_ONNX_LOGE(
  43 + "Please compile with -DSHERPA_ONNX_ENABLE_GPU=ON. Fallback to "
  44 + "cpu!");
  45 + }
34 break; 46 break;
35 } 47 }
36 case Provider::kCoreML: { 48 case Provider::kCoreML: {