Fix to_dlpack Failure on PyTorch-1.10 (#9151)

* workaround to_dlpack fail in new pt version

* add torch code link
This commit is contained in:
Vincent Wang 2021-09-24 09:48:07 +08:00 committed by GitHub
parent 0888c6cc59
commit 39dc6ea8a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 18 deletions

View file

@ -464,7 +464,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da",
"commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c",
"repositoryUrl": "https://github.com/dmlc/dlpack.git"
},
"comments": "dlpack"

View file

@ -184,7 +184,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da",
"commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c",
"repositoryUrl": "https://github.com/dmlc/dlpack.git"
},
"comments": "git submodule at cmake/external/dlpack"

@ -1 +1 @@
Subproject commit e1e11e0d555c08bec08a6c7773aa777dfcaae9da
Subproject commit 277508879878e0a5b5b43599b1bea11f66eb3c6c

View file

@ -75,28 +75,28 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) {
return dtype;
}
DLContext GetDlpackContext(const OrtValue& ort_value, const int64_t& device_id) {
DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) {
ORT_ENFORCE(ort_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
DLContext ctx;
ctx.device_id = static_cast<int>(device_id);
DLDevice device;
device.device_id = static_cast<int>(device_id);
const Tensor& tensor = ort_value.Get<Tensor>();
const auto& location = tensor.Location();
switch (location.device.Type()) {
case OrtDevice::CPU:
ctx.device_type = DLDeviceType::kDLCPU;
device.device_type = DLDeviceType::kDLCPU;
break;
case OrtDevice::GPU:
#ifdef USE_ROCM
ctx.device_type = DLDeviceType::kDLROCM;
device.device_type = DLDeviceType::kDLROCM;
#else
ctx.device_type = DLDeviceType::kDLGPU;
device.device_type = DLDeviceType::kDLCUDA;
#endif
break;
default:
ORT_THROW("Cannot pack tensors on this device.");
}
return ctx;
return device;
}
struct OrtDLManagedTensor {
@ -106,13 +106,13 @@ struct OrtDLManagedTensor {
void DlpackDeleter(DLManagedTensor* arg) { delete static_cast<OrtDLManagedTensor*>(arg->manager_ctx); }
OrtDevice GetOrtDevice(const DLContext& ctx) {
switch (ctx.device_type) {
OrtDevice GetOrtDevice(const DLDevice& device) {
switch (device.device_type) {
case DLDeviceType::kDLCPU:
return OrtDevice();
case DLDeviceType::kDLGPU:
case DLDeviceType::kDLCUDA:
case DLDeviceType::kDLROCM:
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(ctx.device_id));
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
default:
ORT_THROW("Unsupported device type");
}
@ -208,7 +208,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) {
ort_dlmanaged_tensor->tensor.manager_ctx = ort_dlmanaged_tensor;
ort_dlmanaged_tensor->tensor.deleter = &DlpackDeleter;
ort_dlmanaged_tensor->tensor.dl_tensor.data = (tensor.MutableDataRaw());
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = GetDlpackContext(ort_value, tensor.Location().device.Id());
ort_dlmanaged_tensor->tensor.dl_tensor.device = GetDlpackDevice(ort_value, tensor.Location().device.Id());
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = static_cast<int>(tensor.Shape().NumDimensions());
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = GetDlpackDataType(ort_value);
ort_dlmanaged_tensor->tensor.dl_tensor.shape =
@ -221,7 +221,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) {
OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) {
// ORT only supports contiguous tensor for now.
ORT_ENFORCE(IsContiguousTensor(dlpack->dl_tensor), "ORT only supports contiguous tensor for now.");
OrtDevice device = GetOrtDevice(dlpack->dl_tensor.ctx);
OrtDevice device = GetOrtDevice(dlpack->dl_tensor.device);
MLDataType data_type = GetOrtValueDataType(dlpack->dl_tensor.dtype, is_bool_tensor);
OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device, device.Id());
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(

View file

@ -15,9 +15,18 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
from typing import List
import types
import warnings
from distutils.version import LooseVersion
def _ortvalue_from_torch_tensor(torch_tensor):
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool)
# TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit.
# https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41
# We need to convert bool tensor to unit8 tensor to workaround this.
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
# and PyTorch support bool type.
is_bool_tensor = torch_tensor.dtype == torch.bool
if is_bool_tensor and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
torch_tensor = torch_tensor.to(torch.uint8)
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor)
def _torch_tensor_from_dl_pack(dlpack, ortvalue, device):
@ -36,6 +45,13 @@ def _torch_tensor_to_dlpack(tensor):
if tensor.device.type == 'ort':
return C.ort_to_dlpack(tensor)
else:
# TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit.
# https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41
# We need to convert bool tensor to unit8 tensor to workaround this.
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
# and PyTorch support bool type.
if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
tensor = tensor.to(torch.uint8)
return to_dlpack(tensor)

View file

@ -19,6 +19,7 @@ from collections import namedtuple
from inspect import signature
import tempfile
import os
from distutils.version import LooseVersion
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager
import _test_helpers
@ -3827,7 +3828,8 @@ def test_ortmodule_ortmodule_method_attribute_copy():
assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward'
assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward'
assert type(out3.grad_fn).__name__ == 'AddmmBackward'
assert type(out3.grad_fn).__name__ == 'AddmmBackward0' if LooseVersion(
torch.__version__) >= LooseVersion('1.10.0') else 'AddmmBackward'
@pytest.mark.parametrize("policy_str, policy",[
('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),