mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-15 20:50:42 +00:00
Fix to_dlpack Failure on PyTorch-1.10 (#9151)
* workaround to_dlpack fail in new pt version * add torch code link
This commit is contained in:
parent
0888c6cc59
commit
39dc6ea8a3
6 changed files with 36 additions and 18 deletions
|
|
@ -464,7 +464,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da",
|
||||
"commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c",
|
||||
"repositoryUrl": "https://github.com/dmlc/dlpack.git"
|
||||
},
|
||||
"comments": "dlpack"
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da",
|
||||
"commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c",
|
||||
"repositoryUrl": "https://github.com/dmlc/dlpack.git"
|
||||
},
|
||||
"comments": "git submodule at cmake/external/dlpack"
|
||||
|
|
|
|||
2
cmake/external/dlpack
vendored
2
cmake/external/dlpack
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit e1e11e0d555c08bec08a6c7773aa777dfcaae9da
|
||||
Subproject commit 277508879878e0a5b5b43599b1bea11f66eb3c6c
|
||||
|
|
@ -75,28 +75,28 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) {
|
|||
return dtype;
|
||||
}
|
||||
|
||||
DLContext GetDlpackContext(const OrtValue& ort_value, const int64_t& device_id) {
|
||||
DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) {
|
||||
ORT_ENFORCE(ort_value.IsTensor(), "Only OrtValues that are Tensors are currently supported");
|
||||
DLContext ctx;
|
||||
ctx.device_id = static_cast<int>(device_id);
|
||||
DLDevice device;
|
||||
device.device_id = static_cast<int>(device_id);
|
||||
const Tensor& tensor = ort_value.Get<Tensor>();
|
||||
const auto& location = tensor.Location();
|
||||
switch (location.device.Type()) {
|
||||
case OrtDevice::CPU:
|
||||
ctx.device_type = DLDeviceType::kDLCPU;
|
||||
device.device_type = DLDeviceType::kDLCPU;
|
||||
break;
|
||||
case OrtDevice::GPU:
|
||||
#ifdef USE_ROCM
|
||||
ctx.device_type = DLDeviceType::kDLROCM;
|
||||
device.device_type = DLDeviceType::kDLROCM;
|
||||
#else
|
||||
ctx.device_type = DLDeviceType::kDLGPU;
|
||||
device.device_type = DLDeviceType::kDLCUDA;
|
||||
#endif
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Cannot pack tensors on this device.");
|
||||
}
|
||||
|
||||
return ctx;
|
||||
return device;
|
||||
}
|
||||
|
||||
struct OrtDLManagedTensor {
|
||||
|
|
@ -106,13 +106,13 @@ struct OrtDLManagedTensor {
|
|||
|
||||
void DlpackDeleter(DLManagedTensor* arg) { delete static_cast<OrtDLManagedTensor*>(arg->manager_ctx); }
|
||||
|
||||
OrtDevice GetOrtDevice(const DLContext& ctx) {
|
||||
switch (ctx.device_type) {
|
||||
OrtDevice GetOrtDevice(const DLDevice& device) {
|
||||
switch (device.device_type) {
|
||||
case DLDeviceType::kDLCPU:
|
||||
return OrtDevice();
|
||||
case DLDeviceType::kDLGPU:
|
||||
case DLDeviceType::kDLCUDA:
|
||||
case DLDeviceType::kDLROCM:
|
||||
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(ctx.device_id));
|
||||
return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(device.device_id));
|
||||
default:
|
||||
ORT_THROW("Unsupported device type");
|
||||
}
|
||||
|
|
@ -208,7 +208,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) {
|
|||
ort_dlmanaged_tensor->tensor.manager_ctx = ort_dlmanaged_tensor;
|
||||
ort_dlmanaged_tensor->tensor.deleter = &DlpackDeleter;
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.data = (tensor.MutableDataRaw());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ctx = GetDlpackContext(ort_value, tensor.Location().device.Id());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.device = GetDlpackDevice(ort_value, tensor.Location().device.Id());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.ndim = static_cast<int>(tensor.Shape().NumDimensions());
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.dtype = GetDlpackDataType(ort_value);
|
||||
ort_dlmanaged_tensor->tensor.dl_tensor.shape =
|
||||
|
|
@ -221,7 +221,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) {
|
|||
OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) {
|
||||
// ORT only supports contiguous tensor for now.
|
||||
ORT_ENFORCE(IsContiguousTensor(dlpack->dl_tensor), "ORT only supports contiguous tensor for now.");
|
||||
OrtDevice device = GetOrtDevice(dlpack->dl_tensor.ctx);
|
||||
OrtDevice device = GetOrtDevice(dlpack->dl_tensor.device);
|
||||
MLDataType data_type = GetOrtValueDataType(dlpack->dl_tensor.dtype, is_bool_tensor);
|
||||
OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device, device.Id());
|
||||
std::unique_ptr<Tensor> p_tensor = std::make_unique<Tensor>(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,18 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
|||
from typing import List
|
||||
import types
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
def _ortvalue_from_torch_tensor(torch_tensor):
|
||||
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool)
|
||||
# TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit.
|
||||
# https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41
|
||||
# We need to convert bool tensor to unit8 tensor to workaround this.
|
||||
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
|
||||
# and PyTorch support bool type.
|
||||
is_bool_tensor = torch_tensor.dtype == torch.bool
|
||||
if is_bool_tensor and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
|
||||
torch_tensor = torch_tensor.to(torch.uint8)
|
||||
return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor)
|
||||
|
||||
|
||||
def _torch_tensor_from_dl_pack(dlpack, ortvalue, device):
|
||||
|
|
@ -36,6 +45,13 @@ def _torch_tensor_to_dlpack(tensor):
|
|||
if tensor.device.type == 'ort':
|
||||
return C.ort_to_dlpack(tensor)
|
||||
else:
|
||||
# TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit.
|
||||
# https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41
|
||||
# We need to convert bool tensor to unit8 tensor to workaround this.
|
||||
# DLPack is discussing how to support bool type, we can remove this workaround once both DLPack
|
||||
# and PyTorch support bool type.
|
||||
if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'):
|
||||
tensor = tensor.to(torch.uint8)
|
||||
return to_dlpack(tensor)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from collections import namedtuple
|
|||
from inspect import signature
|
||||
import tempfile
|
||||
import os
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager
|
||||
import _test_helpers
|
||||
|
|
@ -3827,7 +3828,8 @@ def test_ortmodule_ortmodule_method_attribute_copy():
|
|||
|
||||
assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward'
|
||||
assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward'
|
||||
assert type(out3.grad_fn).__name__ == 'AddmmBackward'
|
||||
assert type(out3.grad_fn).__name__ == 'AddmmBackward0' if LooseVersion(
|
||||
torch.__version__) >= LooseVersion('1.10.0') else 'AddmmBackward'
|
||||
|
||||
@pytest.mark.parametrize("policy_str, policy",[
|
||||
('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),
|
||||
|
|
|
|||
Loading…
Reference in a new issue