mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
make device normalization more generic in faketensor (#102519)
Fixes #ISSUE_NUMBER make the device normalization more generic in faketensor to support devices like "cuda", "foo" and so on. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102519 Approved by: https://github.com/albanD
This commit is contained in:
parent
85efacee07
commit
9d20b47e47
2 changed files with 21 additions and 6 deletions
|
|
@ -51,6 +51,9 @@ class DummyModule(object):
|
|||
def is_available():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
return 0
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
||||
class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
|
|
@ -373,6 +376,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
finally:
|
||||
torch.foo.FloatStorage = None
|
||||
|
||||
def test_open_device_faketensor():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
# register foo module, torch.foo
|
||||
torch._register_device_module('foo', DummyModule)
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
||||
a = torch.empty(1, device="foo")
|
||||
b = torch.empty(1, device="foo:0")
|
||||
result = a + b
|
||||
|
||||
test_base_device_registration()
|
||||
test_before_common_registration()
|
||||
test_common_registration()
|
||||
|
|
@ -386,6 +398,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
test_open_device_serialization()
|
||||
test_open_device_storage_resize()
|
||||
test_open_device_storage_type()
|
||||
test_open_device_faketensor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -938,15 +938,17 @@ class FakeTensor(torch.Tensor):
|
|||
# on
|
||||
if not fake_mode.allow_meta:
|
||||
assert device.type != "meta"
|
||||
# normalize cuda device.
|
||||
# normalize device.
|
||||
if device.type == "cuda":
|
||||
init_cuda_context()
|
||||
if device.index is None:
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
|
||||
# normalize hpu device.
|
||||
if device.type == "hpu" and device.index is None:
|
||||
device = torch.device(f"hpu:{torch.hpu.current_device()}")
|
||||
if (
|
||||
device.type in ["cuda", "hpu", torch._C._get_privateuse1_backend_name()]
|
||||
and device.index is None
|
||||
):
|
||||
device = torch.device(
|
||||
f"{device.type}:{getattr(torch, device.type).current_device()}"
|
||||
)
|
||||
self.fake_device = device # type: ignore[attr-defined]
|
||||
self.fake_mode = fake_mode # type: ignore[attr-defined]
|
||||
self.constant = constant # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Reference in a new issue