diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index db45908a2d..b1714a8220 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -223,28 +223,4 @@ namespace Dml { m_defaultRoundingMode = roundingMode; } - - CPUAllocator::CPUAllocator(OrtMemType memType) - : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML CPU", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), - 0, - memType - ) - ) - { - } - - void* CPUAllocator::Alloc(size_t size) - { - return onnxruntime::AllocatorDefaultAlloc(size); - } - - void CPUAllocator::Free(void* p) - { - return onnxruntime::AllocatorDefaultFree(p); - } - } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 043853ccae..cb6fc165a9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -239,7 +239,9 @@ namespace Dml std::make_unique(m_d3d12Device.Get())); m_context->SetAllocator(m_allocator); // CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators. - m_cpuInputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUInput); + OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator); + memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput; + m_cpuInputAllocator = std::make_shared(memoryInfo); } return std::vector{m_allocator, m_cpuInputAllocator,}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 29961288a5..c20969250f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -25,7 +25,6 @@ namespace Dml class ReadbackHeap; class ExecutionContext; class BucketizedBufferAllocator; - class CPUAllocator; class ExecutionProvider; class ExecutionProviderImpl : public WRL::Base m_uploadHeap; std::unique_ptr m_readbackHeap; std::shared_ptr m_allocator; - std::shared_ptr m_cpuInputAllocator; + std::shared_ptr m_cpuInputAllocator; std::shared_ptr m_kernelRegistry; std::shared_ptr m_internalRegInfoMap; mutable uint64_t m_partitionKernelPrefixVal = 0;