Distinguish between DML and the generic 'GPU' term. This is needed for packaging DML EP in the same ORT GPU pkg. (#22657)

### Description
Distinguish between DML and the generic 'GPU' term. This is needed for
packaging DML EP in the same ORT GPU pkg.

### Motivation and Context
Customer requirement.
This commit is contained in:
Pranav Sharma 2024-10-30 11:58:34 -07:00 committed by GitHub
parent df236c7894
commit 03ea5dc495
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 97 additions and 71 deletions

View file

@ -17,6 +17,7 @@ struct OrtDevice {
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend
static const DeviceType DML = 4;
struct MemType {
// Pre-defined memory types.

View file

@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
strcmp(name1, onnxruntime::DML) == 0 ||
strcmp(name1, onnxruntime::HIP) == 0 ||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,

View file

@ -41,23 +41,19 @@ namespace Dml
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState,
std::unique_ptr<DmlSubAllocator>&& subAllocator
)
std::unique_ptr<DmlSubAllocator>&& subAllocator)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
)
),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator))
{
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator)) {
}
/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)

View file

@ -20,15 +20,13 @@ namespace Dml
class DmlExternalBufferAllocator : public onnxruntime::IAllocator
{
public:
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
))
{
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) {
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}
void* Alloc(size_t size) final
{

View file

@ -73,20 +73,17 @@ namespace Dml
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
{
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}
bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) {
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) {
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>

View file

@ -242,8 +242,8 @@ namespace Dml
bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
return (srcDevice.Type() == OrtDevice::DML) ||
(dstDevice.Type() == OrtDevice::DML);
}
private:

View file

@ -1662,6 +1662,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
}
}
}
// This function is called when the session is being initialized.
// For now, this function only checks for invalid combination of DML EP with other EPs.
// TODO: extend this function to check for other invalid combinations of EPs.
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
// DML EP is only allowed with CPU EP
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
if (has_dml_ep) {
const auto& ep_list = execution_providers_.GetIds();
for (const auto& ep : ep_list) {
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
}
}
return Status::OK();
}
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
@ -1719,6 +1736,11 @@ common::Status InferenceSession::Initialize() {
execution_providers_.SetCpuProviderWasImplicitlyAdded(true);
}
// Check for the presence of an invalid combination of execution providers in the session
// For e.g. we don't support DML EP and other GPU EPs to be present in the same session
// This check is placed here because it serves as a common place for all language bindings.
ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders());
// re-acquire mutex
std::lock_guard<std::mutex> l(session_mutex_);

View file

@ -620,7 +620,7 @@ class InferenceSession {
const Environment& session_env);
void ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env);
[[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const;
[[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model);
#if !defined(ORT_MINIMAL_BUILD)

View file

@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
{OrtDevice::GPU, DmlToCpuMemCpy}};
{OrtDevice::DML, DmlToCpuMemCpy}};
return &map;
}

View file

@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) {
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
#elif USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::NPU) {
#ifdef USE_CANN
@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) {
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCannMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) {
}
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#elif USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");

View file

@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
#ifdef USE_DML
return DML;
#else
return CUDA;
#endif
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::GPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });