正在显示
6 个修改的文件
包含
218 行增加
和
7 行删除
| @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF) | ||
| 30 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) | 30 | option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON) |
| 31 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) | 31 | option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) |
| 32 | option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) | 32 | option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) |
| 33 | +option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF) | ||
| 33 | option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) | 34 | option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) |
| 34 | option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) | 35 | option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) |
| 35 | option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) | 36 | option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) |
| @@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.") | @@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.") | ||
| 94 | endif() | 95 | endif() |
| 95 | endif() | 96 | endif() |
| 96 | 97 | ||
| 98 | +if(SHERPA_ONNX_ENABLE_DIRECTML) | ||
| 99 | + message(WARNING "\ | ||
| 100 | +Compiling with DirectML enabled. Please make sure Windows 10 SDK | ||
| 101 | +is installed on your system. Otherwise, you will get errors at runtime. | ||
| 102 | +Please refer to | ||
| 103 | + https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements | ||
| 104 | +to install Windows 10 SDK if you have not installed it.") | ||
| 105 | + if(NOT BUILD_SHARED_LIBS) | ||
| 106 | + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON") | ||
| 107 | + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) | ||
| 108 | + endif() | ||
| 109 | +endif() | ||
| 110 | + | ||
| 97 | # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html | 111 | # see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html |
| 98 | # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake | 112 | # https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake |
| 99 | if(MSVC) | 113 | if(MSVC) |
| @@ -160,6 +174,14 @@ else() | @@ -160,6 +174,14 @@ else() | ||
| 160 | add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0) | 174 | add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0) |
| 161 | endif() | 175 | endif() |
| 162 | 176 | ||
| 177 | +if(SHERPA_ONNX_ENABLE_DIRECTML) | ||
| 178 | + message(STATUS "DirectML is enabled") | ||
| 179 | + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1) | ||
| 180 | +else() | ||
| 181 | + message(WARNING "DirectML is disabled") | ||
| 182 | + add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0) | ||
| 183 | +endif() | ||
| 184 | + | ||
| 163 | if(SHERPA_ONNX_ENABLE_WASM_TTS) | 185 | if(SHERPA_ONNX_ENABLE_WASM_TTS) |
| 164 | if(NOT SHERPA_ONNX_ENABLE_TTS) | 186 | if(NOT SHERPA_ONNX_ENABLE_TTS) |
| 165 | message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") | 187 | message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS") |
cmake/onnxruntime-win-x64-directml.cmake
0 → 100644
| 1 | +# Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 2 | +message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") | ||
| 3 | +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") | ||
| 4 | +message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}") | ||
| 5 | + | ||
| 6 | +if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows) | ||
| 7 | + message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}") | ||
| 8 | +endif() | ||
| 9 | + | ||
| 10 | +if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64)) | ||
| 11 | + message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}") | ||
| 12 | +endif() | ||
| 13 | + | ||
| 14 | +if(NOT BUILD_SHARED_LIBS) | ||
| 15 | + message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}") | ||
| 16 | +endif() | ||
| 17 | + | ||
| 18 | +if(NOT SHERPA_ONNX_ENABLE_DIRECTML) | ||
| 19 | + message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}") | ||
| 20 | +endif() | ||
| 21 | + | ||
| 22 | +set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") | ||
| 23 | +set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg") | ||
| 24 | +set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a") | ||
| 25 | + | ||
| 26 | +# If you don't have access to the Internet, | ||
| 27 | +# please download onnxruntime to one of the following locations. | ||
| 28 | +# You can add more if you want. | ||
| 29 | +set(possible_file_locations | ||
| 30 | + $ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg | ||
| 31 | + ${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg | ||
| 32 | + ${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg | ||
| 33 | + /tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg | ||
| 34 | +) | ||
| 35 | + | ||
| 36 | +foreach(f IN LISTS possible_file_locations) | ||
| 37 | + if(EXISTS ${f}) | ||
| 38 | + set(onnxruntime_URL "${f}") | ||
| 39 | + file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL) | ||
| 40 | + message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}") | ||
| 41 | + set(onnxruntime_URL2) | ||
| 42 | + break() | ||
| 43 | + endif() | ||
| 44 | +endforeach() | ||
| 45 | + | ||
| 46 | +FetchContent_Declare(onnxruntime | ||
| 47 | + URL | ||
| 48 | + ${onnxruntime_URL} | ||
| 49 | + ${onnxruntime_URL2} | ||
| 50 | + URL_HASH ${onnxruntime_HASH} | ||
| 51 | +) | ||
| 52 | + | ||
| 53 | +FetchContent_GetProperties(onnxruntime) | ||
| 54 | +if(NOT onnxruntime_POPULATED) | ||
| 55 | + message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}") | ||
| 56 | + FetchContent_Populate(onnxruntime) | ||
| 57 | +endif() | ||
| 58 | +message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}") | ||
| 59 | + | ||
| 60 | +find_library(location_onnxruntime onnxruntime | ||
| 61 | + PATHS | ||
| 62 | + "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native" | ||
| 63 | + NO_CMAKE_SYSTEM_PATH | ||
| 64 | +) | ||
| 65 | + | ||
| 66 | +message(STATUS "location_onnxruntime: ${location_onnxruntime}") | ||
| 67 | + | ||
| 68 | +add_library(onnxruntime SHARED IMPORTED) | ||
| 69 | + | ||
| 70 | +set_target_properties(onnxruntime PROPERTIES | ||
| 71 | + IMPORTED_LOCATION ${location_onnxruntime} | ||
| 72 | + INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include" | ||
| 73 | +) | ||
| 74 | + | ||
| 75 | +set_property(TARGET onnxruntime | ||
| 76 | + PROPERTY | ||
| 77 | + IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib" | ||
| 78 | +) | ||
| 79 | + | ||
| 80 | +file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll | ||
| 81 | + DESTINATION | ||
| 82 | + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} | ||
| 83 | +) | ||
| 84 | + | ||
| 85 | +file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*") | ||
| 86 | + | ||
| 87 | +message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}") | ||
| 88 | + | ||
| 89 | +if(SHERPA_ONNX_ENABLE_PYTHON) | ||
| 90 | + install(FILES ${onnxruntime_lib_files} DESTINATION ..) | ||
| 91 | +else() | ||
| 92 | + install(FILES ${onnxruntime_lib_files} DESTINATION lib) | ||
| 93 | +endif() | ||
| 94 | + | ||
| 95 | +install(FILES ${onnxruntime_lib_files} DESTINATION bin) | ||
| 96 | + | ||
| 97 | +# Setup DirectML | ||
| 98 | + | ||
| 99 | +set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0") | ||
| 100 | +set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23") | ||
| 101 | + | ||
| 102 | +set(possible_directml_file_locations | ||
| 103 | + $ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg | ||
| 104 | + ${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg | ||
| 105 | + ${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg | ||
| 106 | + /tmp/Microsoft.AI.DirectML.1.15.0.nupkg | ||
| 107 | +) | ||
| 108 | + | ||
| 109 | +foreach(f IN LISTS possible_directml_file_locations) | ||
| 110 | + if(EXISTS ${f}) | ||
| 111 | + set(directml_URL "${f}") | ||
| 112 | + file(TO_CMAKE_PATH "${directml_URL}" directml_URL) | ||
| 113 | + message(STATUS "Found local downloaded DirectML: ${directml_URL}") | ||
| 114 | + break() | ||
| 115 | + endif() | ||
| 116 | +endforeach() | ||
| 117 | + | ||
| 118 | +FetchContent_Declare(directml | ||
| 119 | + URL | ||
| 120 | + ${directml_URL} | ||
| 121 | + URL_HASH ${directml_HASH} | ||
| 122 | +) | ||
| 123 | + | ||
| 124 | +FetchContent_GetProperties(directml) | ||
| 125 | +if(NOT directml_POPULATED) | ||
| 126 | + message(STATUS "Downloading DirectML from ${directml_URL}") | ||
| 127 | + FetchContent_Populate(directml) | ||
| 128 | +endif() | ||
| 129 | +message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}") | ||
| 130 | + | ||
| 131 | +find_library(location_directml DirectML | ||
| 132 | + PATHS | ||
| 133 | + "${directml_SOURCE_DIR}/bin/x64-win" | ||
| 134 | + NO_CMAKE_SYSTEM_PATH | ||
| 135 | +) | ||
| 136 | + | ||
| 137 | +message(STATUS "location_directml: ${location_directml}") | ||
| 138 | + | ||
| 139 | +add_library(directml SHARED IMPORTED) | ||
| 140 | + | ||
| 141 | +set_target_properties(directml PROPERTIES | ||
| 142 | + IMPORTED_LOCATION ${location_directml} | ||
| 143 | + INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win" | ||
| 144 | +) | ||
| 145 | + | ||
| 146 | +set_property(TARGET directml | ||
| 147 | + PROPERTY | ||
| 148 | + IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib" | ||
| 149 | +) | ||
| 150 | + | ||
| 151 | +file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll | ||
| 152 | + DESTINATION | ||
| 153 | + ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE} | ||
| 154 | +) | ||
| 155 | + | ||
| 156 | +file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*") | ||
| 157 | + | ||
| 158 | +message(STATUS "DirectML lib files: ${directml_lib_files}") | ||
| 159 | + | ||
| 160 | +install(FILES ${directml_lib_files} DESTINATION lib) | ||
| 161 | +install(FILES ${directml_lib_files} DESTINATION bin) |
| @@ -95,7 +95,10 @@ function(download_onnxruntime) | @@ -95,7 +95,10 @@ function(download_onnxruntime) | ||
| 95 | include(onnxruntime-win-arm64) | 95 | include(onnxruntime-win-arm64) |
| 96 | else() | 96 | else() |
| 97 | # for 64-bit windows (x64) | 97 | # for 64-bit windows (x64) |
| 98 | - if(BUILD_SHARED_LIBS) | 98 | + if(SHERPA_ONNX_ENABLE_DIRECTML) |
| 99 | + message(STATUS "Use DirectML") | ||
| 100 | + include(onnxruntime-win-x64-directml) | ||
| 101 | + elseif(BUILD_SHARED_LIBS) | ||
| 99 | message(STATUS "Use dynamic onnxruntime libraries") | 102 | message(STATUS "Use dynamic onnxruntime libraries") |
| 100 | if(SHERPA_ONNX_ENABLE_GPU) | 103 | if(SHERPA_ONNX_ENABLE_GPU) |
| 101 | include(onnxruntime-win-x64-gpu) | 104 | include(onnxruntime-win-x64-gpu) |
| @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) { | @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) { | ||
| 26 | return Provider::kNNAPI; | 26 | return Provider::kNNAPI; |
| 27 | } else if (s == "trt") { | 27 | } else if (s == "trt") { |
| 28 | return Provider::kTRT; | 28 | return Provider::kTRT; |
| 29 | + } else if (s == "directml") { | ||
| 30 | + return Provider::kDirectML; | ||
| 29 | } else { | 31 | } else { |
| 30 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); | 32 | SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); |
| 31 | return Provider::kCPU; | 33 | return Provider::kCPU; |
| @@ -14,12 +14,13 @@ namespace sherpa_onnx { | @@ -14,12 +14,13 @@ namespace sherpa_onnx { | ||
| 14 | // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java | 14 | // https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java |
| 15 | // for a list of available providers | 15 | // for a list of available providers |
| 16 | enum class Provider { | 16 | enum class Provider { |
| 17 | - kCPU = 0, // CPUExecutionProvider | ||
| 18 | - kCUDA = 1, // CUDAExecutionProvider | ||
| 19 | - kCoreML = 2, // CoreMLExecutionProvider | ||
| 20 | - kXnnpack = 3, // XnnpackExecutionProvider | ||
| 21 | - kNNAPI = 4, // NnapiExecutionProvider | ||
| 22 | - kTRT = 5, // TensorRTExecutionProvider | 17 | + kCPU = 0, // CPUExecutionProvider |
| 18 | + kCUDA = 1, // CUDAExecutionProvider | ||
| 19 | + kCoreML = 2, // CoreMLExecutionProvider | ||
| 20 | + kXnnpack = 3, // XnnpackExecutionProvider | ||
| 21 | + kNNAPI = 4, // NnapiExecutionProvider | ||
| 22 | + kTRT = 5, // TensorRTExecutionProvider | ||
| 23 | + kDirectML = 6, // DmlExecutionProvider | ||
| 23 | }; | 24 | }; |
| 24 | 25 | ||
| 25 | /** | 26 | /** |
| @@ -19,6 +19,10 @@ | @@ -19,6 +19,10 @@ | ||
| 19 | #include "nnapi_provider_factory.h" // NOLINT | 19 | #include "nnapi_provider_factory.h" // NOLINT |
| 20 | #endif | 20 | #endif |
| 21 | 21 | ||
| 22 | +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 | ||
| 23 | +#include "dml_provider_factory.h" // NOLINT | ||
| 24 | +#endif | ||
| 25 | + | ||
| 22 | namespace sherpa_onnx { | 26 | namespace sherpa_onnx { |
| 23 | 27 | ||
| 24 | static void OrtStatusFailure(OrtStatus *status, const char *s) { | 28 | static void OrtStatusFailure(OrtStatus *status, const char *s) { |
| @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl( | @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl( | ||
| 167 | } | 171 | } |
| 168 | break; | 172 | break; |
| 169 | } | 173 | } |
| 174 | + case Provider::kDirectML: { | ||
| 175 | +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 | ||
| 176 | + sess_opts.DisableMemPattern(); | ||
| 177 | + sess_opts.SetExecutionMode(ORT_SEQUENTIAL); | ||
| 178 | + int32_t device_id = 0; | ||
| 179 | + OrtStatus *status = | ||
| 180 | + OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id); | ||
| 181 | + if (status) { | ||
| 182 | + const auto &api = Ort::GetApi(); | ||
| 183 | + const char *msg = api.GetErrorMessage(status); | ||
| 184 | + SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg); | ||
| 185 | + api.ReleaseStatus(status); | ||
| 186 | + } | ||
| 187 | +#else | ||
| 188 | + SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!"); | ||
| 189 | +#endif | ||
| 190 | + break; | ||
| 191 | + } | ||
| 170 | case Provider::kCoreML: { | 192 | case Provider::kCoreML: { |
| 171 | #if defined(__APPLE__) | 193 | #if defined(__APPLE__) |
| 172 | uint32_t coreml_flags = 0; | 194 | uint32_t coreml_flags = 0; |
-
请 注册 或 登录 后发表评论