From 39dc6ea8a32d7dcdbeea45967452118ab16dcefc Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Fri, 24 Sep 2021 09:48:07 +0800 Subject: [PATCH] Fix to_dlpack Failure on PyTorch-1.10 (#9151) * workaround to_dlpack fail in new pt version * add torch code link --- cgmanifests/cgmanifest.json | 2 +- cgmanifests/submodules/cgmanifest.json | 2 +- cmake/external/dlpack | 2 +- onnxruntime/core/dlpack/dlpack_converter.cc | 26 +++++++++---------- .../python/training/ortmodule/_utils.py | 18 ++++++++++++- .../python/orttraining_test_ortmodule_api.py | 4 ++- 6 files changed, 36 insertions(+), 18 deletions(-) diff --git a/cgmanifests/cgmanifest.json b/cgmanifests/cgmanifest.json index c8ff5d7c74..ac0bf5dec5 100644 --- a/cgmanifests/cgmanifest.json +++ b/cgmanifests/cgmanifest.json @@ -464,7 +464,7 @@ "component": { "type": "git", "git": { - "commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da", + "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", "repositoryUrl": "https://github.com/dmlc/dlpack.git" }, "comments": "dlpack" diff --git a/cgmanifests/submodules/cgmanifest.json b/cgmanifests/submodules/cgmanifest.json index eb914f26df..41c43a6ff1 100644 --- a/cgmanifests/submodules/cgmanifest.json +++ b/cgmanifests/submodules/cgmanifest.json @@ -184,7 +184,7 @@ "component": { "type": "git", "git": { - "commitHash": "e1e11e0d555c08bec08a6c7773aa777dfcaae9da", + "commitHash": "277508879878e0a5b5b43599b1bea11f66eb3c6c", "repositoryUrl": "https://github.com/dmlc/dlpack.git" }, "comments": "git submodule at cmake/external/dlpack" diff --git a/cmake/external/dlpack b/cmake/external/dlpack index e1e11e0d55..2775088798 160000 --- a/cmake/external/dlpack +++ b/cmake/external/dlpack @@ -1 +1 @@ -Subproject commit e1e11e0d555c08bec08a6c7773aa777dfcaae9da +Subproject commit 277508879878e0a5b5b43599b1bea11f66eb3c6c diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index 36564c3116..8937d5467e 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -75,28 +75,28 @@ DLDataType GetDlpackDataType(const OrtValue& ort_value) { return dtype; } -DLContext GetDlpackContext(const OrtValue& ort_value, const int64_t& device_id) { +DLDevice GetDlpackDevice(const OrtValue& ort_value, const int64_t& device_id) { ORT_ENFORCE(ort_value.IsTensor(), "Only OrtValues that are Tensors are currently supported"); - DLContext ctx; - ctx.device_id = static_cast(device_id); + DLDevice device; + device.device_id = static_cast(device_id); const Tensor& tensor = ort_value.Get(); const auto& location = tensor.Location(); switch (location.device.Type()) { case OrtDevice::CPU: - ctx.device_type = DLDeviceType::kDLCPU; + device.device_type = DLDeviceType::kDLCPU; break; case OrtDevice::GPU: #ifdef USE_ROCM - ctx.device_type = DLDeviceType::kDLROCM; + device.device_type = DLDeviceType::kDLROCM; #else - ctx.device_type = DLDeviceType::kDLGPU; + device.device_type = DLDeviceType::kDLCUDA; #endif break; default: ORT_THROW("Cannot pack tensors on this device."); } - return ctx; + return device; } struct OrtDLManagedTensor { @@ -106,13 +106,13 @@ struct OrtDLManagedTensor { void DlpackDeleter(DLManagedTensor* arg) { delete static_cast(arg->manager_ctx); } -OrtDevice GetOrtDevice(const DLContext& ctx) { - switch (ctx.device_type) { +OrtDevice GetOrtDevice(const DLDevice& device) { + switch (device.device_type) { case DLDeviceType::kDLCPU: return OrtDevice(); - case DLDeviceType::kDLGPU: + case DLDeviceType::kDLCUDA: case DLDeviceType::kDLROCM: - return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(ctx.device_id)); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); default: ORT_THROW("Unsupported device type"); } @@ -208,7 +208,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) { ort_dlmanaged_tensor->tensor.manager_ctx = ort_dlmanaged_tensor; ort_dlmanaged_tensor->tensor.deleter = &DlpackDeleter; ort_dlmanaged_tensor->tensor.dl_tensor.data = (tensor.MutableDataRaw()); - ort_dlmanaged_tensor->tensor.dl_tensor.ctx = GetDlpackContext(ort_value, tensor.Location().device.Id()); + ort_dlmanaged_tensor->tensor.dl_tensor.device = GetDlpackDevice(ort_value, tensor.Location().device.Id()); ort_dlmanaged_tensor->tensor.dl_tensor.ndim = static_cast(tensor.Shape().NumDimensions()); ort_dlmanaged_tensor->tensor.dl_tensor.dtype = GetDlpackDataType(ort_value); ort_dlmanaged_tensor->tensor.dl_tensor.shape = @@ -221,7 +221,7 @@ DLManagedTensor* OrtValueToDlpack(OrtValue& ort_value) { OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) { // ORT only supports contiguous tensor for now. ORT_ENFORCE(IsContiguousTensor(dlpack->dl_tensor), "ORT only supports contiguous tensor for now."); - OrtDevice device = GetOrtDevice(dlpack->dl_tensor.ctx); + OrtDevice device = GetOrtDevice(dlpack->dl_tensor.device); MLDataType data_type = GetOrtValueDataType(dlpack->dl_tensor.dtype, is_bool_tensor); OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device, device.Id()); std::unique_ptr p_tensor = std::make_unique( diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index ec3af45b0a..50161e8e77 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -15,9 +15,18 @@ from torch.utils.dlpack import from_dlpack, to_dlpack from typing import List import types import warnings +from distutils.version import LooseVersion def _ortvalue_from_torch_tensor(torch_tensor): - return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), torch_tensor.dtype == torch.bool) + # TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit. + # https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41 + # We need to convert bool tensor to unit8 tensor to workaround this. + # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack + # and PyTorch support bool type. + is_bool_tensor = torch_tensor.dtype == torch.bool + if is_bool_tensor and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): + torch_tensor = torch_tensor.to(torch.uint8) + return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor) def _torch_tensor_from_dl_pack(dlpack, ortvalue, device): @@ -36,6 +45,13 @@ def _torch_tensor_to_dlpack(tensor): if tensor.device.type == 'ort': return C.ort_to_dlpack(tensor) else: + # TODO: Current DLPack doesn't support bool and PyTorch disables converting bool tensor to DLPack in recent commit. + # https://github.com/pytorch/pytorch/blob/7e7be526c9d9179f35084e9cca5b5c5ad5172100/aten/src/ATen/DLConvertor.cpp#L41 + # We need to convert bool tensor to unit8 tensor to workaround this. + # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack + # and PyTorch support bool type. + if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): + tensor = tensor.to(torch.uint8) return to_dlpack(tensor) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ef4d894f96..a111176073 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -19,6 +19,7 @@ from collections import namedtuple from inspect import signature import tempfile import os +from distutils.version import LooseVersion from onnxruntime.training.ortmodule import ORTModule, _utils, _io, DebugOptions, LogLevel, _fallback, _graph_execution_manager import _test_helpers @@ -3827,7 +3828,8 @@ def test_ortmodule_ortmodule_method_attribute_copy(): assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward' assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward' - assert type(out3.grad_fn).__name__ == 'AddmmBackward' + assert type(out3.grad_fn).__name__ == 'AddmmBackward0' if LooseVersion( + torch.__version__) >= LooseVersion('1.10.0') else 'AddmmBackward' @pytest.mark.parametrize("policy_str, policy",[ ('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED),