mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[DML EP] Decouple the bucketized allocator from the individual block allocation logic (#14056)
### Description Decouple the DML bucketized allocator from the individual block allocation logic ### Motivation and Context This is the first step into using tiled/placed resources instead of committed resources. Given the potential impact of changing the allocation logic and the large number of edge cases, I decided to take a step-by-step approach. It will also reduce the size of the PRs to a reasonable length, while making sure each PR has a single responsibility. Decoupling the logic that way will make it easier in the future to easily plug in different kind of "suballocators" if we want to play around with the allocation logic. Currently, the only suballocator is a committed resource, but placed resources are the next step and will come in a future PR.
This commit is contained in:
parent
f344d4b3d1
commit
b6ea60436d
8 changed files with 148 additions and 54 deletions
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/session/onnxruntime_c_api.h"
|
||||
|
||||
#include "BucketizedBufferAllocator.h"
|
||||
#include "DmlSubAllocator.h"
|
||||
// #define PRINT_OUTSTANDING_ALLOCATIONS
|
||||
|
||||
namespace Dml
|
||||
|
|
@ -39,7 +40,8 @@ namespace Dml
|
|||
const D3D12_HEAP_PROPERTIES& heapProps,
|
||||
D3D12_HEAP_FLAGS heapFlags,
|
||||
D3D12_RESOURCE_FLAGS resourceFlags,
|
||||
D3D12_RESOURCE_STATES initialState
|
||||
D3D12_RESOURCE_STATES initialState,
|
||||
std::unique_ptr<DmlSubAllocator>&& subAllocator
|
||||
)
|
||||
: onnxruntime::IAllocator(
|
||||
OrtMemoryInfo(
|
||||
|
|
@ -53,7 +55,8 @@ namespace Dml
|
|||
m_heapFlags(heapFlags),
|
||||
m_resourceFlags(resourceFlags),
|
||||
m_initialState(initialState),
|
||||
m_context(context)
|
||||
m_context(context),
|
||||
m_subAllocator(std::move(subAllocator))
|
||||
{
|
||||
}
|
||||
|
||||
|
|
@ -86,8 +89,8 @@ namespace Dml
|
|||
{
|
||||
// For some reason lotus likes requesting 0 bytes of memory
|
||||
size = std::max<size_t>(1, size);
|
||||
|
||||
ComPtr<ID3D12Resource> resource;
|
||||
|
||||
ComPtr<DmlResourceWrapper> resourceWrapper;
|
||||
uint64_t resourceId = 0;
|
||||
uint64_t bucketSize = 0;
|
||||
|
||||
|
|
@ -98,7 +101,7 @@ namespace Dml
|
|||
|
||||
// Find the bucket for this allocation size
|
||||
gsl::index bucketIndex = GetBucketIndexFromSize(size);
|
||||
|
||||
|
||||
if (gsl::narrow_cast<gsl::index>(m_pool.size()) <= bucketIndex)
|
||||
{
|
||||
// Ensure there are sufficient buckets
|
||||
|
|
@ -107,26 +110,17 @@ namespace Dml
|
|||
|
||||
bucket = &m_pool[bucketIndex];
|
||||
bucketSize = GetBucketSizeFromIndex(bucketIndex);
|
||||
|
||||
|
||||
if (bucket->resources.empty())
|
||||
{
|
||||
// No more resources in this bucket - allocate a new one
|
||||
auto buffer = CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags);
|
||||
ORT_THROW_IF_FAILED(m_device->CreateCommittedResource(
|
||||
&m_heapProperties,
|
||||
m_heapFlags,
|
||||
&buffer,
|
||||
m_initialState,
|
||||
nullptr,
|
||||
IID_GRAPHICS_PPV_ARGS(resource.ReleaseAndGetAddressOf())
|
||||
));
|
||||
|
||||
resourceWrapper = m_subAllocator->Alloc(bucketSize);
|
||||
resourceId = ++m_currentResourceId;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Retrieve a resource from the bucket
|
||||
resource = std::move(bucket->resources.back().resource);
|
||||
resourceWrapper = std::move(bucket->resources.back().resource);
|
||||
resourceId = bucket->resources.back().resourceId;
|
||||
bucket->resources.pop_back();
|
||||
}
|
||||
|
|
@ -135,28 +129,18 @@ namespace Dml
|
|||
{
|
||||
// The allocation will not be pooled. Construct a new one
|
||||
bucketSize = (size + 3) & ~3;
|
||||
|
||||
auto buffer = CD3DX12_RESOURCE_DESC::Buffer(bucketSize, m_resourceFlags);
|
||||
ORT_THROW_IF_FAILED(m_device->CreateCommittedResource(
|
||||
&m_heapProperties,
|
||||
m_heapFlags,
|
||||
&buffer,
|
||||
m_initialState,
|
||||
nullptr,
|
||||
IID_GRAPHICS_PPV_ARGS(resource.ReleaseAndGetAddressOf())
|
||||
));
|
||||
|
||||
resourceWrapper = m_subAllocator->Alloc(bucketSize);
|
||||
resourceId = ++m_currentResourceId;
|
||||
}
|
||||
|
||||
assert(resource->GetDesc().Width == bucketSize);
|
||||
assert(resource != nullptr);
|
||||
assert(resourceWrapper->GetD3D12Resource()->GetDesc().Width == bucketSize);
|
||||
assert(resourceWrapper != nullptr);
|
||||
|
||||
ComPtr<AllocationInfo> allocInfo = wil::MakeOrThrow<AllocationInfo>(
|
||||
this,
|
||||
++m_currentAllocationId,
|
||||
resourceId,
|
||||
resource.Get(),
|
||||
resourceWrapper.Get(),
|
||||
size
|
||||
);
|
||||
|
||||
|
|
@ -196,8 +180,8 @@ namespace Dml
|
|||
|
||||
// Return the resource to the bucket
|
||||
Bucket* bucket = &m_pool[bucketIndex];
|
||||
|
||||
Resource resource = {allocInfo->DetachResource(), pooledResourceId};
|
||||
|
||||
Resource resource = {allocInfo->DetachResourceWrapper(), pooledResourceId};
|
||||
bucket->resources.push_back(resource);
|
||||
}
|
||||
else
|
||||
|
|
@ -208,14 +192,14 @@ namespace Dml
|
|||
#else
|
||||
m_context->QueueReference(allocInfo->GetResource());
|
||||
#endif
|
||||
allocInfo->DetachResource();
|
||||
allocInfo->DetachResourceWrapper();
|
||||
}
|
||||
|
||||
#if _DEBUG
|
||||
assert(m_outstandingAllocationsById[allocInfo->GetId()] == allocInfo);
|
||||
m_outstandingAllocationsById.erase(allocInfo->GetId());
|
||||
#endif
|
||||
|
||||
|
||||
// The allocation info is already destructing at this point
|
||||
}
|
||||
|
||||
|
|
@ -239,7 +223,7 @@ namespace Dml
|
|||
|
||||
return allocInfo;
|
||||
}
|
||||
|
||||
|
||||
void BucketizedBufferAllocator::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
|
||||
{
|
||||
m_defaultRoundingMode = roundingMode;
|
||||
|
|
|
|||
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
#include "core/framework/allocator.h"
|
||||
#include "ExecutionContext.h"
|
||||
#include "DmlResourceWrapper.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlSubAllocator;
|
||||
|
||||
class CPUAllocator : public onnxruntime::IAllocator
|
||||
{
|
||||
public:
|
||||
|
|
@ -28,42 +30,42 @@ namespace Dml
|
|||
BucketizedBufferAllocator* owner,
|
||||
size_t id,
|
||||
uint64_t pooledResourceId,
|
||||
ID3D12Resource* resource,
|
||||
DmlResourceWrapper* resourceWrapper,
|
||||
size_t requestedSize)
|
||||
: m_owner(owner)
|
||||
, m_allocationId(id)
|
||||
, m_pooledResourceId(pooledResourceId)
|
||||
, m_resource(resource)
|
||||
, m_resourceWrapper(resourceWrapper)
|
||||
, m_requestedSize(requestedSize)
|
||||
{}
|
||||
|
||||
~AllocationInfo();
|
||||
|
||||
BucketizedBufferAllocator* GetOwner() const
|
||||
{
|
||||
{
|
||||
return m_owner;
|
||||
}
|
||||
|
||||
ID3D12Resource* GetResource() const
|
||||
{
|
||||
return m_resource.Get();
|
||||
{
|
||||
return m_resourceWrapper->GetD3D12Resource();
|
||||
}
|
||||
|
||||
ComPtr<ID3D12Resource> DetachResource() const
|
||||
{
|
||||
return std::move(m_resource);
|
||||
ComPtr<DmlResourceWrapper> DetachResourceWrapper() const
|
||||
{
|
||||
return std::move(m_resourceWrapper);
|
||||
}
|
||||
|
||||
size_t GetRequestedSize() const
|
||||
{
|
||||
{
|
||||
return m_requestedSize;
|
||||
}
|
||||
|
||||
size_t GetId() const
|
||||
{
|
||||
return m_allocationId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
uint64_t GetPooledResourceId() const
|
||||
{
|
||||
return m_pooledResourceId;
|
||||
|
|
@ -73,7 +75,7 @@ namespace Dml
|
|||
BucketizedBufferAllocator* m_owner;
|
||||
size_t m_allocationId; // For debugging purposes
|
||||
uint64_t m_pooledResourceId = 0;
|
||||
ComPtr<ID3D12Resource> m_resource;
|
||||
ComPtr<DmlResourceWrapper> m_resourceWrapper;
|
||||
|
||||
// The size requested during Alloc(), which may be smaller than the physical resource size
|
||||
size_t m_requestedSize;
|
||||
|
|
@ -91,12 +93,13 @@ namespace Dml
|
|||
// Constructs a BucketizedBufferAllocator which allocates D3D12 committed resources with the specified heap properties,
|
||||
// resource flags, and initial resource state.
|
||||
BucketizedBufferAllocator(
|
||||
ID3D12Device* device,
|
||||
ID3D12Device* device,
|
||||
std::shared_ptr<ExecutionContext> context,
|
||||
const D3D12_HEAP_PROPERTIES& heapProps,
|
||||
D3D12_HEAP_FLAGS heapFlags,
|
||||
D3D12_RESOURCE_FLAGS resourceFlags,
|
||||
D3D12_RESOURCE_STATES initialState);
|
||||
D3D12_RESOURCE_STATES initialState,
|
||||
std::unique_ptr<DmlSubAllocator>&& subAllocator);
|
||||
|
||||
// Returns the information associated with an opaque allocation handle returned by IAllocator::Alloc.
|
||||
const AllocationInfo* DecodeDataHandle(const void* opaqueHandle);
|
||||
|
|
@ -116,7 +119,7 @@ namespace Dml
|
|||
// as large as the previous bucket.
|
||||
struct Resource
|
||||
{
|
||||
ComPtr<ID3D12Resource> resource;
|
||||
ComPtr<DmlResourceWrapper> resource;
|
||||
uint64_t resourceId;
|
||||
};
|
||||
|
||||
|
|
@ -148,6 +151,7 @@ namespace Dml
|
|||
uint64_t m_currentResourceId = 0;
|
||||
AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Enabled;
|
||||
std::shared_ptr<ExecutionContext> m_context;
|
||||
std::unique_ptr<DmlSubAllocator> m_subAllocator;
|
||||
|
||||
#if _DEBUG
|
||||
// Useful for debugging; keeps track of all allocations that haven't been freed yet
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "DmlCommittedResourceAllocator.h"
|
||||
#include "DmlResourceWrapper.h"
|
||||
#include "DmlCommittedResourceWrapper.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
ComPtr<DmlResourceWrapper> DmlCommittedResourceAllocator::Alloc(size_t size)
|
||||
{
|
||||
ComPtr<ID3D12Resource> resource;
|
||||
auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS);
|
||||
ORT_THROW_IF_FAILED(m_device->CreateCommittedResource(
|
||||
&CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT),
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
&buffer,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
nullptr,
|
||||
IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf())
|
||||
));
|
||||
|
||||
ComPtr<DmlResourceWrapper> resourceWrapper;
|
||||
wil::MakeOrThrow<DmlCommittedResourceWrapper>(std::move(resource)).As(&resourceWrapper);
|
||||
return resourceWrapper;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "DmlSubAllocator.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
struct DmlResourceWrapper;
|
||||
|
||||
class DmlCommittedResourceAllocator : public DmlSubAllocator
|
||||
{
|
||||
public:
|
||||
DmlCommittedResourceAllocator(ID3D12Device* device) : m_device(device) {}
|
||||
Microsoft::WRL::ComPtr<DmlResourceWrapper> Alloc(size_t size) final;
|
||||
|
||||
private:
|
||||
ID3D12Device* m_device = nullptr;
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "DmlResourceWrapper.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class DmlCommittedResourceWrapper : public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, DmlResourceWrapper>
|
||||
{
|
||||
public:
|
||||
DmlCommittedResourceWrapper(ComPtr<ID3D12Resource>&& d3d12Resource) : m_d3d12Resource(std::move(d3d12Resource)) {}
|
||||
ID3D12Resource* GetD3D12Resource() const final { return m_d3d12Resource.Get(); }
|
||||
|
||||
private:
|
||||
ComPtr<ID3D12Resource> m_d3d12Resource;
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ID3D12Resource;
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
interface __declspec(uuid("d430f6f1-5c43-48d1-97e6-f080cc7fa0c5"))
|
||||
DmlResourceWrapper : public IUnknown
|
||||
{
|
||||
public:
|
||||
virtual ID3D12Resource* GetD3D12Resource() const = 0;
|
||||
virtual ~DmlResourceWrapper(){}
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
struct DmlResourceWrapper;
|
||||
|
||||
class DmlSubAllocator
|
||||
{
|
||||
public:
|
||||
virtual Microsoft::WRL::ComPtr<DmlResourceWrapper> Alloc(size_t size) = 0;
|
||||
virtual ~DmlSubAllocator(){}
|
||||
};
|
||||
}
|
||||
|
|
@ -17,6 +17,8 @@
|
|||
#include "core/graph/indexed_sub_graph.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/fallback_cpu_capability.h"
|
||||
#include "DmlCommittedResourceAllocator.h"
|
||||
#include "DmlCommittedResourceWrapper.h"
|
||||
|
||||
#ifdef ERROR
|
||||
#undef ERROR
|
||||
|
|
@ -184,7 +186,8 @@ namespace Dml
|
|||
CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT),
|
||||
D3D12_HEAP_FLAG_NONE,
|
||||
D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS,
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
|
||||
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
|
||||
std::make_unique<DmlCommittedResourceAllocator>(m_d3d12Device.Get()));
|
||||
|
||||
m_context->SetAllocator(m_allocator);
|
||||
|
||||
|
|
@ -1007,7 +1010,11 @@ namespace Dml
|
|||
void* CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource)
|
||||
{
|
||||
uint64_t pooledResourceId = 0; // Not a pooled resource
|
||||
ComPtr<AllocationInfo> allocInfo = wil::MakeOrThrow<AllocationInfo>(nullptr, 0, pooledResourceId, pResource, (size_t)pResource->GetDesc().Width);
|
||||
|
||||
ComPtr<DmlResourceWrapper> resourceWrapper;
|
||||
wil::MakeOrThrow<DmlCommittedResourceWrapper>(pResource).As(&resourceWrapper);
|
||||
|
||||
ComPtr<AllocationInfo> allocInfo = wil::MakeOrThrow<AllocationInfo>(nullptr, 0, pooledResourceId, resourceWrapper.Get(), (size_t)pResource->GetDesc().Width);
|
||||
return allocInfo.Detach();
|
||||
}
|
||||
void FreeGPUAllocation(void* ptr)
|
||||
|
|
|
|||
Loading…
Reference in a new issue