mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
This commit is contained in:
parent
c239ff0750
commit
4553b2eecd
5 changed files with 37 additions and 3 deletions
1
cmake/external/dml.cmake
vendored
1
cmake/external/dml.cmake
vendored
|
|
@ -21,6 +21,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
|
|||
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
|
||||
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
|
||||
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)
|
||||
set(DML_SHARED_LIB DirectML.dll)
|
||||
|
||||
# Restore nuget packages, which will pull down the DirectML redist package
|
||||
add_custom_command(
|
||||
|
|
|
|||
|
|
@ -377,6 +377,15 @@ if (onnxruntime_USE_NUPHAR)
|
|||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_DML)
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_pybind11_state POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/${DML_SHARED_LIB}
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi/
|
||||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
|
||||
include(onnxruntime_language_interop_ops.cmake)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -119,7 +119,13 @@ struct OrtStatus {
|
|||
#define BACKEND_ARMNN ""
|
||||
#endif
|
||||
|
||||
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
|
||||
#if USE_DML
|
||||
#define BACKEND_DML "-DML"
|
||||
#else
|
||||
#define BACKEND_DML ""
|
||||
#endif
|
||||
|
||||
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/providers/providers.h"
|
||||
#include "core/providers/cpu/cpu_execution_provider.h"
|
||||
|
|
@ -159,6 +165,9 @@ std::string nuphar_settings;
|
|||
#ifdef USE_ARMNN
|
||||
#include "core/providers/armnn/armnn_provider_factory.h"
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#endif
|
||||
|
||||
#define PYBIND_UNREFERENCED_PARAMETER(parameter) ((void)(parameter))
|
||||
|
||||
|
|
@ -176,6 +185,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -374,7 +384,7 @@ const std::vector<std::string>& GetAllProviders() {
|
|||
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider, kMIGraphXExecutionProvider,
|
||||
kNGraphExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kNupharExecutionProvider, kVitisAIExecutionProvider, kArmNNExecutionProvider,
|
||||
kAclExecutionProvider, kCpuExecutionProvider};
|
||||
kAclExecutionProvider, kDmlExecutionProvider, kCpuExecutionProvider};
|
||||
return all_providers;
|
||||
}
|
||||
|
||||
|
|
@ -572,6 +582,9 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
|
|||
sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN(sess->GetSessionOptions().enable_cpu_mem_arena));
|
||||
#endif
|
||||
} else if (type == kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_DML(0));
|
||||
#endif
|
||||
} else {
|
||||
// unknown provider
|
||||
throw std::runtime_error("Unknown Provider Type: " + type);
|
||||
|
|
@ -721,6 +734,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
|
|||
#endif
|
||||
#ifdef USE_ARMNN
|
||||
onnxruntime::CreateExecutionProviderFactory_ArmNN(0)
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
onnxruntime::CreateExecutionProviderFactory_DML(0)
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
5
setup.py
5
setup.py
|
|
@ -65,6 +65,9 @@ elif '--use_acl' in sys.argv:
|
|||
elif '--use_armnn' in sys.argv:
|
||||
package_name = 'onnxruntime-armnn'
|
||||
sys.argv.remove('--use_armnn')
|
||||
elif '--use_dml' in sys.argv:
|
||||
package_name = 'onnxruntime-dml'
|
||||
sys.argv.remove('--use_dml')
|
||||
|
||||
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
|
||||
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686
|
||||
|
|
@ -188,6 +191,8 @@ else:
|
|||
libs.extend(['onnxruntime_providers_tensorrt.dll'])
|
||||
# nGraph Libs
|
||||
libs.extend(['ngraph.dll', 'cpu_backend.dll', 'tbb.dll', 'mimalloc-override.dll', 'mimalloc-redirect.dll', 'mimalloc-redirect32.dll'])
|
||||
# DirectML Libs
|
||||
libs.extend(['directml.dll'])
|
||||
# Nuphar Libs
|
||||
libs.extend(['tvm.dll'])
|
||||
if nightly_build:
|
||||
|
|
|
|||
|
|
@ -1353,7 +1353,7 @@ def run_nodejs_tests(nodejs_binding_dir):
|
|||
|
||||
def build_python_wheel(
|
||||
source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl,
|
||||
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn,
|
||||
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn, use_dml,
|
||||
wheel_name_suffix, enable_training, nightly_build=False, featurizers_build=False, use_ninja=False):
|
||||
for config in configs:
|
||||
cwd = get_config_build_dir(build_dir, config)
|
||||
|
|
@ -1402,6 +1402,8 @@ def build_python_wheel(
|
|||
args.append('--use_acl')
|
||||
elif use_armnn:
|
||||
args.append('--use_armnn')
|
||||
elif use_dml:
|
||||
args.append('--use_dml')
|
||||
|
||||
run_subprocess(args, cwd=cwd)
|
||||
|
||||
|
|
@ -1794,6 +1796,7 @@ def main():
|
|||
args.use_vitisai,
|
||||
args.use_acl,
|
||||
args.use_armnn,
|
||||
args.use_dml,
|
||||
args.wheel_name_suffix,
|
||||
args.enable_training,
|
||||
nightly_build=nightly_build,
|
||||
|
|
|
|||
Loading…
Reference in a new issue