mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
Implement DML copy for Lora Adapters (#22396)
### Description Request and create DML EP and its data transfer. Use to copy on device. The PR includes changes to fix issues in DML provider. ### Motivation and Context This enables Lora users to run it with DML which is important for GenAI. Co-authored-by: @PatriceVignola --------- Co-authored-by: Patrice Vignola <vignola.patrice@gmail.com>
This commit is contained in:
parent
35adba21c7
commit
87e8a5dfa8
11 changed files with 137 additions and 79 deletions
|
|
@ -186,7 +186,7 @@ namespace Dml
|
|||
}
|
||||
else
|
||||
{
|
||||
if (!m_context->IsClosed())
|
||||
if (!m_closed)
|
||||
{
|
||||
// Free the underlying allocation once queued work has completed.
|
||||
#ifdef _GAMING_XBOX
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@ namespace Dml
|
|||
|
||||
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
|
||||
|
||||
void Close()
|
||||
{
|
||||
m_closed = true;
|
||||
}
|
||||
|
||||
public: // onnxruntime::IAllocator
|
||||
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
|
||||
void* Alloc(size_t size) final;
|
||||
|
|
@ -83,6 +88,7 @@ namespace Dml
|
|||
std::vector<Bucket> m_pool;
|
||||
size_t m_currentAllocationId = 0;
|
||||
uint64_t m_currentResourceId = 0;
|
||||
bool m_closed = false;
|
||||
|
||||
// Unless specifically requested, allocation sizes are not rounded to enable pooling
|
||||
// until SetDefaultRoundingMode is called. This should be done at completion of session
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ namespace Dml
|
|||
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
|
||||
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
|
||||
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
|
||||
if (!m_closing)
|
||||
if (!m_clearingQueue)
|
||||
{
|
||||
QueuedReference queuedReference = {GetLastFenceValue(), object};
|
||||
|
||||
|
|
@ -70,15 +70,15 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
void CommandQueue::Close()
|
||||
void CommandQueue::WaitForSignalAndClearQueue()
|
||||
{
|
||||
// Wait for flushed work:
|
||||
assert(!m_closing);
|
||||
m_closing = true;
|
||||
assert(!m_clearingQueue);
|
||||
m_clearingQueue = true;
|
||||
GpuEvent event = GetCurrentCompletionEvent();
|
||||
event.WaitForSignal(m_cpuSyncSpinningEnabled);
|
||||
m_queuedReferences.clear();
|
||||
m_closing = false;
|
||||
m_clearingQueue = false;
|
||||
}
|
||||
|
||||
void CommandQueue::ReleaseCompletedReferences()
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ namespace Dml
|
|||
}
|
||||
#endif
|
||||
|
||||
void Close();
|
||||
void WaitForSignalAndClearQueue();
|
||||
void ReleaseCompletedReferences();
|
||||
|
||||
private:
|
||||
|
|
@ -61,7 +61,7 @@ namespace Dml
|
|||
|
||||
ComPtr<ID3D12Fence> m_fence;
|
||||
uint64_t m_lastFenceValue = 0;
|
||||
bool m_closing = false;
|
||||
bool m_clearingQueue = false;
|
||||
bool m_cpuSyncSpinningEnabled = false;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,10 @@ namespace Dml
|
|||
ID3D12Device* d3d12Device,
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* queue,
|
||||
bool cpuSyncSpinningEnabled,
|
||||
bool keepOpen
|
||||
)
|
||||
bool cpuSyncSpinningEnabled)
|
||||
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
|
||||
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
|
||||
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
|
||||
, m_keepOpen(keepOpen)
|
||||
{
|
||||
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
|
||||
}
|
||||
|
|
@ -36,8 +33,6 @@ namespace Dml
|
|||
D3D12_RESOURCE_STATES srcState,
|
||||
uint64_t byteCount)
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
std::vector<D3D12_RESOURCE_BARRIER> barriers;
|
||||
|
|
@ -84,8 +79,6 @@ namespace Dml
|
|||
_Out_ uint64_t* completionValue
|
||||
)
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
|
||||
}
|
||||
|
|
@ -95,7 +88,6 @@ namespace Dml
|
|||
const DML_BINDING_DESC& persistentResourceBinding,
|
||||
const DML_BINDING_DESC& inputArrayBinding)
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
|
||||
|
|
@ -107,7 +99,6 @@ namespace Dml
|
|||
gsl::span<const DML_BINDING_DESC> inputBindings,
|
||||
gsl::span<const DML_BINDING_DESC> outputBindings)
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
|
||||
|
|
@ -115,7 +106,6 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::AddUAVBarrier()
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
m_dmlRecorder.AddUAVBarrier();
|
||||
|
|
@ -123,7 +113,6 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
m_dmlRecorder.ResourceBarrier(barriers);
|
||||
|
|
@ -131,7 +120,6 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
|
||||
{
|
||||
assert(!m_closed);
|
||||
SetCommandRecorder(&m_dmlRecorder);
|
||||
|
||||
// Ensure the descriptor heap is reset to D3D as something external may change it before recording
|
||||
|
|
@ -142,8 +130,6 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
|
||||
// ordering of operations on the command queue.
|
||||
if (m_currentRecorder != newRecorder)
|
||||
|
|
@ -160,8 +146,6 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::Flush()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
|
||||
{
|
||||
// Nothing to flush
|
||||
|
|
@ -180,34 +164,21 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::QueueReference(IUnknown* object)
|
||||
{
|
||||
assert(!m_closed);
|
||||
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
|
||||
// value is the one to signal completion.
|
||||
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
|
||||
m_queue->QueueReference(object, waitForUnsubmittedWork);
|
||||
}
|
||||
|
||||
void ExecutionContext::Close()
|
||||
void ExecutionContext::WaitForSignalAndClearQueue()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
// Discard unflushed work and clear queued references. This prevents the circular reference:
|
||||
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
|
||||
m_queue->Close();
|
||||
|
||||
// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
|
||||
// and single command queue
|
||||
if (!m_keepOpen)
|
||||
{
|
||||
m_currentRecorder = nullptr;
|
||||
m_closed = true;
|
||||
}
|
||||
m_queue->WaitForSignalAndClearQueue();
|
||||
}
|
||||
|
||||
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
GpuEvent event = m_queue->GetCurrentCompletionEvent();
|
||||
|
||||
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
|
||||
|
|
@ -223,13 +194,11 @@ namespace Dml
|
|||
|
||||
void ExecutionContext::ReleaseCompletedReferences()
|
||||
{
|
||||
assert(!m_closed);
|
||||
m_queue->ReleaseCompletedReferences();
|
||||
}
|
||||
|
||||
D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
|
||||
{
|
||||
assert(!m_closed);
|
||||
return m_queue->GetType();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,14 +23,13 @@ namespace Dml
|
|||
ID3D12Device* d3d12Device,
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* queue,
|
||||
bool cpuSyncSpinningEnabled,
|
||||
bool keepOpen);
|
||||
bool cpuSyncSpinningEnabled);
|
||||
|
||||
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
|
||||
|
||||
// Waits for flushed work, discards unflushed work, and discards associated references to
|
||||
// prevent circular references. Must be the last call on the object before destruction.
|
||||
void Close();
|
||||
// prevent circular references.
|
||||
void WaitForSignalAndClearQueue();
|
||||
|
||||
// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
|
||||
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
|
||||
|
|
@ -87,7 +86,6 @@ namespace Dml
|
|||
|
||||
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
|
||||
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
|
||||
bool IsClosed() const { return m_closed; }
|
||||
|
||||
private:
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
|
||||
|
|
@ -103,10 +101,6 @@ namespace Dml
|
|||
|
||||
bool m_closed = false;
|
||||
bool m_cpuSyncSpinningEnabled = false;
|
||||
|
||||
// The python API has a global state used for I/O binding where the execution context is shared between session,
|
||||
// so we don't want to close the context when one of the sessions is destroyed
|
||||
bool m_keepOpen = false;
|
||||
};
|
||||
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -106,7 +106,26 @@ namespace Dml
|
|||
// Release the cached command list references before closing the context
|
||||
m_capturedGraphs.clear();
|
||||
|
||||
m_context->Close();
|
||||
// Close the allocator before clearing the command queue to stop it from
|
||||
// appending resources to it in an attempt to keep them alive.
|
||||
if (m_allocator)
|
||||
{
|
||||
m_allocator->Close();
|
||||
}
|
||||
|
||||
// Destroy the allocators. We are closing the execution provider, so from now on the
|
||||
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
|
||||
// require allocating any memory.
|
||||
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly
|
||||
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
|
||||
m_allocator = nullptr;
|
||||
m_cpuInputAllocator = nullptr;
|
||||
|
||||
// Wait for all pending commands to be done executing and empty the command queue. This will
|
||||
// Force all kernels and resources in flight to get destroyed and, from this point forward,
|
||||
// ExecutionProviderImpl will only be used to execute transfer between resources that are
|
||||
// already existing via the DataTransfer;
|
||||
m_context->WaitForSignalAndClearQueue();
|
||||
}
|
||||
|
||||
void ExecutionProviderImpl::WaitForOutstandingWork()
|
||||
|
|
|
|||
|
|
@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
|
|||
|
||||
// First, check if an I/O binding API that was used before this session or another session has already created a queue
|
||||
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
|
||||
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
|
||||
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
|
||||
}
|
||||
} else {
|
||||
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
|
||||
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
|
||||
}
|
||||
|
||||
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@
|
|||
#include "core/session/lora_adapters.h"
|
||||
#include "lora/adapter_format_utils.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "core/framework/data_transfer.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/allocator_adapters.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
|
|
@ -16,6 +15,15 @@
|
|||
#include "core/providers/cuda/cuda_provider_factory.h"
|
||||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/providers/dml/dml_provider_factory_creator.h"
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
#endif
|
||||
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -50,28 +58,56 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) {
|
|||
InitializeParamsValues();
|
||||
}
|
||||
|
||||
static std::unique_ptr<IDataTransfer> GetDataTransfer(const OrtMemoryInfo& mem_info) {
|
||||
namespace {
|
||||
struct DataTransfer {
|
||||
std::unique_ptr<IExecutionProvider> ep;
|
||||
std::unique_ptr<IDataTransfer> data_transfer;
|
||||
|
||||
if (strcmp(mem_info.name, onnxruntime::CPU) == 0) {
|
||||
return data_transfer;
|
||||
Status CopyTensor(const Tensor& src, Tensor& dst) const {
|
||||
return data_transfer->CopyTensor(src, dst);
|
||||
}
|
||||
Status Sync() const {
|
||||
#if USE_DML
|
||||
return ep->Sync();
|
||||
#else
|
||||
return Status::OK();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static Status GetDataTransfer(const OrtMemoryInfo& mem_info, [[maybe_unused]] DataTransfer& dt) {
|
||||
ORT_RETURN_IF(strcmp(mem_info.name, onnxruntime::CPU) == 0, "Expecting on device allocator for LoraAdapter");
|
||||
|
||||
Status status;
|
||||
if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) {
|
||||
#ifdef USE_CUDA
|
||||
auto* cuda_provider_info = TryGetProviderInfo_CUDA();
|
||||
if (cuda_provider_info != nullptr) {
|
||||
data_transfer = cuda_provider_info->CreateGPUDataTransfer();
|
||||
dt.data_transfer = cuda_provider_info->CreateGPUDataTransfer();
|
||||
} else {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider could not be loaded");
|
||||
}
|
||||
#else
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA provider is not enabled in this build");
|
||||
#endif
|
||||
} else if (strcmp(mem_info.name, onnxruntime::DML) == 0) {
|
||||
#ifdef USE_DML
|
||||
auto ep_factory = onnxruntime::DMLProviderFactoryCreator::Create(ConfigOptions{}, 0, false, false, false);
|
||||
dt.ep = ep_factory->CreateProvider();
|
||||
dt.data_transfer = dt.ep->GetDataTransfer();
|
||||
#else
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "DML provider is not enabled in this build");
|
||||
#endif
|
||||
} else {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported device allocator");
|
||||
}
|
||||
|
||||
return data_transfer;
|
||||
return status;
|
||||
}
|
||||
|
||||
static Status CreateOrtValueOnDevice(const OrtValue& ort_value_mapped,
|
||||
const AllocatorPtr& device_allocator,
|
||||
const IDataTransfer& data_transfer,
|
||||
const DataTransfer& data_transfer,
|
||||
OrtValue& out) {
|
||||
OrtValue result;
|
||||
const auto& src = ort_value_mapped.Get<Tensor>();
|
||||
|
|
@ -87,12 +123,9 @@ void LoraAdapter::InitializeParamsValues() {
|
|||
ORT_THROW("Adapter is not loaded yet.");
|
||||
}
|
||||
|
||||
std::unique_ptr<IDataTransfer> data_transfer;
|
||||
DataTransfer data_transfer;
|
||||
if (device_allocator_) {
|
||||
data_transfer = GetDataTransfer(device_allocator_->Info());
|
||||
if (data_transfer == nullptr) {
|
||||
ORT_THROW("Data transfer is not available for the specified device allocator, it also must not be a CPU allocator");
|
||||
}
|
||||
ORT_THROW_IF_ERROR(GetDataTransfer(device_allocator_->Info(), data_transfer));
|
||||
}
|
||||
|
||||
const auto* params = adapter_->parameters();
|
||||
|
|
@ -100,12 +133,12 @@ void LoraAdapter::InitializeParamsValues() {
|
|||
std::unordered_map<std::string, Param> params_values;
|
||||
params_values.reserve(params->size());
|
||||
// Re-work in two separate loops due to compiler issues
|
||||
if (data_transfer) {
|
||||
if (device_allocator_) {
|
||||
for (const auto* param : *params) {
|
||||
auto [name, ort_value] = adapters::utils::CreateOrtValueOverLoraParameter(*param);
|
||||
OrtValue ort_value_ondevice;
|
||||
ORT_THROW_IF_ERROR(CreateOrtValueOnDevice(ort_value, device_allocator_,
|
||||
*data_transfer, ort_value_ondevice));
|
||||
data_transfer, ort_value_ondevice));
|
||||
Param lora_param(std::move(ort_value), std::move(ort_value_ondevice));
|
||||
params_values.emplace(std::move(name), std::move(lora_param));
|
||||
}
|
||||
|
|
@ -117,6 +150,10 @@ void LoraAdapter::InitializeParamsValues() {
|
|||
}
|
||||
}
|
||||
|
||||
if (device_allocator_) {
|
||||
ORT_THROW_IF_ERROR(data_transfer.Sync());
|
||||
}
|
||||
|
||||
params_values_.swap(params_values);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) {
|
|||
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get());
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get()));
|
||||
|
||||
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true);
|
||||
context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true);
|
||||
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get()));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -200,13 +200,11 @@ TEST(LoraAdapterTest, Load) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
TEST(LoraAdapterTest, VerifyDeviceCopy) {
|
||||
TEST(LoraAdapterTest, VerifyCudaDeviceCopy) {
|
||||
auto cpu_ep = DefaultCpuExecutionProvider();
|
||||
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
|
||||
auto cuda_ep = DefaultCudaExecutionProvider();
|
||||
auto cuda_allocator = cuda_ep->CreatePreferredAllocators()[0];
|
||||
|
||||
auto gpu_transfer = cuda_ep->GetDataTransfer();
|
||||
auto cuda_allocator = DefaultCudaExecutionProvider()->CreatePreferredAllocators()[0];
|
||||
auto cuda_transfer = DefaultCudaExecutionProvider()->GetDataTransfer();
|
||||
|
||||
auto test_params = GenerateTestParameters<float>()();
|
||||
lora::LoraAdapter adapter(std::move(cuda_allocator));
|
||||
|
|
@ -222,9 +220,9 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
|
|||
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
|
||||
|
||||
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
|
||||
ASSERT_TRUE(gpu_transfer->CanCopy(tensor_device.Location().device,
|
||||
copy.Location().device));
|
||||
ASSERT_STATUS_OK(gpu_transfer->CopyTensor(tensor_device, copy));
|
||||
ASSERT_TRUE(cuda_transfer->CanCopy(tensor_device.Location().device,
|
||||
copy.Location().device));
|
||||
ASSERT_STATUS_OK(cuda_transfer->CopyTensor(tensor_device, copy));
|
||||
|
||||
auto expected_span = tensor_cpu.DataAsSpan<float>();
|
||||
auto copy_span = copy.DataAsSpan<float>();
|
||||
|
|
@ -233,5 +231,40 @@ TEST(LoraAdapterTest, VerifyDeviceCopy) {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_DML
|
||||
TEST(LoraAdapterTest, VerifyDmlDeviceCopy) {
|
||||
auto cpu_ep = DefaultCpuExecutionProvider();
|
||||
auto cpu_allocator = cpu_ep->CreatePreferredAllocators()[0];
|
||||
|
||||
auto dml_allocator = DefaultDmlExecutionProvider()->CreatePreferredAllocators()[0];
|
||||
auto dml_transfer = DefaultDmlExecutionProvider()->GetDataTransfer();
|
||||
|
||||
auto test_params = GenerateTestParameters<float>()();
|
||||
lora::LoraAdapter adapter(std::move(dml_allocator));
|
||||
adapter.Load(std::move(test_params));
|
||||
|
||||
auto [begin, end] = adapter.GetParamIterators();
|
||||
for (; begin != end; ++begin) {
|
||||
const auto& [_, param] = *begin;
|
||||
const auto& tensor_device = param.GetDeviceOrMapped().Get<Tensor>();
|
||||
ASSERT_EQ(0, strcmp(tensor_device.Location().name, onnxruntime::DML));
|
||||
|
||||
const auto& tensor_cpu = param.GetMapped().Get<Tensor>();
|
||||
ASSERT_EQ(tensor_cpu.Shape().Size(), tensor_device.Shape().Size());
|
||||
|
||||
Tensor copy(tensor_cpu.DataType(), tensor_cpu.Shape(), cpu_allocator);
|
||||
ASSERT_TRUE(dml_transfer->CanCopy(tensor_device.Location().device,
|
||||
copy.Location().device));
|
||||
ASSERT_STATUS_OK(dml_transfer->CopyTensor(tensor_device, copy));
|
||||
|
||||
auto expected_span = tensor_cpu.DataAsSpan<float>();
|
||||
auto copy_span = copy.DataAsSpan<float>();
|
||||
|
||||
ASSERT_EQ(expected_span, copy_span);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue