Fix Memory Leak from DlpackToOrtValue (#8029)

This commit is contained in:
Vincent Wang 2021-06-11 15:48:13 +08:00 committed by GitHub
parent d02de9c1bc
commit 2f2aaf2cf6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -223,13 +223,17 @@ OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) {
ORT_ENFORCE(IsContiguousTensor(dlpack->dl_tensor), "ORT only supports contiguous tensor for now.");
OrtDevice device = GetOrtDevice(dlpack->dl_tensor.ctx);
MLDataType data_type = GetOrtValueDataType(dlpack->dl_tensor.dtype, is_bool_tensor);
std::function<void(void*)> deleter = [dlpack](void*) { dlpack->deleter((dlpack)); };
OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device, device.Id());
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(
data_type, TensorShape(dlpack->dl_tensor.shape, static_cast<size_t>(dlpack->dl_tensor.ndim)),
dlpack->dl_tensor.data, info);
OrtValue ort_value;
std::function<void(void*)> deleter = [dlpack](void* p) {
dlpack->deleter(dlpack);
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc()(p);
};
ort_value.Init(p_tensor.release(), DataTypeImpl::GetType<Tensor>(), deleter);
return ort_value;
}