thewh1teagle
Committed by GitHub

feat: add directml support (#1153)

@@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
30 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) 30 option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
31 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) 31 option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
32 option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) 32 option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
  33 +option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
33 option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) 34 option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
34 option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) 35 option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
35 option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) 36 option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
@@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.") @@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
94 endif() 95 endif()
95 endif() 96 endif()
96 97
  98 +if(SHERPA_ONNX_ENABLE_DIRECTML)
  99 + message(WARNING "\
  100 +Compiling with DirectML enabled. Please make sure Windows 10 SDK
  101 +is installed on your system. Otherwise, you will get errors at runtime.
  102 +Please refer to
  103 + https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
  104 +to install Windows 10 SDK if you have not installed it.")
  105 + if(NOT BUILD_SHARED_LIBS)
  106 + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON")
  107 + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
  108 + endif()
  109 +endif()
  110 +
97 # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html 111 # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
98 # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake 112 # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
99 if(MSVC) 113 if(MSVC)
@@ -160,6 +174,14 @@ else() @@ -160,6 +174,14 @@ else()
160 add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0) 174 add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
161 endif() 175 endif()
162 176
  177 +if(SHERPA_ONNX_ENABLE_DIRECTML)
  178 + message(STATUS "DirectML is enabled")
  179 + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
  180 +else()
  181 + message(WARNING "DirectML is disabled")
  182 + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
  183 +endif()
  184 +
163 if(SHERPA_ONNX_ENABLE_WASM_TTS) 185 if(SHERPA_ONNX_ENABLE_WASM_TTS)
164 if(NOT SHERPA_ONNX_ENABLE_TTS) 186 if(NOT SHERPA_ONNX_ENABLE_TTS)
165 message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") 187 message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
  1 +# Copyright (c) 2022-2023 Xiaomi Corporation
  2 +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
  3 +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
  4 +message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")
  5 +
  6 +if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)
  7 + message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}")
  8 +endif()
  9 +
  10 +if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64))
  11 + message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}")
  12 +endif()
  13 +
  14 +if(NOT BUILD_SHARED_LIBS)
  15 + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
  16 +endif()
  17 +
  18 +if(NOT SHERPA_ONNX_ENABLE_DIRECTML)
  19 + message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}")
  20 +endif()
  21 +
  22 +set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
  23 +set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
  24 +set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")
  25 +
  26 +# If you don't have access to the Internet,
  27 +# please download onnxruntime to one of the following locations.
  28 +# You can add more if you want.
  29 +set(possible_file_locations
  30 + $ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
  31 + ${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
  32 + ${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
  33 + /tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
  34 +)
  35 +
  36 +foreach(f IN LISTS possible_file_locations)
  37 + if(EXISTS ${f})
  38 + set(onnxruntime_URL "${f}")
  39 + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
  40 + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
  41 + set(onnxruntime_URL2)
  42 + break()
  43 + endif()
  44 +endforeach()
  45 +
  46 +FetchContent_Declare(onnxruntime
  47 + URL
  48 + ${onnxruntime_URL}
  49 + ${onnxruntime_URL2}
  50 + URL_HASH ${onnxruntime_HASH}
  51 +)
  52 +
  53 +FetchContent_GetProperties(onnxruntime)
  54 +if(NOT onnxruntime_POPULATED)
  55 + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
  56 + FetchContent_Populate(onnxruntime)
  57 +endif()
  58 +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")
  59 +
  60 +find_library(location_onnxruntime onnxruntime
  61 + PATHS
  62 + "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native"
  63 + NO_CMAKE_SYSTEM_PATH
  64 +)
  65 +
  66 +message(STATUS "location_onnxruntime: ${location_onnxruntime}")
  67 +
  68 +add_library(onnxruntime SHARED IMPORTED)
  69 +
  70 +set_target_properties(onnxruntime PROPERTIES
  71 + IMPORTED_LOCATION ${location_onnxruntime}
  72 + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include"
  73 +)
  74 +
  75 +set_property(TARGET onnxruntime
  76 + PROPERTY
  77 + IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib"
  78 +)
  79 +
  80 +file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll
  81 + DESTINATION
  82 + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
  83 +)
  84 +
  85 +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*")
  86 +
  87 +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")
  88 +
  89 +if(SHERPA_ONNX_ENABLE_PYTHON)
  90 + install(FILES ${onnxruntime_lib_files} DESTINATION ..)
  91 +else()
  92 + install(FILES ${onnxruntime_lib_files} DESTINATION lib)
  93 +endif()
  94 +
  95 +install(FILES ${onnxruntime_lib_files} DESTINATION bin)
  96 +
  97 +# Setup DirectML
  98 +
  99 +set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0")
  100 +set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23")
  101 +
  102 +set(possible_directml_file_locations
  103 + $ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg
  104 + ${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
  105 + ${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
  106 + /tmp/Microsoft.AI.DirectML.1.15.0.nupkg
  107 +)
  108 +
  109 +foreach(f IN LISTS possible_directml_file_locations)
  110 + if(EXISTS ${f})
  111 + set(directml_URL "${f}")
  112 + file(TO_CMAKE_PATH "${directml_URL}" directml_URL)
  113 + message(STATUS "Found local downloaded DirectML: ${directml_URL}")
  114 + break()
  115 + endif()
  116 +endforeach()
  117 +
  118 +FetchContent_Declare(directml
  119 + URL
  120 + ${directml_URL}
  121 + URL_HASH ${directml_HASH}
  122 +)
  123 +
  124 +FetchContent_GetProperties(directml)
  125 +if(NOT directml_POPULATED)
  126 + message(STATUS "Downloading DirectML from ${directml_URL}")
  127 + FetchContent_Populate(directml)
  128 +endif()
  129 +message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}")
  130 +
  131 +find_library(location_directml DirectML
  132 + PATHS
  133 + "${directml_SOURCE_DIR}/bin/x64-win"
  134 + NO_CMAKE_SYSTEM_PATH
  135 +)
  136 +
  137 +message(STATUS "location_directml: ${location_directml}")
  138 +
  139 +add_library(directml SHARED IMPORTED)
  140 +
  141 +set_target_properties(directml PROPERTIES
  142 + IMPORTED_LOCATION ${location_directml}
  143 + INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win"
  144 +)
  145 +
  146 +set_property(TARGET directml
  147 + PROPERTY
  148 + IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib"
  149 +)
  150 +
  151 +file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll
  152 + DESTINATION
  153 + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
  154 +)
  155 +
  156 +file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*")
  157 +
  158 +message(STATUS "DirectML lib files: ${directml_lib_files}")
  159 +
  160 +install(FILES ${directml_lib_files} DESTINATION lib)
  161 +install(FILES ${directml_lib_files} DESTINATION bin)
@@ -95,7 +95,10 @@ function(download_onnxruntime) @@ -95,7 +95,10 @@ function(download_onnxruntime)
95 include(onnxruntime-win-arm64) 95 include(onnxruntime-win-arm64)
96 else() 96 else()
97 # for 64-bit windows (x64) 97 # for 64-bit windows (x64)
98 - if(BUILD_SHARED_LIBS) 98 + if(SHERPA_ONNX_ENABLE_DIRECTML)
  99 + message(STATUS "Use DirectML")
  100 + include(onnxruntime-win-x64-directml)
  101 + elseif(BUILD_SHARED_LIBS)
99 message(STATUS "Use dynamic onnxruntime libraries") 102 message(STATUS "Use dynamic onnxruntime libraries")
100 if(SHERPA_ONNX_ENABLE_GPU) 103 if(SHERPA_ONNX_ENABLE_GPU)
101 include(onnxruntime-win-x64-gpu) 104 include(onnxruntime-win-x64-gpu)
@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) { @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
26 return Provider::kNNAPI; 26 return Provider::kNNAPI;
27 } else if (s == "trt") { 27 } else if (s == "trt") {
28 return Provider::kTRT; 28 return Provider::kTRT;
  29 + } else if (s == "directml") {
  30 + return Provider::kDirectML;
29 } else { 31 } else {
30 SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); 32 SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
31 return Provider::kCPU; 33 return Provider::kCPU;
@@ -14,12 +14,13 @@ namespace sherpa_onnx { @@ -14,12 +14,13 @@ namespace sherpa_onnx {
14 // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java 14 // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
15 // for a list of available providers 15 // for a list of available providers
16 enum class Provider { 16 enum class Provider {
17 - kCPU = 0, // CPUExecutionProvider  
18 - kCUDA = 1, // CUDAExecutionProvider  
19 - kCoreML = 2, // CoreMLExecutionProvider  
20 - kXnnpack = 3, // XnnpackExecutionProvider  
21 - kNNAPI = 4, // NnapiExecutionProvider  
22 - kTRT = 5, // TensorRTExecutionProvider 17 + kCPU = 0, // CPUExecutionProvider
  18 + kCUDA = 1, // CUDAExecutionProvider
  19 + kCoreML = 2, // CoreMLExecutionProvider
  20 + kXnnpack = 3, // XnnpackExecutionProvider
  21 + kNNAPI = 4, // NnapiExecutionProvider
  22 + kTRT = 5, // TensorRTExecutionProvider
  23 + kDirectML = 6, // DmlExecutionProvider
23 }; 24 };
24 25
25 /** 26 /**
@@ -19,6 +19,10 @@ @@ -19,6 +19,10 @@
19 #include "nnapi_provider_factory.h" // NOLINT 19 #include "nnapi_provider_factory.h" // NOLINT
20 #endif 20 #endif
21 21
  22 +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
  23 +#include "dml_provider_factory.h" // NOLINT
  24 +#endif
  25 +
22 namespace sherpa_onnx { 26 namespace sherpa_onnx {
23 27
24 static void OrtStatusFailure(OrtStatus *status, const char *s) { 28 static void OrtStatusFailure(OrtStatus *status, const char *s) {
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl( @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
167 } 171 }
168 break; 172 break;
169 } 173 }
  174 + case Provider::kDirectML: {
  175 +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
  176 + sess_opts.DisableMemPattern();
  177 + sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
  178 + int32_t device_id = 0;
  179 + OrtStatus *status =
  180 + OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
  181 + if (status) {
  182 + const auto &api = Ort::GetApi();
  183 + const char *msg = api.GetErrorMessage(status);
  184 + SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
  185 + api.ReleaseStatus(status);
  186 + }
  187 +#else
  188 + SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
  189 +#endif
  190 + break;
  191 + }
170 case Provider::kCoreML: { 192 case Provider::kCoreML: {
171 #if defined(__APPLE__) 193 #if defined(__APPLE__)
172 uint32_t coreml_flags = 0; 194 uint32_t coreml_flags = 0;