diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 549f10df207..d5893f23426 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -51,6 +51,9 @@ class DummyModule(object): def is_available(): return True + @staticmethod + def current_device(): + return 0 @unittest.skipIf(IS_ARM64, "Does not work on arm") class TestCppExtensionOpenRgistration(common.TestCase): @@ -373,6 +376,15 @@ class TestCppExtensionOpenRgistration(common.TestCase): finally: torch.foo.FloatStorage = None + def test_open_device_faketensor(): + torch.utils.rename_privateuse1_backend('foo') + # register foo module, torch.foo + torch._register_device_module('foo', DummyModule) + with torch._subclasses.fake_tensor.FakeTensorMode.push(): + a = torch.empty(1, device="foo") + b = torch.empty(1, device="foo:0") + result = a + b + test_base_device_registration() test_before_common_registration() test_common_registration() @@ -386,6 +398,7 @@ class TestCppExtensionOpenRgistration(common.TestCase): test_open_device_serialization() test_open_device_storage_resize() test_open_device_storage_type() + test_open_device_faketensor() if __name__ == "__main__": diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index dd822e46ca9..e0b419b5e8c 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -938,15 +938,17 @@ class FakeTensor(torch.Tensor): # on if not fake_mode.allow_meta: assert device.type != "meta" - # normalize cuda device. + # normalize device. if device.type == "cuda": init_cuda_context() - if device.index is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - # normalize hpu device. - if device.type == "hpu" and device.index is None: - device = torch.device(f"hpu:{torch.hpu.current_device()}") + if ( + device.type in ["cuda", "hpu", torch._C._get_privateuse1_backend_name()] + and device.index is None + ): + device = torch.device( + f"{device.type}:{getattr(torch, device.type).current_device()}" + ) self.fake_device = device # type: ignore[attr-defined] self.fake_mode = fake_mode # type: ignore[attr-defined] self.constant = constant # type: ignore[attr-defined]