From 28c519961fee288db33dbb48d32aee2fbb1d6e3d Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Thu, 2 Dec 2021 07:45:35 -0800 Subject: [PATCH] Follow the undefined Tensor <-> None rule better in torch dispatch (#67793) Summary: As per title. This in particular allows to more easily override backward function for which the underlying backend returns `None` Pull Request resolved: https://github.com/pytorch/pytorch/pull/67793 Reviewed By: zou3519 Differential Revision: D32242962 Pulled By: albanD fbshipit-source-id: 6e114def90ee9499161e1303d301ba7fd003ff89 --- test/test_jit.py | 5 +- test/test_python_dispatch.py | 58 +++++++++++++++++++--- torch/csrc/autograd/VariableTypeManual.cpp | 6 ++- torch/csrc/jit/python/pybind_utils.cpp | 4 ++ 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 1b1126dcc4d..81c8671de2e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10795,9 +10795,8 @@ dedent """ modelA = torch.jit.script(A()) self.assertEqual(modelA(), 9) - with self.assertRaisesRegexWithHighlight(RuntimeError, "expected value of type Tensor", "self.ignored"): - modelB = torch.jit.script(B()) - modelB() + modelB = torch.jit.script(B()) + self.assertEqual(modelB(), 9) def test_addmm_grad(self): """ This test checks several things: diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 82099089130..b016da0d08e 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -440,7 +440,7 @@ $6 = torch._ops.aten.add_($1, $5)''') self.assertEqual(x, None) def test_enable_python_mode_subclass_autograd_device_check(self) -> None: - class NonWrapperSublass(torch.Tensor): + class NonWrapperSubclass(torch.Tensor): elem: torch.Tensor __slots__ = ['elem'] @@ -456,25 +456,71 @@ $6 = torch._ops.aten.add_($1, $5)''') @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): - return e.elem if isinstance(e, NonWrapperSublass) else e + return e.elem if isinstance(e, NonWrapperSubclass) else e def wrap(e): - return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e + return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. with no_dispatch(): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) - logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) + logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs - x = NonWrapperSublass(torch.tensor([3.0, 4.0], requires_grad=True)) + x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) y = torch.randn(2, requires_grad=True) z = x * y - self.assertIsInstance(z, NonWrapperSublass) + self.assertIsInstance(z, NonWrapperSubclass) z.sum().backward(torch.tensor(1)) self.assertEqual(x.grad, y) self.assertEqual(y.grad, x) + def test_none_wrapping(self): + # A Tensor subclass that returns None when doing add + # See LoggingTensor above for more details on the subclass + class SubclassWithNone(torch.Tensor): + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad + ) + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, SubclassWithNone) else e + + def wrap(e): + return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e + + # no_dispatch is only needed if you use enable_python_mode. + # It prevents infinite recursion. + with no_dispatch(): + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + if func.__name__ == "add": + return None + else: + return rs + + x = SubclassWithNone(torch.rand(2)) + # Make sure both run without error + self.assertIsInstance(x * 2, SubclassWithNone) + self.assertIsNone(x + 2) + + x.requires_grad_() + out = x.acos().sum() + + # The backward of acos does add then rsqrt so here we make sure that the + # undefined Tensor generated by the user code is nicely handled. + # If acos formula changes in the future, this can be replaced by any other + # function that does add then something in the backward in a composite way + with self.assertRaisesRegex(RuntimeError, "found an undefined Tensor"): + out.backward() + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 7489b98709c..531bcd36420 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -42,14 +42,16 @@ C10_EXPORT std::vector allCUDATypes() { namespace { const Variable & checked_cast_variable(const Tensor & t, const char * name, int pos) { if (!t.defined()) { - AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); + AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ", + "for argument #", pos, " '", name, "'"); } return t; } Variable & checked_cast_variable(Tensor & t, const char * name, int pos) { if (!t.defined()) { - AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'"); + AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ", + "for argument #", pos, " '", name, "'"); } return t; } diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 673b139bd75..a43af423bf4 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -26,6 +26,10 @@ void clear_registered_instances(void* ptr) { IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { switch (type->kind()) { case TypeKind::TensorType: { + if (obj.ptr() == Py_None) { + // None gets converted to undefined Tensors + return autograd::Variable(); + } auto var = py::cast(obj); if (var.is_sparse()) { TORCH_WARN_ONCE(