Follow the undefined Tensor <-> None rule better in torch dispatch (#67793)

Summary:
As per title. This in particular allows to more easily override backward function for which the underlying backend returns `None`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67793

Reviewed By: zou3519

Differential Revision: D32242962

Pulled By: albanD

fbshipit-source-id: 6e114def90ee9499161e1303d301ba7fd003ff89
This commit is contained in:
Alban Desmaison 2021-12-02 07:45:35 -08:00 committed by Facebook GitHub Bot
parent 0465f64bb8
commit 28c519961f
4 changed files with 62 additions and 11 deletions

View file

@ -10795,9 +10795,8 @@ dedent """
modelA = torch.jit.script(A())
self.assertEqual(modelA(), 9)
with self.assertRaisesRegexWithHighlight(RuntimeError, "expected value of type Tensor", "self.ignored"):
modelB = torch.jit.script(B())
modelB()
modelB = torch.jit.script(B())
self.assertEqual(modelB(), 9)
def test_addmm_grad(self):
""" This test checks several things:

View file

@ -440,7 +440,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
self.assertEqual(x, None)
def test_enable_python_mode_subclass_autograd_device_check(self) -> None:
class NonWrapperSublass(torch.Tensor):
class NonWrapperSubclass(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
@ -456,25 +456,71 @@ $6 = torch._ops.aten.add_($1, $5)''')
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, NonWrapperSublass) else e
return e.elem if isinstance(e, NonWrapperSubclass) else e
def wrap(e):
return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs
x = NonWrapperSublass(torch.tensor([3.0, 4.0], requires_grad=True))
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
y = torch.randn(2, requires_grad=True)
z = x * y
self.assertIsInstance(z, NonWrapperSublass)
self.assertIsInstance(z, NonWrapperSubclass)
z.sum().backward(torch.tensor(1))
self.assertEqual(x.grad, y)
self.assertEqual(y.grad, x)
def test_none_wrapping(self):
# A Tensor subclass that returns None when doing add
# See LoggingTensor above for more details on the subclass
class SubclassWithNone(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, SubclassWithNone) else e
def wrap(e):
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
# no_dispatch is only needed if you use enable_python_mode.
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
if func.__name__ == "add":
return None
else:
return rs
x = SubclassWithNone(torch.rand(2))
# Make sure both run without error
self.assertIsInstance(x * 2, SubclassWithNone)
self.assertIsNone(x + 2)
x.requires_grad_()
out = x.acos().sum()
# The backward of acos does add then rsqrt so here we make sure that the
# undefined Tensor generated by the user code is nicely handled.
# If acos formula changes in the future, this can be replaced by any other
# function that does add then something in the backward in a composite way
with self.assertRaisesRegex(RuntimeError, "found an undefined Tensor"):
out.backward()
if __name__ == '__main__':
run_tests()

View file

@ -42,14 +42,16 @@ C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
namespace {
const Variable & checked_cast_variable(const Tensor & t, const char * name, int pos) {
if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
"for argument #", pos, " '", name, "'");
}
return t;
}
Variable & checked_cast_variable(Tensor & t, const char * name, int pos) {
if (!t.defined()) {
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
"for argument #", pos, " '", name, "'");
}
return t;
}

View file

@ -26,6 +26,10 @@ void clear_registered_instances(void* ptr) {
IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
switch (type->kind()) {
case TypeKind::TensorType: {
if (obj.ptr() == Py_None) {
// None gets converted to undefined Tensors
return autograd::Variable();
}
auto var = py::cast<autograd::Variable>(obj);
if (var.is_sparse()) {
TORCH_WARN_ONCE(