make device normalization more generic in faketensor (#102519)

Fixes #ISSUE_NUMBER
 make the device normalization more generic in faketensor to support devices like "cuda", "foo" and so on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102519
Approved by: https://github.com/albanD
This commit is contained in:
shibo19 2023-06-04 01:44:21 +00:00 committed by PyTorch MergeBot
parent 85efacee07
commit 9d20b47e47
2 changed files with 21 additions and 6 deletions

View file

@ -51,6 +51,9 @@ class DummyModule(object):
def is_available():
return True
@staticmethod
def current_device():
return 0
@unittest.skipIf(IS_ARM64, "Does not work on arm")
class TestCppExtensionOpenRgistration(common.TestCase):
@ -373,6 +376,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
finally:
torch.foo.FloatStorage = None
def test_open_device_faketensor():
torch.utils.rename_privateuse1_backend('foo')
# register foo module, torch.foo
torch._register_device_module('foo', DummyModule)
with torch._subclasses.fake_tensor.FakeTensorMode.push():
a = torch.empty(1, device="foo")
b = torch.empty(1, device="foo:0")
result = a + b
test_base_device_registration()
test_before_common_registration()
test_common_registration()
@ -386,6 +398,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
test_open_device_serialization()
test_open_device_storage_resize()
test_open_device_storage_type()
test_open_device_faketensor()
if __name__ == "__main__":

View file

@ -938,15 +938,17 @@ class FakeTensor(torch.Tensor):
# on
if not fake_mode.allow_meta:
assert device.type != "meta"
# normalize cuda device.
# normalize device.
if device.type == "cuda":
init_cuda_context()
if device.index is None:
device = torch.device(f"cuda:{torch.cuda.current_device()}")
# normalize hpu device.
if device.type == "hpu" and device.index is None:
device = torch.device(f"hpu:{torch.hpu.current_device()}")
if (
device.type in ["cuda", "hpu", torch._C._get_privateuse1_backend_name()]
and device.index is None
):
device = torch.device(
f"{device.type}:{getattr(torch, device.type).current_device()}"
)
self.fake_device = device # type: ignore[attr-defined]
self.fake_mode = fake_mode # type: ignore[attr-defined]
self.constant = constant # type: ignore[attr-defined]