mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[FakeTensor] Use the device of the meta tensor for fallback kernel (#80193)"
This reverts commit 93e70c5973.
Reverted https://github.com/pytorch/pytorch/pull/80193 on behalf of https://github.com/b0noI due to broken test: https://github.com/pytorch/pytorch/runs/7035945243?check_suite_focus=true
This commit is contained in:
parent
845021db2c
commit
35268bdc2a
2 changed files with 13 additions and 29 deletions
|
|
@ -11,7 +11,6 @@ from torch._subclasses.fake_tensor import (
|
|||
DynamicOutputShapeException,
|
||||
)
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch import nn
|
||||
import unittest
|
||||
import torch._prims as prims
|
||||
import copy
|
||||
|
|
@ -138,20 +137,20 @@ class FakeTensorTest(TestCase):
|
|||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_cpu_fallback(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=False)):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=False)):
|
||||
filters = torch.randn(8, 4, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 4, 5, 5).cuda()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
torch.nn.functional.conv2d(inputs, filters, padding=1)
|
||||
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)):
|
||||
# intentionally bad inputs
|
||||
filters = torch.randn(8, 20, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 7, 10, 5).cuda()
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.conv2d(inputs, filters, padding=1)
|
||||
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_fallback_kernels=True)):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(inner=None, allow_cpu_fallback=True)):
|
||||
filters = torch.randn(8, 4, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 4, 5, 5).cuda()
|
||||
|
||||
|
|
@ -159,24 +158,9 @@ class FakeTensorTest(TestCase):
|
|||
self.assertEqual(out.device.type, "cuda")
|
||||
self.assertEqual(list(out.size()), [1, 8, 5, 5])
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_fallback_memory_prop(self):
|
||||
m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half)
|
||||
m = m.to(memory_format=torch.channels_last)
|
||||
mode = FakeTensorMode(inner=None)
|
||||
# TODO: module.to() doesn't work because it assigns .data, which is ignored
|
||||
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
|
||||
mod_copied = copy.deepcopy(m)
|
||||
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last)
|
||||
out = mod_copied(input)
|
||||
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
||||
self.checkType(out, "cuda", [20, 33, 24, 49])
|
||||
|
||||
def test_data_dependent_operator(self):
|
||||
with enable_torch_dispatch_mode(
|
||||
FakeTensorMode(inner=None, allow_fallback_kernels=False)
|
||||
FakeTensorMode(inner=None, allow_cpu_fallback=False)
|
||||
):
|
||||
x = torch.rand([10, 10])
|
||||
|
||||
|
|
|
|||
|
|
@ -393,8 +393,8 @@ class FakeTensor(torch.Tensor):
|
|||
|
||||
|
||||
class FakeTensorMode(TorchDispatchMode):
|
||||
def __init__(self, allow_fallback_kernels=True):
|
||||
self.allow_fallback_kernels = allow_fallback_kernels
|
||||
def __init__(self, allow_cpu_fallback=True):
|
||||
self.allow_cpu_fallback = allow_cpu_fallback
|
||||
self.fake_tensor_converter = FakeTensorConverter()
|
||||
|
||||
# [in_kernel_invocation]
|
||||
|
|
@ -475,9 +475,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
try:
|
||||
r = func(*args, **kwargs)
|
||||
except NotImplementedError as not_implemented_error:
|
||||
if not self.allow_fallback_kernels:
|
||||
if not self.allow_cpu_fallback:
|
||||
raise not_implemented_error
|
||||
r = run_fallback_kernel(func, args, kwargs, not_implemented_error)
|
||||
r = run_cpu_fallback(func, args, kwargs, not_implemented_error)
|
||||
|
||||
# TODO: handle non-kwarg devices
|
||||
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
|
||||
|
|
@ -493,16 +493,16 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
def from_tensor(self, tensor):
|
||||
return self.fake_tensor_converter(self, tensor)
|
||||
|
||||
def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
|
||||
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
|
||||
with no_dispatch():
|
||||
def to_real_tensor(e):
|
||||
def to_cpu(e):
|
||||
if isinstance(e, FakeTensor):
|
||||
return torch.zeros_like(e, device=e.fake_device)
|
||||
return torch.zeros_like(e, device="cpu")
|
||||
return e
|
||||
|
||||
try:
|
||||
args = tree_map(to_real_tensor, args)
|
||||
kwargs = tree_map(to_real_tensor, kwargs)
|
||||
args = tree_map(to_cpu, args)
|
||||
kwargs = tree_map(to_cpu, kwargs)
|
||||
|
||||
r = func(*args, **kwargs)
|
||||
except Exception as new_exception:
|
||||
|
|
|
|||
Loading…
Reference in a new issue