mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Distinguish between DML and the generic 'GPU' term. This is needed for packaging DML EP in the same ORT GPU pkg. (#22657)
### Description Distinguish between DML and the generic 'GPU' term. This is needed for packaging DML EP in the same ORT GPU pkg. ### Motivation and Context Customer requirement.
This commit is contained in:
parent
df236c7894
commit
03ea5dc495
11 changed files with 97 additions and 71 deletions
|
|
@ -17,6 +17,7 @@ struct OrtDevice {
|
|||
static const DeviceType GPU = 1; // Nvidia or AMD
|
||||
static const DeviceType FPGA = 2;
|
||||
static const DeviceType NPU = 3; // Ascend
|
||||
static const DeviceType DML = 4;
|
||||
|
||||
struct MemType {
|
||||
// Pre-defined memory types.
|
||||
|
|
|
|||
|
|
@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
|
|||
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
|
||||
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
|
||||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
|
||||
strcmp(name1, onnxruntime::DML) == 0 ||
|
||||
strcmp(name1, onnxruntime::HIP) == 0 ||
|
||||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
|
||||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
|
||||
*out = new OrtMemoryInfo(
|
||||
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
||||
mem_type1);
|
||||
} else if (strcmp(name1, onnxruntime::DML) == 0) {
|
||||
*out = new OrtMemoryInfo(
|
||||
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
||||
mem_type1);
|
||||
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
|
||||
*out = new OrtMemoryInfo(
|
||||
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
|
||||
|
|
|
|||
|
|
@ -41,23 +41,19 @@ namespace Dml
|
|||
D3D12_HEAP_FLAGS heapFlags,
|
||||
D3D12_RESOURCE_FLAGS resourceFlags,
|
||||
D3D12_RESOURCE_STATES initialState,
|
||||
std::unique_ptr<DmlSubAllocator>&& subAllocator
|
||||
)
|
||||
std::unique_ptr<DmlSubAllocator>&& subAllocator)
|
||||
: onnxruntime::IAllocator(
|
||||
OrtMemoryInfo(
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
|
||||
)
|
||||
),
|
||||
m_device(device),
|
||||
m_heapProperties(heapProps),
|
||||
m_heapFlags(heapFlags),
|
||||
m_resourceFlags(resourceFlags),
|
||||
m_initialState(initialState),
|
||||
m_context(context),
|
||||
m_subAllocator(std::move(subAllocator))
|
||||
{
|
||||
OrtMemoryInfo(
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))),
|
||||
m_device(device),
|
||||
m_heapProperties(heapProps),
|
||||
m_heapFlags(heapFlags),
|
||||
m_resourceFlags(resourceFlags),
|
||||
m_initialState(initialState),
|
||||
m_context(context),
|
||||
m_subAllocator(std::move(subAllocator)) {
|
||||
}
|
||||
|
||||
/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)
|
||||
|
|
|
|||
|
|
@ -20,15 +20,13 @@ namespace Dml
|
|||
class DmlExternalBufferAllocator : public onnxruntime::IAllocator
|
||||
{
|
||||
public:
|
||||
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
|
||||
OrtMemoryInfo(
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
|
||||
))
|
||||
{
|
||||
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
|
||||
}
|
||||
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
|
||||
OrtMemoryInfo(
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) {
|
||||
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
|
||||
}
|
||||
|
||||
void* Alloc(size_t size) final
|
||||
{
|
||||
|
|
|
|||
|
|
@ -73,20 +73,17 @@ namespace Dml
|
|||
bool enableMetacommands,
|
||||
bool enableGraphCapture,
|
||||
bool enableSyncSpinning,
|
||||
bool disableMemoryArena) :
|
||||
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
|
||||
{
|
||||
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
|
||||
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
|
||||
{
|
||||
// DML requires either DIRECT or COMPUTE command queues.
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) {
|
||||
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
|
||||
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) {
|
||||
// DML requires either DIRECT or COMPUTE command queues.
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
||||
ComPtr<ID3D12Device> device;
|
||||
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
|
||||
ComPtr<ID3D12Device> device;
|
||||
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
|
||||
|
||||
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
|
||||
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
|
|
|
|||
|
|
@ -242,8 +242,8 @@ namespace Dml
|
|||
|
||||
bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
|
||||
{
|
||||
return (srcDevice.Type() == OrtDevice::GPU) ||
|
||||
(dstDevice.Type() == OrtDevice::GPU);
|
||||
return (srcDevice.Type() == OrtDevice::DML) ||
|
||||
(dstDevice.Type() == OrtDevice::DML);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -1662,6 +1662,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function is called when the session is being initialized.
|
||||
// For now, this function only checks for invalid combination of DML EP with other EPs.
|
||||
// TODO: extend this function to check for other invalid combinations of EPs.
|
||||
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
|
||||
// DML EP is only allowed with CPU EP
|
||||
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
|
||||
if (has_dml_ep) {
|
||||
const auto& ep_list = execution_providers_.GetIds();
|
||||
for (const auto& ep : ep_list) {
|
||||
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
|
||||
|
|
@ -1719,6 +1736,11 @@ common::Status InferenceSession::Initialize() {
|
|||
execution_providers_.SetCpuProviderWasImplicitlyAdded(true);
|
||||
}
|
||||
|
||||
// Check for the presence of an invalid combination of execution providers in the session
|
||||
// For e.g. we don't support DML EP and other GPU EPs to be present in the same session
|
||||
// This check is placed here because it serves as a common place for all language bindings.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders());
|
||||
|
||||
// re-acquire mutex
|
||||
std::lock_guard<std::mutex> l(session_mutex_);
|
||||
|
||||
|
|
|
|||
|
|
@ -620,7 +620,7 @@ class InferenceSession {
|
|||
const Environment& session_env);
|
||||
void ConstructorCommon(const SessionOptions& session_options,
|
||||
const Environment& session_env);
|
||||
|
||||
[[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const;
|
||||
[[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
|
|||
|
||||
const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
|
||||
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
|
||||
{OrtDevice::GPU, DmlToCpuMemCpy}};
|
||||
{OrtDevice::DML, DmlToCpuMemCpy}};
|
||||
|
||||
return ↦
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
|
||||
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
|
||||
#elif USE_DML
|
||||
// InputDeflist is null because OrtValue creation is not tied to a specific model
|
||||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
|
||||
CreateGenericMLValue(
|
||||
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
|
||||
"Please use the CUDA package of OnnxRuntime to use this feature.");
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
|
||||
"Please use the CUDA package of OnnxRuntime to use this feature.");
|
||||
#endif
|
||||
} else if (device.Type() == OrtDevice::DML) {
|
||||
#if USE_DML
|
||||
// InputDeflist is null because OrtValue creation is not tied to a specific model
|
||||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
|
||||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
|
||||
CreateGenericMLValue(
|
||||
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
|
||||
"Please use the CUDA package of OnnxRuntime to use this feature.");
|
||||
#endif
|
||||
} else if (device.Type() == OrtDevice::NPU) {
|
||||
#ifdef USE_CANN
|
||||
|
|
@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
|
||||
true, false, CpuToCannMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
|
||||
"Please use the CANN package of OnnxRuntime to use this feature.");
|
||||
throw std::runtime_error(
|
||||
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
|
||||
"Please use the CANN package of OnnxRuntime to use this feature.");
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
|
||||
|
|
@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) {
|
|||
}
|
||||
|
||||
onnxruntime::python::CopyDataToTensor(
|
||||
py_values,
|
||||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToRocmMemCpy);
|
||||
#elif USE_DML
|
||||
onnxruntime::python::CopyDataToTensor(
|
||||
py_values,
|
||||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToDmlMemCpy);
|
||||
py_values,
|
||||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToRocmMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Unsupported GPU device: Cannot find the supported GPU device.");
|
||||
throw std::runtime_error(
|
||||
"Unsupported GPU device: Cannot find the supported GPU device.");
|
||||
#endif
|
||||
} else if (device.Type() == OrtDevice::DML) {
|
||||
#if USE_DML
|
||||
onnxruntime::python::CopyDataToTensor(
|
||||
py_values,
|
||||
values_type,
|
||||
*(ml_value->GetMutable<Tensor>()),
|
||||
CpuToDmlMemCpy);
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Unsupported GPU device: Cannot find the supported GPU device.");
|
||||
#endif
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");
|
||||
|
|
|
|||
|
|
@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
|
|||
case OrtDevice::CPU:
|
||||
return CPU;
|
||||
case OrtDevice::GPU:
|
||||
#ifdef USE_DML
|
||||
return DML;
|
||||
#else
|
||||
return CUDA;
|
||||
#endif
|
||||
case OrtDevice::DML:
|
||||
return DML;
|
||||
case OrtDevice::FPGA:
|
||||
return "FPGA";
|
||||
case OrtDevice::NPU:
|
||||
|
|
@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
|
|||
.def_static("cann", []() { return OrtDevice::NPU; })
|
||||
.def_static("fpga", []() { return OrtDevice::FPGA; })
|
||||
.def_static("npu", []() { return OrtDevice::NPU; })
|
||||
.def_static("dml", []() { return OrtDevice::GPU; })
|
||||
.def_static("dml", []() { return OrtDevice::DML; })
|
||||
.def_static("webgpu", []() { return OrtDevice::GPU; })
|
||||
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue