Revert "Remove DeviceAllocatorRegistry class (#2451)"

This reverts commit 6a88d03621.
This commit is contained in:
Changming Sun 2019-11-25 16:59:34 -08:00
parent 4cac18f666
commit 1ca0e0866e
2 changed files with 26 additions and 0 deletions

View file

@ -29,4 +29,9 @@ AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id
return AllocatorPtr(std::move(device_allocator));
}
DeviceAllocatorRegistry& DeviceAllocatorRegistry::Instance() {
static DeviceAllocatorRegistry s_instance;
return s_instance;
}
} // namespace onnxruntime

View file

@ -18,4 +18,25 @@ struct DeviceAllocatorRegistrationInfo {
AllocatorPtr CreateAllocator(DeviceAllocatorRegistrationInfo info, int device_id = 0);
class DeviceAllocatorRegistry {
public:
void RegisterDeviceAllocator(std::string&& name, DeviceAllocatorFactory factory, size_t max_mem,
OrtMemType mem_type = OrtMemTypeDefault) {
DeviceAllocatorRegistrationInfo info({mem_type, factory, max_mem});
device_allocator_registrations_.emplace(std::move(name), std::move(info));
}
const std::map<std::string, DeviceAllocatorRegistrationInfo>& AllRegistrations() const {
return device_allocator_registrations_;
}
static DeviceAllocatorRegistry& Instance();
private:
DeviceAllocatorRegistry() = default;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DeviceAllocatorRegistry);
std::map<std::string, DeviceAllocatorRegistrationInfo> device_allocator_registrations_;
};
} // namespace onnxruntime