Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
thewh1teagle
2024-07-22 18:50:48 +0300
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-07-22 23:50:48 +0800
Commit
d32a46169f7e30126eab3495aea79c6329c7237a
d32a4616
1 parent
ea1d81bd
feat: add directml support (#1153)
显示空白字符变更
内嵌
并排对比
正在显示
6 个修改的文件
包含
212 行增加
和
1 行删除
CMakeLists.txt
cmake/onnxruntime-win-x64-directml.cmake
cmake/onnxruntime.cmake
sherpa-onnx/csrc/provider.cc
sherpa-onnx/csrc/provider.h
sherpa-onnx/csrc/session.cc
CMakeLists.txt
查看文件 @
d32a461
...
...
@@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option
(
SHERPA_ONNX_ENABLE_C_API
"Whether to build C API"
ON
)
option
(
SHERPA_ONNX_ENABLE_WEBSOCKET
"Whether to build webscoket server/client"
ON
)
option
(
SHERPA_ONNX_ENABLE_GPU
"Enable ONNX Runtime GPU support"
OFF
)
option
(
SHERPA_ONNX_ENABLE_DIRECTML
"Enable ONNX Runtime DirectML support"
OFF
)
option
(
SHERPA_ONNX_ENABLE_WASM
"Whether to enable WASM"
OFF
)
option
(
SHERPA_ONNX_ENABLE_WASM_TTS
"Whether to enable WASM for TTS"
OFF
)
option
(
SHERPA_ONNX_ENABLE_WASM_ASR
"Whether to enable WASM for ASR"
OFF
)
...
...
@@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
endif
()
endif
()
if
(
SHERPA_ONNX_ENABLE_DIRECTML
)
message
(
WARNING
"\
Compiling with DirectML enabled. Please make sure Windows 10 SDK
is installed on your system. Otherwise, you will get errors at runtime.
Please refer to
https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
to install Windows 10 SDK if you have not installed it."
)
if
(
NOT BUILD_SHARED_LIBS
)
message
(
STATUS
"Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON"
)
set
(
BUILD_SHARED_LIBS ON CACHE BOOL
""
FORCE
)
endif
()
endif
()
# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
if
(
MSVC
)
...
...
@@ -160,6 +174,14 @@ else()
add_definitions
(
-DSHERPA_ONNX_ENABLE_TTS=0
)
endif
()
if
(
SHERPA_ONNX_ENABLE_DIRECTML
)
message
(
STATUS
"DirectML is enabled"
)
add_definitions
(
-DSHERPA_ONNX_ENABLE_DIRECTML=1
)
else
()
message
(
WARNING
"DirectML is disabled"
)
add_definitions
(
-DSHERPA_ONNX_ENABLE_DIRECTML=0
)
endif
()
if
(
SHERPA_ONNX_ENABLE_WASM_TTS
)
if
(
NOT SHERPA_ONNX_ENABLE_TTS
)
message
(
FATAL_ERROR
"Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS"
)
...
...
cmake/onnxruntime-win-x64-directml.cmake
0 → 100644
查看文件 @
d32a461
# Copyright (c) 2022-2023 Xiaomi Corporation
message
(
STATUS
"CMAKE_SYSTEM_NAME:
${
CMAKE_SYSTEM_NAME
}
"
)
message
(
STATUS
"CMAKE_SYSTEM_PROCESSOR:
${
CMAKE_SYSTEM_PROCESSOR
}
"
)
message
(
STATUS
"CMAKE_VS_PLATFORM_NAME:
${
CMAKE_VS_PLATFORM_NAME
}
"
)
if
(
NOT CMAKE_SYSTEM_NAME STREQUAL Windows
)
message
(
FATAL_ERROR
"This file is for Windows only. Given:
${
CMAKE_SYSTEM_NAME
}
"
)
endif
()
if
(
NOT
(
CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64
))
message
(
FATAL_ERROR
"This file is for Windows x64 only. Given:
${
CMAKE_VS_PLATFORM_NAME
}
"
)
endif
()
if
(
NOT BUILD_SHARED_LIBS
)
message
(
FATAL_ERROR
"This file is for building shared libraries. BUILD_SHARED_LIBS:
${
BUILD_SHARED_LIBS
}
"
)
endif
()
if
(
NOT SHERPA_ONNX_ENABLE_DIRECTML
)
message
(
FATAL_ERROR
"This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML:
${
SHERPA_ONNX_ENABLE_DIRECTML
}
"
)
endif
()
set
(
onnxruntime_URL
"https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg"
)
set
(
onnxruntime_URL2
"https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg"
)
set
(
onnxruntime_HASH
"SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a"
)
# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
# You can add more if you want.
set
(
possible_file_locations
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${
PROJECT_SOURCE_DIR
}
/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${
PROJECT_BINARY_DIR
}
/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
)
foreach
(
f IN LISTS possible_file_locations
)
if
(
EXISTS
${
f
}
)
set
(
onnxruntime_URL
"
${
f
}
"
)
file
(
TO_CMAKE_PATH
"
${
onnxruntime_URL
}
"
onnxruntime_URL
)
message
(
STATUS
"Found local downloaded onnxruntime:
${
onnxruntime_URL
}
"
)
set
(
onnxruntime_URL2
)
break
()
endif
()
endforeach
()
FetchContent_Declare
(
onnxruntime
URL
${
onnxruntime_URL
}
${
onnxruntime_URL2
}
URL_HASH
${
onnxruntime_HASH
}
)
FetchContent_GetProperties
(
onnxruntime
)
if
(
NOT onnxruntime_POPULATED
)
message
(
STATUS
"Downloading onnxruntime from
${
onnxruntime_URL
}
"
)
FetchContent_Populate
(
onnxruntime
)
endif
()
message
(
STATUS
"onnxruntime is downloaded to
${
onnxruntime_SOURCE_DIR
}
"
)
find_library
(
location_onnxruntime onnxruntime
PATHS
"
${
onnxruntime_SOURCE_DIR
}
/runtimes/win-x64/native"
NO_CMAKE_SYSTEM_PATH
)
message
(
STATUS
"location_onnxruntime:
${
location_onnxruntime
}
"
)
add_library
(
onnxruntime SHARED IMPORTED
)
set_target_properties
(
onnxruntime PROPERTIES
IMPORTED_LOCATION
${
location_onnxruntime
}
INTERFACE_INCLUDE_DIRECTORIES
"
${
onnxruntime_SOURCE_DIR
}
/build/native/include"
)
set_property
(
TARGET onnxruntime
PROPERTY
IMPORTED_IMPLIB
"
${
onnxruntime_SOURCE_DIR
}
/runtimes/win-x64/native/onnxruntime.lib"
)
file
(
COPY
${
onnxruntime_SOURCE_DIR
}
/runtimes/win-x64/native/onnxruntime.dll
DESTINATION
${
CMAKE_BINARY_DIR
}
/bin/
${
CMAKE_BUILD_TYPE
}
)
file
(
GLOB onnxruntime_lib_files
"
${
onnxruntime_SOURCE_DIR
}
/runtimes/win-x64/native/onnxruntime.*"
)
message
(
STATUS
"onnxruntime lib files:
${
onnxruntime_lib_files
}
"
)
if
(
SHERPA_ONNX_ENABLE_PYTHON
)
install
(
FILES
${
onnxruntime_lib_files
}
DESTINATION ..
)
else
()
install
(
FILES
${
onnxruntime_lib_files
}
DESTINATION lib
)
endif
()
install
(
FILES
${
onnxruntime_lib_files
}
DESTINATION bin
)
# Setup DirectML
set
(
directml_URL
"https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0"
)
set
(
directml_HASH
"SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23"
)
set
(
possible_directml_file_locations
$ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg
${
PROJECT_SOURCE_DIR
}
/Microsoft.AI.DirectML.1.15.0.nupkg
${
PROJECT_BINARY_DIR
}
/Microsoft.AI.DirectML.1.15.0.nupkg
/tmp/Microsoft.AI.DirectML.1.15.0.nupkg
)
foreach
(
f IN LISTS possible_directml_file_locations
)
if
(
EXISTS
${
f
}
)
set
(
directml_URL
"
${
f
}
"
)
file
(
TO_CMAKE_PATH
"
${
directml_URL
}
"
directml_URL
)
message
(
STATUS
"Found local downloaded DirectML:
${
directml_URL
}
"
)
break
()
endif
()
endforeach
()
FetchContent_Declare
(
directml
URL
${
directml_URL
}
URL_HASH
${
directml_HASH
}
)
FetchContent_GetProperties
(
directml
)
if
(
NOT directml_POPULATED
)
message
(
STATUS
"Downloading DirectML from
${
directml_URL
}
"
)
FetchContent_Populate
(
directml
)
endif
()
message
(
STATUS
"DirectML is downloaded to
${
directml_SOURCE_DIR
}
"
)
find_library
(
location_directml DirectML
PATHS
"
${
directml_SOURCE_DIR
}
/bin/x64-win"
NO_CMAKE_SYSTEM_PATH
)
message
(
STATUS
"location_directml:
${
location_directml
}
"
)
add_library
(
directml SHARED IMPORTED
)
set_target_properties
(
directml PROPERTIES
IMPORTED_LOCATION
${
location_directml
}
INTERFACE_INCLUDE_DIRECTORIES
"
${
directml_SOURCE_DIR
}
/bin/x64-win"
)
set_property
(
TARGET directml
PROPERTY
IMPORTED_IMPLIB
"
${
directml_SOURCE_DIR
}
/bin/x64-win/DirectML.lib"
)
file
(
COPY
${
directml_SOURCE_DIR
}
/bin/x64-win/DirectML.dll
DESTINATION
${
CMAKE_BINARY_DIR
}
/bin/
${
CMAKE_BUILD_TYPE
}
)
file
(
GLOB directml_lib_files
"
${
directml_SOURCE_DIR
}
/bin/x64-win/DirectML.*"
)
message
(
STATUS
"DirectML lib files:
${
directml_lib_files
}
"
)
install
(
FILES
${
directml_lib_files
}
DESTINATION lib
)
install
(
FILES
${
directml_lib_files
}
DESTINATION bin
)
\ No newline at end of file
...
...
cmake/onnxruntime.cmake
查看文件 @
d32a461
...
...
@@ -95,7 +95,10 @@ function(download_onnxruntime)
include
(
onnxruntime-win-arm64
)
else
()
# for 64-bit windows (x64)
if
(
BUILD_SHARED_LIBS
)
if
(
SHERPA_ONNX_ENABLE_DIRECTML
)
message
(
STATUS
"Use DirectML"
)
include
(
onnxruntime-win-x64-directml
)
elseif
(
BUILD_SHARED_LIBS
)
message
(
STATUS
"Use dynamic onnxruntime libraries"
)
if
(
SHERPA_ONNX_ENABLE_GPU
)
include
(
onnxruntime-win-x64-gpu
)
...
...
sherpa-onnx/csrc/provider.cc
查看文件 @
d32a461
...
...
@@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
return
Provider
::
kNNAPI
;
}
else
if
(
s
==
"trt"
)
{
return
Provider
::
kTRT
;
}
else
if
(
s
==
"directml"
)
{
return
Provider
::
kDirectML
;
}
else
{
SHERPA_ONNX_LOGE
(
"Unsupported string: %s. Fallback to cpu"
,
s
.
c_str
());
return
Provider
::
kCPU
;
...
...
sherpa-onnx/csrc/provider.h
查看文件 @
d32a461
...
...
@@ -20,6 +20,7 @@ enum class Provider {
kXnnpack
=
3
,
// XnnpackExecutionProvider
kNNAPI
=
4
,
// NnapiExecutionProvider
kTRT
=
5
,
// TensorRTExecutionProvider
kDirectML
=
6
,
// DmlExecutionProvider
};
/**
...
...
sherpa-onnx/csrc/session.cc
查看文件 @
d32a461
...
...
@@ -19,6 +19,10 @@
#include "nnapi_provider_factory.h" // NOLINT
#endif
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
#include "dml_provider_factory.h" // NOLINT
#endif
namespace
sherpa_onnx
{
static
void
OrtStatusFailure
(
OrtStatus
*
status
,
const
char
*
s
)
{
...
...
@@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
}
break
;
}
case
Provider
:
:
kDirectML
:
{
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
sess_opts
.
DisableMemPattern
();
sess_opts
.
SetExecutionMode
(
ORT_SEQUENTIAL
);
int32_t
device_id
=
0
;
OrtStatus
*
status
=
OrtSessionOptionsAppendExecutionProvider_DML
(
sess_opts
,
device_id
);
if
(
status
)
{
const
auto
&
api
=
Ort
::
GetApi
();
const
char
*
msg
=
api
.
GetErrorMessage
(
status
);
SHERPA_ONNX_LOGE
(
"Failed to enable DirectML: %s. Fallback to cpu"
,
msg
);
api
.
ReleaseStatus
(
status
);
}
#else
SHERPA_ONNX_LOGE
(
"DirectML is for Windows only. Fallback to cpu!"
);
#endif
break
;
}
case
Provider
:
:
kCoreML
:
{
#if defined(__APPLE__)
uint32_t
coreml_flags
=
0
;
...
...
请
注册
或
登录
后发表评论