diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 8a02a358f6..588c4ac391 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -6,6 +6,7 @@ #include "core/session/onnxruntime_c_api.h" #include "BucketizedBufferAllocator.h" +#include "DmlSubAllocator.h" // #define PRINT_OUTSTANDING_ALLOCATIONS namespace Dml @@ -39,7 +40,8 @@ namespace Dml const D3D12_HEAP_PROPERTIES& heapProps, D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, - D3D12_RESOURCE_STATES initialState + D3D12_RESOURCE_STATES initialState, + std::unique_ptr&& subAllocator ) : onnxruntime::IAllocator( OrtMemoryInfo( @@ -53,7 +55,8 @@ namespace Dml m_heapFlags(heapFlags), m_resourceFlags(resourceFlags), m_initialState(initialState), - m_context(context) + m_context(context), + m_subAllocator(std::move(subAllocator)) { } @@ -86,8 +89,8 @@ namespace Dml { // For some reason lotus likes requesting 0 bytes of memory size = std::max(1, size); - - ComPtr resource; + + ComPtr resourceWrapper; uint64_t resourceId = 0; uint64_t bucketSize = 0; @@ -98,7 +101,7 @@ namespace Dml // Find the bucket for this allocation size gsl::index bucketIndex = GetBucketIndexFromSize(size); - + if (gsl::narrow_cast(m_pool.size()) <= bucketIndex) { // Ensure there are sufficient buckets @@ -107,26 +110,17 @@ namespace Dml bucket = &m_pool[bucketIndex]; bucketSize = GetBucketSizeFromIndex(bucketIndex); - + if (bucket->resources.empty()) { // No more resources in this bucket - allocate a new one - auto buffer = CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags); - ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( - &m_heapProperties, - m_heapFlags, - &buffer, - m_initialState, - nullptr, - IID_GRAPHICS_PPV_ARGS(resource.ReleaseAndGetAddressOf()) - )); - + resourceWrapper = m_subAllocator->Alloc(bucketSize); resourceId = ++m_currentResourceId; } else { // Retrieve a resource from the bucket - resource = std::move(bucket->resources.back().resource); + resourceWrapper = std::move(bucket->resources.back().resource); resourceId = bucket->resources.back().resourceId; bucket->resources.pop_back(); } @@ -135,28 +129,18 @@ namespace Dml { // The allocation will not be pooled. Construct a new one bucketSize = (size + 3) & ~3; - - auto buffer = CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags); - ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( - &m_heapProperties, - m_heapFlags, - &buffer, - m_initialState, - nullptr, - IID_GRAPHICS_PPV_ARGS(resource.ReleaseAndGetAddressOf()) - )); - + resourceWrapper = m_subAllocator->Alloc(bucketSize); resourceId = ++m_currentResourceId; } - assert(resource->GetDesc().Width == bucketSize); - assert(resource != nullptr); + assert(resourceWrapper->GetD3D12Resource()->GetDesc().Width == bucketSize); + assert(resourceWrapper != nullptr); ComPtr allocInfo = wil::MakeOrThrow( this, ++m_currentAllocationId, resourceId, - resource.Get(), + resourceWrapper.Get(), size ); @@ -196,8 +180,8 @@ namespace Dml // Return the resource to the bucket Bucket* bucket = &m_pool[bucketIndex]; - - Resource resource = {allocInfo->DetachResource(), pooledResourceId}; + + Resource resource = {allocInfo->DetachResourceWrapper(), pooledResourceId}; bucket->resources.push_back(resource); } else @@ -208,14 +192,14 @@ namespace Dml #else m_context->QueueReference(allocInfo->GetResource()); #endif - allocInfo->DetachResource(); + allocInfo->DetachResourceWrapper(); } #if _DEBUG assert(m_outstandingAllocationsById[allocInfo->GetId()] == allocInfo); m_outstandingAllocationsById.erase(allocInfo->GetId()); #endif - + // The allocation info is already destructing at this point } @@ -239,7 +223,7 @@ namespace Dml return allocInfo; } - + void BucketizedBufferAllocator::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode) { m_defaultRoundingMode = roundingMode; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 3ca7cbe527..7e3471e276 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -5,10 +5,12 @@ #include "core/framework/allocator.h" #include "ExecutionContext.h" +#include "DmlResourceWrapper.h" namespace Dml { - + class DmlSubAllocator; + class CPUAllocator : public onnxruntime::IAllocator { public: @@ -28,42 +30,42 @@ namespace Dml BucketizedBufferAllocator* owner, size_t id, uint64_t pooledResourceId, - ID3D12Resource* resource, + DmlResourceWrapper* resourceWrapper, size_t requestedSize) : m_owner(owner) , m_allocationId(id) , m_pooledResourceId(pooledResourceId) - , m_resource(resource) + , m_resourceWrapper(resourceWrapper) , m_requestedSize(requestedSize) {} ~AllocationInfo(); BucketizedBufferAllocator* GetOwner() const - { + { return m_owner; } ID3D12Resource* GetResource() const - { - return m_resource.Get(); + { + return m_resourceWrapper->GetD3D12Resource(); } - ComPtr DetachResource() const - { - return std::move(m_resource); + ComPtr DetachResourceWrapper() const + { + return std::move(m_resourceWrapper); } size_t GetRequestedSize() const - { + { return m_requestedSize; } size_t GetId() const { return m_allocationId; - } - + } + uint64_t GetPooledResourceId() const { return m_pooledResourceId; @@ -73,7 +75,7 @@ namespace Dml BucketizedBufferAllocator* m_owner; size_t m_allocationId; // For debugging purposes uint64_t m_pooledResourceId = 0; - ComPtr m_resource; + ComPtr m_resourceWrapper; // The size requested during Alloc(), which may be smaller than the physical resource size size_t m_requestedSize; @@ -91,12 +93,13 @@ namespace Dml // Constructs a BucketizedBufferAllocator which allocates D3D12 committed resources with the specified heap properties, // resource flags, and initial resource state. BucketizedBufferAllocator( - ID3D12Device* device, + ID3D12Device* device, std::shared_ptr context, const D3D12_HEAP_PROPERTIES& heapProps, D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, - D3D12_RESOURCE_STATES initialState); + D3D12_RESOURCE_STATES initialState, + std::unique_ptr&& subAllocator); // Returns the information associated with an opaque allocation handle returned by IAllocator::Alloc. const AllocationInfo* DecodeDataHandle(const void* opaqueHandle); @@ -116,7 +119,7 @@ namespace Dml // as large as the previous bucket. struct Resource { - ComPtr resource; + ComPtr resource; uint64_t resourceId; }; @@ -148,6 +151,7 @@ namespace Dml uint64_t m_currentResourceId = 0; AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Enabled; std::shared_ptr m_context; + std::unique_ptr m_subAllocator; #if _DEBUG // Useful for debugging; keeps track of all allocations that haven't been freed yet diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp new file mode 100644 index 0000000000..d9bfdc3473 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "DmlCommittedResourceAllocator.h" +#include "DmlResourceWrapper.h" +#include "DmlCommittedResourceWrapper.h" + +namespace Dml +{ + ComPtr DmlCommittedResourceAllocator::Alloc(size_t size) + { + ComPtr resource; + auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &buffer, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) + )); + + ComPtr resourceWrapper; + wil::MakeOrThrow(std::move(resource)).As(&resourceWrapper); + return resourceWrapper; + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h new file mode 100644 index 0000000000..7ad48be32a --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "DmlSubAllocator.h" + +namespace Dml +{ + struct DmlResourceWrapper; + + class DmlCommittedResourceAllocator : public DmlSubAllocator + { + public: + DmlCommittedResourceAllocator(ID3D12Device* device) : m_device(device) {} + Microsoft::WRL::ComPtr Alloc(size_t size) final; + + private: + ID3D12Device* m_device = nullptr; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceWrapper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceWrapper.h new file mode 100644 index 0000000000..cae206b569 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceWrapper.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "DmlResourceWrapper.h" + +namespace Dml +{ + class DmlCommittedResourceWrapper : public Microsoft::WRL::RuntimeClass, DmlResourceWrapper> + { + public: + DmlCommittedResourceWrapper(ComPtr&& d3d12Resource) : m_d3d12Resource(std::move(d3d12Resource)) {} + ID3D12Resource* GetD3D12Resource() const final { return m_d3d12Resource.Get(); } + + private: + ComPtr m_d3d12Resource; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlResourceWrapper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlResourceWrapper.h new file mode 100644 index 0000000000..876487242a --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlResourceWrapper.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +struct ID3D12Resource; + +namespace Dml +{ + interface __declspec(uuid("d430f6f1-5c43-48d1-97e6-f080cc7fa0c5")) + DmlResourceWrapper : public IUnknown + { + public: + virtual ID3D12Resource* GetD3D12Resource() const = 0; + virtual ~DmlResourceWrapper(){} + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h new file mode 100644 index 0000000000..cfdaf17710 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Dml +{ + struct DmlResourceWrapper; + + class DmlSubAllocator + { + public: + virtual Microsoft::WRL::ComPtr Alloc(size_t size) = 0; + virtual ~DmlSubAllocator(){} + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 09314c4db3..3ae8e14831 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -17,6 +17,8 @@ #include "core/graph/indexed_sub_graph.h" #include "core/framework/compute_capability.h" #include "core/framework/fallback_cpu_capability.h" +#include "DmlCommittedResourceAllocator.h" +#include "DmlCommittedResourceWrapper.h" #ifdef ERROR #undef ERROR @@ -184,7 +186,8 @@ namespace Dml CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), D3D12_HEAP_FLAG_NONE, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + std::make_unique(m_d3d12Device.Get())); m_context->SetAllocator(m_allocator); @@ -1007,7 +1010,11 @@ namespace Dml void* CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) { uint64_t pooledResourceId = 0; // Not a pooled resource - ComPtr allocInfo = wil::MakeOrThrow(nullptr, 0, pooledResourceId, pResource, (size_t)pResource->GetDesc().Width); + + ComPtr resourceWrapper; + wil::MakeOrThrow(pResource).As(&resourceWrapper); + + ComPtr allocInfo = wil::MakeOrThrow(nullptr, 0, pooledResourceId, resourceWrapper.Get(), (size_t)pResource->GetDesc().Width); return allocInfo.Detach(); } void FreeGPUAllocation(void* ptr)