Replace "DML CPU" Allocator with onnxruntime::CpuAllocator (#21818)

### Description
Replace "DML CPU" Allocator with onnxruntime::CpuAllocator

### Motivation and Context
This allocator is being ignored by ORTExtensions and causes CPU memory
to be treated as non-CPU memory and crash in SentencepieceTokenizer.

In general it seems like this allocator is not used and can be handled
just fine by the default allocator.

---------

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
Sheil Kumar 2024-08-23 10:35:57 -07:00 committed by GitHub
parent 5726318ec0
commit 44dcc3aafd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 4 additions and 27 deletions

View file

@ -223,28 +223,4 @@ namespace Dml
{
m_defaultRoundingMode = roundingMode;
}
CPUAllocator::CPUAllocator(OrtMemType memType)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML CPU",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0),
0,
memType
)
)
{
}
void* CPUAllocator::Alloc(size_t size)
{
return onnxruntime::AllocatorDefaultAlloc(size);
}
void CPUAllocator::Free(void* p)
{
return onnxruntime::AllocatorDefaultFree(p);
}
} // namespace Dml

View file

@ -239,7 +239,9 @@ namespace Dml
std::make_unique<DmlCommittedResourceAllocator>(m_d3d12Device.Get()));
m_context->SetAllocator(m_allocator);
// CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators.
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator);
memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput;
m_cpuInputAllocator = std::make_shared<onnxruntime::CPUAllocator>(memoryInfo);
}
return std::vector<onnxruntime::AllocatorPtr>{m_allocator, m_cpuInputAllocator,};

View file

@ -25,7 +25,6 @@ namespace Dml
class ReadbackHeap;
class ExecutionContext;
class BucketizedBufferAllocator;
class CPUAllocator;
class ExecutionProvider;
class ExecutionProviderImpl : public WRL::Base<Dml::IExecutionProvider,
@ -213,7 +212,7 @@ namespace Dml
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;
std::shared_ptr<BucketizedBufferAllocator> m_allocator;
std::shared_ptr<CPUAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::IAllocator> m_cpuInputAllocator;
std::shared_ptr<onnxruntime::KernelRegistry> m_kernelRegistry;
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap> m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;