Revert "[FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)"

This reverts commit 93e70c5973.

Reverted https://github.com/pytorch/pytorch/pull/80193 on behalf of https://github.com/b0noI due to broken test: https://github.com/pytorch/pytorch/runs/7035945243?check_suite_focus=true
This commit is contained in:
PyTorch MergeBot 2022-06-24 14:53:37 +00:00
parent 845021db2c
commit 35268bdc2a
2 changed files with 13 additions and 29 deletions

View file

@ -11,7 +11,6 @@ from torch._subclasses.fake_tensor import (
DynamicOutputShapeException,
)
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch import nn
import unittest
import torch._prims as prims
import copy
@ -138,20 +137,20 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cpu_fallback(self):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=False)):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=False)):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
with self.assertRaises(NotImplementedError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)):
# intentionally bad inputs
filters = torch.randn(8, 20, 3, 3).cuda()
inputs = torch.randn(1, 7, 10, 5).cuda()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
@ -159,24 +158,9 @@ class FakeTensorTest(TestCase):
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_fallback_memory_prop(self):
m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half)
m = m.to(memory_format=torch.channels_last)
mode = FakeTensorMode(inner=None)
# TODO: module.to() doesn't work because it assigns .data, which is ignored
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(m)
with enable_torch_dispatch_mode(mode):
input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last)
out = mod_copied(input)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.checkType(out, "cuda", [20, 33, 24, 49])
def test_data_dependent_operator(self):
with enable_torch_dispatch_mode(
FakeTensorMode(inner=None, allow_fallback_kernels=False)
FakeTensorMode(inner=None, allow_cpu_fallback=False)
):
x = torch.rand([10, 10])

View file

@ -393,8 +393,8 @@ class FakeTensor(torch.Tensor):
class FakeTensorMode(TorchDispatchMode):
def __init__(self, allow_fallback_kernels=True):
self.allow_fallback_kernels = allow_fallback_kernels
def __init__(self, allow_cpu_fallback=True):
self.allow_cpu_fallback = allow_cpu_fallback
self.fake_tensor_converter = FakeTensorConverter()
# [in_kernel_invocation]
@ -475,9 +475,9 @@ class FakeTensorMode(TorchDispatchMode):
try:
r = func(*args, **kwargs)
except NotImplementedError as not_implemented_error:
if not self.allow_fallback_kernels:
if not self.allow_cpu_fallback:
raise not_implemented_error
r = run_fallback_kernel(func, args, kwargs, not_implemented_error)
r = run_cpu_fallback(func, args, kwargs, not_implemented_error)
# TODO: handle non-kwarg devices
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@ -493,16 +493,16 @@ class FakeTensorMode(TorchDispatchMode):
def from_tensor(self, tensor):
return self.fake_tensor_converter(self, tensor)
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
with no_dispatch():
def to_real_tensor(e):
def to_cpu(e):
if isinstance(e, FakeTensor):
return torch.zeros_like(e, device=e.fake_device)
return torch.zeros_like(e, device="cpu")
return e
try:
args = tree_map(to_real_tensor, args)
kwargs = tree_map(to_real_tensor, kwargs)
args = tree_map(to_cpu, args)
kwargs = tree_map(to_cpu, kwargs)
r = func(*args, **kwargs)
except Exception as new_exception: