Expose DirectML provider to python (conflicts resolved from #3359) (#4630)

This commit is contained in:
Cameron Maske 2020-09-08 23:34:09 +02:00 committed by GitHub
parent c239ff0750
commit 4553b2eecd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 3 deletions

View file

@ -21,6 +21,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)
set(DML_SHARED_LIB DirectML.dll)
# Restore nuget packages, which will pull down the DirectML redist package
add_custom_command(

View file

@ -377,6 +377,15 @@ if (onnxruntime_USE_NUPHAR)
)
endif()
if (onnxruntime_USE_DML)
add_custom_command(
TARGET onnxruntime_pybind11_state POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/${DML_SHARED_LIB}
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi/
)
endif()
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
include(onnxruntime_language_interop_ops.cmake)
endif()

View file

@ -119,7 +119,13 @@ struct OrtStatus {
#define BACKEND_ARMNN ""
#endif
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
#if USE_DML
#define BACKEND_DML "-DML"
#else
#define BACKEND_DML ""
#endif
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/providers.h"
#include "core/providers/cpu/cpu_execution_provider.h"
@ -159,6 +165,9 @@ std::string nuphar_settings;
#ifdef USE_ARMNN
#include "core/providers/armnn/armnn_provider_factory.h"
#endif
#ifdef USE_DML
#include "core/providers/dml/dml_provider_factory.h"
#endif
#define PYBIND_UNREFERENCED_PARAMETER(parameter) ((void)(parameter))
@ -176,6 +185,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
} // namespace onnxruntime
#if defined(_MSC_VER)
@ -374,7 +384,7 @@ const std::vector<std::string>& GetAllProviders() {
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider, kMIGraphXExecutionProvider,
kNGraphExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
kNupharExecutionProvider, kVitisAIExecutionProvider, kArmNNExecutionProvider,
kAclExecutionProvider, kCpuExecutionProvider};
kAclExecutionProvider, kDmlExecutionProvider, kCpuExecutionProvider};
return all_providers;
}
@ -572,6 +582,9 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN(sess->GetSessionOptions().enable_cpu_mem_arena));
#endif
} else if (type == kDmlExecutionProvider) {
#ifdef USE_DML
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_DML(0));
#endif
} else {
// unknown provider
throw std::runtime_error("Unknown Provider Type: " + type);
@ -721,6 +734,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
#endif
#ifdef USE_ARMNN
onnxruntime::CreateExecutionProviderFactory_ArmNN(0)
#endif
#ifdef USE_DML
onnxruntime::CreateExecutionProviderFactory_DML(0)
#endif
};

View file

@ -65,6 +65,9 @@ elif '--use_acl' in sys.argv:
elif '--use_armnn' in sys.argv:
package_name = 'onnxruntime-armnn'
sys.argv.remove('--use_armnn')
elif '--use_dml' in sys.argv:
package_name = 'onnxruntime-dml'
sys.argv.remove('--use_dml')
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686
@ -188,6 +191,8 @@ else:
libs.extend(['onnxruntime_providers_tensorrt.dll'])
# nGraph Libs
libs.extend(['ngraph.dll', 'cpu_backend.dll', 'tbb.dll', 'mimalloc-override.dll', 'mimalloc-redirect.dll', 'mimalloc-redirect32.dll'])
# DirectML Libs
libs.extend(['directml.dll'])
# Nuphar Libs
libs.extend(['tvm.dll'])
if nightly_build:

View file

@ -1353,7 +1353,7 @@ def run_nodejs_tests(nodejs_binding_dir):
def build_python_wheel(
source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl,
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn,
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn, use_dml,
wheel_name_suffix, enable_training, nightly_build=False, featurizers_build=False, use_ninja=False):
for config in configs:
cwd = get_config_build_dir(build_dir, config)
@ -1402,6 +1402,8 @@ def build_python_wheel(
args.append('--use_acl')
elif use_armnn:
args.append('--use_armnn')
elif use_dml:
args.append('--use_dml')
run_subprocess(args, cwd=cwd)
@ -1794,6 +1796,7 @@ def main():
args.use_vitisai,
args.use_acl,
args.use_armnn,
args.use_dml,
args.wheel_name_suffix,
args.enable_training,
nightly_build=nightly_build,