Implement DML copy for Lora Adapters (#22396)

### Description
Request and create DML EP and its data transfer.
Use to copy on device.

The PR includes changes to fix issues in DML provider.

### Motivation and Context
This enables Lora users to run it with DML which is important for GenAI.

Co-authored-by: @PatriceVignola

---------

Co-authored-by: Patrice Vignola <vignola.patrice@gmail.com>
This commit is contained in:
Dmitri Smirnov 2024-10-14 12:26:50 -07:00 committed by GitHub
parent 35adba21c7
commit 87e8a5dfa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 137 additions and 79 deletions

View file

@ -186,7 +186,7 @@ namespace Dml
}
else
{
if (!m_context->IsClosed())
if (!m_closed)
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX

View file

@ -46,6 +46,11 @@ namespace Dml
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
void Close()
{
m_closed = true;
}
public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
@ -83,6 +88,7 @@ namespace Dml
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
bool m_closed = false;
// Unless specifically requested, allocation sizes are not rounded to enable pooling
// until SetDefaultRoundingMode is called. This should be done at completion of session

View file

@ -55,7 +55,7 @@ namespace Dml
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
if (!m_closing)
if (!m_clearingQueue)
{
QueuedReference queuedReference = {GetLastFenceValue(), object};
@ -70,15 +70,15 @@ namespace Dml
}
}
void CommandQueue::Close()
void CommandQueue::WaitForSignalAndClearQueue()
{
// Wait for flushed work:
assert(!m_closing);
m_closing = true;
assert(!m_clearingQueue);
m_clearingQueue = true;
GpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal(m_cpuSyncSpinningEnabled);
m_queuedReferences.clear();
m_closing = false;
m_clearingQueue = false;
}
void CommandQueue::ReleaseCompletedReferences()

View file

@ -44,7 +44,7 @@ namespace Dml
}
#endif
void Close();
void WaitForSignalAndClearQueue();
void ReleaseCompletedReferences();
private:
@ -61,7 +61,7 @@ namespace Dml
ComPtr<ID3D12Fence> m_fence;
uint64_t m_lastFenceValue = 0;
bool m_closing = false;
bool m_clearingQueue = false;
bool m_cpuSyncSpinningEnabled = false;
};

View file

@ -11,13 +11,10 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen
)
bool cpuSyncSpinningEnabled)
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
, m_keepOpen(keepOpen)
{
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}
@ -36,8 +33,6 @@ namespace Dml
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
std::vector<D3D12_RESOURCE_BARRIER> barriers;
@ -84,8 +79,6 @@ namespace Dml
_Out_ uint64_t* completionValue
)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
}
@ -95,7 +88,6 @@ namespace Dml
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
@ -107,7 +99,6 @@ namespace Dml
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
@ -115,7 +106,6 @@ namespace Dml
void ExecutionContext::AddUAVBarrier()
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.AddUAVBarrier();
@ -123,7 +113,6 @@ namespace Dml
void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ResourceBarrier(barriers);
@ -131,7 +120,6 @@ namespace Dml
void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
// Ensure the descriptor heap is reset to D3D as something external may change it before recording
@ -142,8 +130,6 @@ namespace Dml
void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
{
assert(!m_closed);
// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
// ordering of operations on the command queue.
if (m_currentRecorder != newRecorder)
@ -160,8 +146,6 @@ namespace Dml
void ExecutionContext::Flush()
{
assert(!m_closed);
if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
{
// Nothing to flush
@ -180,34 +164,21 @@ namespace Dml
void ExecutionContext::QueueReference(IUnknown* object)
{
assert(!m_closed);
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
m_queue->QueueReference(object, waitForUnsubmittedWork);
}
void ExecutionContext::Close()
void ExecutionContext::WaitForSignalAndClearQueue()
{
assert(!m_closed);
// Discard unflushed work and clear queued references. This prevents the circular reference:
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
m_queue->Close();
// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
// and single command queue
if (!m_keepOpen)
{
m_currentRecorder = nullptr;
m_closed = true;
}
m_queue->WaitForSignalAndClearQueue();
}
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
{
assert(!m_closed);
GpuEvent event = m_queue->GetCurrentCompletionEvent();
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
@ -223,13 +194,11 @@ namespace Dml
void ExecutionContext::ReleaseCompletedReferences()
{
assert(!m_closed);
m_queue->ReleaseCompletedReferences();
}
D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
{
assert(!m_closed);
return m_queue->GetType();
}

View file

@ -23,14 +23,13 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen);
bool cpuSyncSpinningEnabled);
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// prevent circular references.
void WaitForSignalAndClearQueue();
// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
@ -87,7 +86,6 @@ namespace Dml
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
bool IsClosed() const { return m_closed; }
private:
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
@ -103,10 +101,6 @@ namespace Dml
bool m_closed = false;
bool m_cpuSyncSpinningEnabled = false;
// The python API has a global state used for I/O binding where the execution context is shared between session,
// so we don't want to close the context when one of the sessions is destroyed
bool m_keepOpen = false;
};
} // namespace Dml

View file

@ -106,7 +106,26 @@ namespace Dml
// Release the cached command list references before closing the context
m_capturedGraphs.clear();
m_context->Close();
// Close the allocator before clearing the command queue to stop it from
// appending resources to it in an attempt to keep them alive.
if (m_allocator)
{
m_allocator->Close();
}
// Destroy the allocators. We are closing the execution provider, so from now on the
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
// require allocating any memory.
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
m_allocator = nullptr;
m_cpuInputAllocator = nullptr;
// Wait for all pending commands to be done executing and empty the command queue. This will
// Force all kernels and resources in flight to get destroyed and, from this point forward,
// ExecutionProviderImpl will only be used to execute transfer between resources that are
// already existing via the DataTransfer;
m_context->WaitForSignalAndClearQueue();
}
void ExecutionProviderImpl::WaitForOutstandingWork()

View file

@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
// First, check if an I/O binding API that was used before this session or another session has already created a queue
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
}
} else {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
}
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);

View file

@ -4,10 +4,9 @@
#include "core/session/lora_adapters.h"
#include "lora/adapter_format_utils.h"
#include <unordered_map>
#include "core/framework/data_transfer.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/allocator_adapters.h"
#include "core/session/ort_apis.h"
@ -16,6 +15,15 @@
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
#ifdef USE_DML
#include "core/session/abi_session_options_impl.h"
#include "core/providers/dml/dml_provider_factory_creator.h"
#include "core/providers/dml/dml_provider_factory.h"
#endif
#include <functional>
#include <unordered_map>
namespace onnxruntime {
#ifdef USE_CUDA
@ -50,28 +58,56 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) {
InitializeParamsValues();
}
static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_info) {
namespace {
struct DataTransfer {
std::unique_ptr<IExecutionProvider> ep;
std::unique_ptr<IDataTransfer> data_transfer;
if (strcmp(mem_info.name, onnxruntime::CPU) == 0) {
return data_transfer;
Status CopyTensor(const Tensor& src, Tensor& dst) const {
return data_transfer->CopyTensor(src, dst);
}
Status Sync() const {
#if USE_DML
return ep->Sync();
#else
return Status::OK();
#endif
}
};
} // namespace
static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) {
ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter");
Status status;
if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) {
#ifdef USE_CUDA
auto* cuda_provider_info = TryGetProviderInfo_CUDA();
if (cuda_provider_info != nullptr) {
data_transfer = cuda_provider_info->CreateGPUDataTransfer();
dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer();
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded");
}
#else
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build");
#endif
} else if (strcmp(mem_info.name, onnxruntime::DML) == 0) {
#ifdef USE_DML
auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false);
dt.ep = ep_factory->CreateProvider();
dt.data_transfer = dt.ep->GetDataTransfer();
#else
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build");
#endif
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator");
}
return data_transfer;
return status;
}
static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped,
const AllocatorPtr& device_allocator,
const IDataTransfer& data_transfer,
const DataTransfer& data_transfer,
OrtValue& out) {
OrtValue result;
const auto& src = ort_value_mapped.Get<Tensor>();
@ -87,12 +123,9 @@ void LoraAdapter::InitializeParamsValues() {
ORT_THROW("Adapter is not loaded yet.");
}
std::unique_ptr<IDataTransfer> data_transfer;
DataTransfer data_transfer;
if (device_allocator_) {
data_transfer = GetDataTransfer(device_allocator_->Info());
if (data_transfer == nullptr) {
ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator");
}
ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer));
}
const auto* params = adapter_->parameters();
@ -100,12 +133,12 @@ void LoraAdapter::InitializeParamsValues() {
std::unordered_map<std::string, Param> params_values;
params_values.reserve(params->size());
// Re-work in two separate loops due to compiler issues
if (data_transfer) {
if (device_allocator_) {
for (const auto* param : *params) {
auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param);
OrtValue ort_value_ondevice;
ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_,
*data_transfer, ort_value_ondevice));
data_transfer, ort_value_ondevice));
Param lora_param(std::move(ort_value), std::move(ort_value_ondevice));
params_values.emplace(std::move(name), std::move(lora_param));
}
@ -117,6 +150,10 @@ void LoraAdapter::InitializeParamsValues() {
}
}
if (device_allocator_) {
ORT_THROW_IF_ERROR(data_transfer.Sync());
}
params_values_.swap(params_values);
}

View file

@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) {
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get());
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get()));
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true);
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get()));
}

View file

@ -200,13 +200,11 @@ TEST(LoraAdapterTest, Load) {
}
#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyDeviceCopy) {
TEST(LoraAdapterTest, VerifyCudaDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
auto cuda_ep = DefaultCudaExecutionProvider();
auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0];
auto gpu_transfer = cuda_ep->GetDataTransfer();
auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0];
auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer();
auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(cuda_allocator));
@ -222,9 +220,9 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy));
ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy));
auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();
@ -233,5 +231,40 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
}
}
#endif
#ifdef USE_DML
TEST(LoraAdapterTest, VerifyDmlDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0];
auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer();
auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(dml_allocator));
adapter.Load(std::move(test_params));
auto [begin, end] = adapter.GetParamIterators();
for (; begin != end; ++begin) {
const auto& [_, param] = *begin;
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML));
const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy));
auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();
ASSERT_EQ(expected_span, copy_span);
}
}
#endif
} // namespace test
} // namespace onnxruntime