diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 0c844003eb1..64a8d091049 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -121,6 +121,9 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { case DeviceType::MAIA: ctx.device_type = DLDeviceType::kDLMAIA; break; + case DeviceType::PrivateUse1: + ctx.device_type = DLDeviceType::kDLExtDev; + break; default: TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); } @@ -149,6 +152,8 @@ static Device getATenDevice(const DLDevice& ctx, void* data) { return at::detail::getXPUHooks().getDeviceFromPtr(data); case DLDeviceType::kDLMAIA: return at::Device(DeviceType::MAIA, static_cast(ctx.device_id)); + case DLDeviceType::kDLExtDev: + return at::Device(DeviceType::PrivateUse1, static_cast(ctx.device_id)); default: TORCH_CHECK( false, "Unsupported device_type: ", std::to_string(ctx.device_type)); @@ -287,7 +292,7 @@ DLManagedTensor* toDLPack(const Tensor& src) { atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); c10::DeviceIndex device_id = 0; - if (src.is_cuda()) { + if (src.is_cuda() || src.is_privateuseone()) { device_id = src.get_device(); } atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 2cf69211b40..ba90aad46fe 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -644,6 +644,15 @@ class TestCppExtensionOpenRgistration(common.TestCase): with torch.serialization.skip_data(): torch.save(sd, f) + def test_open_device_dlpack(self): + t = torch.randn(2, 3).to("foo") + capsule = torch.utils.dlpack.to_dlpack(t) + t1 = torch.from_dlpack(capsule) + self.assertTrue(t1.device == t.device) + t = t.to("cpu") + t1 = t1.to("cpu") + self.assertEqual(t, t1) + if __name__ == "__main__": common.run_tests() diff --git a/torch/_tensor.py b/torch/_tensor.py index 0ee0a6cd282..18c309ec876 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1637,6 +1637,8 @@ class Tensor(torch._C.TensorBase): device_type = DLDeviceType.kDLCPU elif self.device.type == "xpu": device_type = DLDeviceType.kDLOneAPI + elif self.device.type == "privateuse1": + device_type = DLDeviceType.kDLExtDev else: raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") return (device_type, idx)