diff --git a/onnxruntime/core/framework/allocatormgr.cc b/onnxruntime/core/framework/allocatormgr.cc index a38d89a9e2..f4258d5a6a 100644 --- a/onnxruntime/core/framework/allocatormgr.cc +++ b/onnxruntime/core/framework/allocatormgr.cc @@ -29,4 +29,9 @@ AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id return AllocatorPtr(std::move(device_allocator)); } +DeviceAllocatorRegistry& DeviceAllocatorRegistry::Instance() { + static DeviceAllocatorRegistry s_instance; + return s_instance; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocatormgr.h b/onnxruntime/core/framework/allocatormgr.h index aa346fc52f..3985fd4b66 100644 --- a/onnxruntime/core/framework/allocatormgr.h +++ b/onnxruntime/core/framework/allocatormgr.h @@ -18,4 +18,25 @@ struct DeviceAllocatorRegistrationInfo { AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id = 0); +class DeviceAllocatorRegistry { + public: + void RegisterDeviceAllocator(std::string&& name, DeviceAllocatorFactory factory, size_t max_mem, + OrtMemType mem_type = OrtMemTypeDefault) { + DeviceAllocatorRegistrationInfo info({mem_type, factory, max_mem}); + device_allocator_registrations_.emplace(std::move(name), std::move(info)); + } + + const std::map& AllRegistrations() const { + return device_allocator_registrations_; + } + + static DeviceAllocatorRegistry& Instance(); + + private: + DeviceAllocatorRegistry() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeviceAllocatorRegistry); + + std::map device_allocator_registrations_; +}; + } // namespace onnxruntime