From 35268bdc2a6f0be613b9fa0b5e3da6ae68ece67f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 24 Jun 2022 14:53:37 +0000 Subject: [PATCH] Revert "[FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)" This reverts commit 93e70c5973bd2ad9eeee1469dd6ac0076943d9aa. Reverted https://github.com/pytorch/pytorch/pull/80193 on behalf of https://github.com/b0noI due to broken test: https://github.com/pytorch/pytorch/runs/7035945243?check_suite_focus=true --- test/test_fake_tensor.py | 24 ++++-------------------- torch/_subclasses/fake_tensor.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 11ef36ce98d..3252ef14cfa 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -11,7 +11,6 @@ from torch._subclasses.fake_tensor import ( DynamicOutputShapeException, ) from torch.utils._python_dispatch import enable_torch_dispatch_mode -from torch import nn import unittest import torch._prims as prims import copy @@ -138,20 +137,20 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=False)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=False)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() with self.assertRaises(NotImplementedError): torch.nn.functional.conv2d(inputs, filters, padding=1) - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)): # intentionally bad inputs filters = torch.randn(8, 20, 3, 3).cuda() inputs = torch.randn(1, 7, 10, 5).cuda() with self.assertRaises(RuntimeError): torch.nn.functional.conv2d(inputs, filters, padding=1) - with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)): + with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() @@ -159,24 +158,9 @@ class FakeTensorTest(TestCase): self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5]) - @unittest.skipIf(not RUN_CUDA, "requires cuda") - def test_fallback_memory_prop(self): - m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half) - m = m.to(memory_format=torch.channels_last) - mode = FakeTensorMode(inner=None) - # TODO: module.to() doesn't work because it assigns .data, which is ignored - with torch._subclasses.fake_tensor.FakeCopyMode(mode): - mod_copied = copy.deepcopy(m) - - with enable_torch_dispatch_mode(mode): - input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last) - out = mod_copied(input) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.checkType(out, "cuda", [20, 33, 24, 49]) - def test_data_dependent_operator(self): with enable_torch_dispatch_mode( - FakeTensorMode(inner=None, allow_fallback_kernels=False) + FakeTensorMode(inner=None, allow_cpu_fallback=False) ): x = torch.rand([10, 10]) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 3961ac4fdda..5b979324a86 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -393,8 +393,8 @@ class FakeTensor(torch.Tensor): class FakeTensorMode(TorchDispatchMode): - def __init__(self, allow_fallback_kernels=True): - self.allow_fallback_kernels = allow_fallback_kernels + def __init__(self, allow_cpu_fallback=True): + self.allow_cpu_fallback = allow_cpu_fallback self.fake_tensor_converter = FakeTensorConverter() # [in_kernel_invocation] @@ -475,9 +475,9 @@ class FakeTensorMode(TorchDispatchMode): try: r = func(*args, **kwargs) except NotImplementedError as not_implemented_error: - if not self.allow_fallback_kernels: + if not self.allow_cpu_fallback: raise not_implemented_error - r = run_fallback_kernel(func, args, kwargs, not_implemented_error) + r = run_cpu_fallback(func, args, kwargs, not_implemented_error) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" @@ -493,16 +493,16 @@ class FakeTensorMode(TorchDispatchMode): def from_tensor(self, tensor): return self.fake_tensor_converter(self, tensor) -def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception): +def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception): with no_dispatch(): - def to_real_tensor(e): + def to_cpu(e): if isinstance(e, FakeTensor): - return torch.zeros_like(e, device=e.fake_device) + return torch.zeros_like(e, device="cpu") return e try: - args = tree_map(to_real_tensor, args) - kwargs = tree_map(to_real_tensor, kwargs) + args = tree_map(to_cpu, args) + kwargs = tree_map(to_cpu, kwargs) r = func(*args, **kwargs) except Exception as new_exception: