Revert Implement DML copy for Lora Adapters (#22814)

Revert https://github.com/microsoft/onnxruntime/pull/22396
This commit is contained in:
Xiang Zhang 2024-11-12 17:45:59 -05:00 committed by GitHub
parent 7fa69461fd
commit 69a36eb231
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 78 additions and 157 deletions

View file

@ -182,7 +182,7 @@ namespace Dml
}
else
{
if (!m_closed)
if (!m_context->IsClosed())
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX

View file

@ -46,11 +46,6 @@ namespace Dml
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
void Close()
{
m_closed = true;
}
public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
@ -88,7 +83,6 @@ namespace Dml
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
bool m_closed = false;
// Unless specifically requested, allocation sizes are not rounded to enable pooling
// until SetDefaultRoundingMode is called. This should be done at completion of session

View file

@ -55,7 +55,7 @@ namespace Dml
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
if (!m_clearingQueue)
if (!m_closing)
{
QueuedReference queuedReference = {GetLastFenceValue(), object};
@ -70,15 +70,15 @@ namespace Dml
}
}
void CommandQueue::WaitForSignalAndClearQueue()
void CommandQueue::Close()
{
// Wait for flushed work:
assert(!m_clearingQueue);
m_clearingQueue = true;
assert(!m_closing);
m_closing = true;
GpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal(m_cpuSyncSpinningEnabled);
m_queuedReferences.clear();
m_clearingQueue = false;
m_closing = false;
}
void CommandQueue::ReleaseCompletedReferences()

View file

@ -44,7 +44,7 @@ namespace Dml
}
#endif
void WaitForSignalAndClearQueue();
void Close();
void ReleaseCompletedReferences();
private:
@ -61,7 +61,7 @@ namespace Dml
ComPtr<ID3D12Fence> m_fence;
uint64_t m_lastFenceValue = 0;
bool m_clearingQueue = false;
bool m_closing = false;
bool m_cpuSyncSpinningEnabled = false;
};

View file

@ -11,10 +11,13 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled)
bool cpuSyncSpinningEnabled,
bool keepOpen
)
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
, m_keepOpen(keepOpen)
{
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}
@ -33,6 +36,8 @@ namespace Dml
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
std::vector<D3D12_RESOURCE_BARRIER> barriers;
@ -79,6 +84,8 @@ namespace Dml
_Out_ uint64_t* completionValue
)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
}
@ -88,6 +95,7 @@ namespace Dml
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
@ -99,6 +107,7 @@ namespace Dml
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
@ -106,6 +115,7 @@ namespace Dml
void ExecutionContext::AddUAVBarrier()
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.AddUAVBarrier();
@ -113,6 +123,7 @@ namespace Dml
void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ResourceBarrier(barriers);
@ -120,6 +131,7 @@ namespace Dml
void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);
// Ensure the descriptor heap is reset to D3D as something external may change it before recording
@ -130,6 +142,8 @@ namespace Dml
void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
{
assert(!m_closed);
// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
// ordering of operations on the command queue.
if (m_currentRecorder != newRecorder)
@ -146,6 +160,8 @@ namespace Dml
void ExecutionContext::Flush()
{
assert(!m_closed);
if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
{
// Nothing to flush
@ -164,21 +180,34 @@ namespace Dml
void ExecutionContext::QueueReference(IUnknown* object)
{
assert(!m_closed);
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
m_queue->QueueReference(object, waitForUnsubmittedWork);
}
void ExecutionContext::WaitForSignalAndClearQueue()
void ExecutionContext::Close()
{
assert(!m_closed);
// Discard unflushed work and clear queued references. This prevents the circular reference:
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
m_queue->WaitForSignalAndClearQueue();
m_queue->Close();
// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
// and single command queue
if (!m_keepOpen)
{
m_currentRecorder = nullptr;
m_closed = true;
}
}
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
{
assert(!m_closed);
GpuEvent event = m_queue->GetCurrentCompletionEvent();
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
@ -194,11 +223,13 @@ namespace Dml
void ExecutionContext::ReleaseCompletedReferences()
{
assert(!m_closed);
m_queue->ReleaseCompletedReferences();
}
D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
{
assert(!m_closed);
return m_queue->GetType();
}

View file

@ -23,13 +23,14 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled);
bool cpuSyncSpinningEnabled,
bool keepOpen);
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references.
void WaitForSignalAndClearQueue();
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
@ -86,6 +87,7 @@ namespace Dml
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
bool IsClosed() const { return m_closed; }
private:
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
@ -101,6 +103,10 @@ namespace Dml
bool m_closed = false;
bool m_cpuSyncSpinningEnabled = false;
// The python API has a global state used for I/O binding where the execution context is shared between session,
// so we don't want to close the context when one of the sessions is destroyed
bool m_keepOpen = false;
};
} // namespace Dml

View file

@ -106,26 +106,7 @@ namespace Dml
// Release the cached command list references before closing the context
m_capturedGraphs.clear();
// Close the allocator before clearing the command queue to stop it from
// appending resources to it in an attempt to keep them alive.
if (m_allocator)
{
m_allocator->Close();
}
// Destroy the allocators. We are closing the execution provider, so from now on the
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
// require allocating any memory.
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
m_allocator = nullptr;
m_cpuInputAllocator = nullptr;
// Wait for all pending commands to be done executing and empty the command queue. This will
// Force all kernels and resources in flight to get destroyed and, from this point forward,
// ExecutionProviderImpl will only be used to execute transfer between resources that are
// already existing via the DataTransfer;
m_context->WaitForSignalAndClearQueue();
m_context->Close();
}
void ExecutionProviderImpl::WaitForOutstandingWork()

View file

@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
// First, check if an I/O binding API that was used before this session or another session has already created a queue
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
}
} else {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
}
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);

View file

@ -4,9 +4,10 @@
#include "core/session/lora_adapters.h"
#include "lora/adapter_format_utils.h"
#include <unordered_map>
#include "core/framework/data_transfer.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/session/allocator_adapters.h"
#include "core/session/ort_apis.h"
@ -15,15 +16,6 @@
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
#ifdef USE_DML
#include "core/session/abi_session_options_impl.h"
#include "core/providers/dml/dml_provider_factory_creator.h"
#include "core/providers/dml/dml_provider_factory.h"
#endif
#include <functional>
#include <unordered_map>
namespace onnxruntime {
#ifdef USE_CUDA
@ -58,58 +50,28 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) {
InitializeParamsValues();
}
namespace {
struct DataTransfer {
std::unique_ptr<IExecutionProvider> ep;
static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_info) {
std::unique_ptr<IDataTransfer> data_transfer;
bool is_dml = false;
Status CopyTensor(const Tensor& src, Tensor& dst) const {
return data_transfer->CopyTensor(src, dst);
}
Status Sync() const {
if (is_dml) {
return ep->Sync();
} else {
return Status::OK();
}
}
};
} // namespace
static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) {
ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter");
if (strcmp(mem_info.name, onnxruntime::CPU) == 0) {
return data_transfer;
}
Status status;
if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) {
#ifdef USE_CUDA
auto* cuda_provider_info = TryGetProviderInfo_CUDA();
if (cuda_provider_info != nullptr) {
dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer();
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded");
data_transfer = cuda_provider_info->CreateGPUDataTransfer();
}
#else
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build");
#endif
} else if (strcmp(mem_info.name, onnxruntime::DML) == 0) {
#ifdef USE_DML
auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false);
dt.ep = ep_factory->CreateProvider();
dt.is_dml = true;
dt.data_transfer = dt.ep->GetDataTransfer();
#else
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build");
#endif
} else {
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator");
}
return status;
return data_transfer;
}
static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped,
const AllocatorPtr& device_allocator,
const DataTransfer& data_transfer,
const IDataTransfer& data_transfer,
OrtValue& out) {
OrtValue result;
const auto& src = ort_value_mapped.Get<Tensor>();
@ -125,9 +87,12 @@ void LoraAdapter::InitializeParamsValues() {
ORT_THROW("Adapter is not loaded yet.");
}
DataTransfer data_transfer;
std::unique_ptr<IDataTransfer> data_transfer;
if (device_allocator_) {
ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer));
data_transfer = GetDataTransfer(device_allocator_->Info());
if (data_transfer == nullptr) {
ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator");
}
}
const auto* params = adapter_->parameters();
@ -135,12 +100,12 @@ void LoraAdapter::InitializeParamsValues() {
std::unordered_map<std::string, Param> params_values;
params_values.reserve(params->size());
// Re-work in two separate loops due to compiler issues
if (device_allocator_) {
if (data_transfer) {
for (const auto* param : *params) {
auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param);
OrtValue ort_value_ondevice;
ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_,
data_transfer, ort_value_ondevice));
*data_transfer, ort_value_ondevice));
Param lora_param(std::move(ort_value), std::move(ort_value_ondevice));
params_values.emplace(std::move(name), std::move(lora_param));
}
@ -152,10 +117,6 @@ void LoraAdapter::InitializeParamsValues() {
}
}
if (device_allocator_) {
ORT_THROW_IF_ERROR(data_transfer.Sync());
}
params_values_.swap(params_values);
}

View file

@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) {
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get());
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get()));
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true);
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get()));
}

View file

@ -200,19 +200,13 @@ TEST(LoraAdapterTest, Load) {
}
#ifdef USE_CUDA
TEST(LoraAdapterTest, VerifyCudaDeviceCopy) {
if (DefaultCudaExecutionProvider() == nullptr) {
GTEST_SKIP() << "Skip This Test Due to this EP is null";
}
#ifdef USE_DML
if (DefaultDmlExecutionProvider() != nullptr) {
GTEST_FAIL() << "It should not run with DML EP";
}
#endif
TEST(LoraAdapterTest, VerifyDeviceCopy) {
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0];
auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer();
auto cuda_ep = DefaultCudaExecutionProvider();
auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0];
auto gpu_transfer = cuda_ep->GetDataTransfer();
auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(cuda_allocator));
@ -228,54 +222,9 @@ TEST(LoraAdapterTest, VerifyCudaDeviceCopy) {
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy));
auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();
ASSERT_EQ(expected_span, copy_span);
}
}
#endif
#ifdef USE_DML
TEST(LoraAdapterTest, VerifyDmlDeviceCopy) {
// NO_DML_TEST is set, DML test is skipped
if (DefaultDmlExecutionProvider() == nullptr) {
GTEST_SKIP() << "Skip This Test Due to this EP is null";
}
#ifdef USE_CUDA
if (DefaultCudaExecutionProvider() != nullptr) {
GTEST_FAIL() << "It should not run with CUDA EP";
}
#endif
auto cpu_ep = DefaultCpuExecutionProvider();
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0];
auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer();
auto test_params = GenerateTestParameters<float>()();
lora::LoraAdapter adapter(std::move(dml_allocator));
adapter.Load(std::move(test_params));
auto [begin, end] = adapter.GetParamIterators();
for (; begin != end; ++begin) {
const auto& [_, param] = *begin;
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML));
const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device,
ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device,
copy.Location().device));
ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy));
ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy));
auto expected_span = tensor_cpu.DataAsSpan<float>();
auto copy_span = copy.DataAsSpan<float>();
@ -284,6 +233,5 @@ TEST(LoraAdapterTest, VerifyDmlDeviceCopy) {
}
}
#endif
} // namespace test
} // namespace onnxruntime