mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Follow the undefined Tensor <-> None rule better in torch dispatch (#67793)
Summary: As per title. This in particular allows to more easily override backward function for which the underlying backend returns `None` Pull Request resolved: https://github.com/pytorch/pytorch/pull/67793 Reviewed By: zou3519 Differential Revision: D32242962 Pulled By: albanD fbshipit-source-id: 6e114def90ee9499161e1303d301ba7fd003ff89
This commit is contained in:
parent
0465f64bb8
commit
28c519961f
4 changed files with 62 additions and 11 deletions
|
|
@ -10795,9 +10795,8 @@ dedent """
|
|||
modelA = torch.jit.script(A())
|
||||
self.assertEqual(modelA(), 9)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "expected value of type Tensor", "self.ignored"):
|
||||
modelB = torch.jit.script(B())
|
||||
modelB()
|
||||
modelB = torch.jit.script(B())
|
||||
self.assertEqual(modelB(), 9)
|
||||
|
||||
def test_addmm_grad(self):
|
||||
""" This test checks several things:
|
||||
|
|
|
|||
|
|
@ -440,7 +440,7 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
|||
self.assertEqual(x, None)
|
||||
|
||||
def test_enable_python_mode_subclass_autograd_device_check(self) -> None:
|
||||
class NonWrapperSublass(torch.Tensor):
|
||||
class NonWrapperSubclass(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
|
@ -456,25 +456,71 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
|||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, NonWrapperSublass) else e
|
||||
return e.elem if isinstance(e, NonWrapperSubclass) else e
|
||||
|
||||
def wrap(e):
|
||||
return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e
|
||||
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
||||
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
||||
return rs
|
||||
|
||||
x = NonWrapperSublass(torch.tensor([3.0, 4.0], requires_grad=True))
|
||||
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
|
||||
y = torch.randn(2, requires_grad=True)
|
||||
z = x * y
|
||||
self.assertIsInstance(z, NonWrapperSublass)
|
||||
self.assertIsInstance(z, NonWrapperSubclass)
|
||||
z.sum().backward(torch.tensor(1))
|
||||
self.assertEqual(x.grad, y)
|
||||
self.assertEqual(y.grad, x)
|
||||
|
||||
def test_none_wrapping(self):
|
||||
# A Tensor subclass that returns None when doing add
|
||||
# See LoggingTensor above for more details on the subclass
|
||||
class SubclassWithNone(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, elem.size(),
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device=elem.device, requires_grad=elem.requires_grad
|
||||
)
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, SubclassWithNone) else e
|
||||
|
||||
def wrap(e):
|
||||
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
# no_dispatch is only needed if you use enable_python_mode.
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
if func.__name__ == "add":
|
||||
return None
|
||||
else:
|
||||
return rs
|
||||
|
||||
x = SubclassWithNone(torch.rand(2))
|
||||
# Make sure both run without error
|
||||
self.assertIsInstance(x * 2, SubclassWithNone)
|
||||
self.assertIsNone(x + 2)
|
||||
|
||||
x.requires_grad_()
|
||||
out = x.acos().sum()
|
||||
|
||||
# The backward of acos does add then rsqrt so here we make sure that the
|
||||
# undefined Tensor generated by the user code is nicely handled.
|
||||
# If acos formula changes in the future, this can be replaced by any other
|
||||
# function that does add then something in the backward in a composite way
|
||||
with self.assertRaisesRegex(RuntimeError, "found an undefined Tensor"):
|
||||
out.backward()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -42,14 +42,16 @@ C10_EXPORT std::vector<at::DeprecatedTypeProperties*> allCUDATypes() {
|
|||
namespace {
|
||||
const Variable & checked_cast_variable(const Tensor & t, const char * name, int pos) {
|
||||
if (!t.defined()) {
|
||||
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
|
||||
AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
|
||||
"for argument #", pos, " '", name, "'");
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
Variable & checked_cast_variable(Tensor & t, const char * name, int pos) {
|
||||
if (!t.defined()) {
|
||||
AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor for argument #", pos, " '", name, "'");
|
||||
AT_ERROR("Expected a proper Tensor but got None (or an undefined Tensor in C++) ",
|
||||
"for argument #", pos, " '", name, "'");
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ void clear_registered_instances(void* ptr) {
|
|||
IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
|
||||
switch (type->kind()) {
|
||||
case TypeKind::TensorType: {
|
||||
if (obj.ptr() == Py_None) {
|
||||
// None gets converted to undefined Tensors
|
||||
return autograd::Variable();
|
||||
}
|
||||
auto var = py::cast<autograd::Variable>(obj);
|
||||
if (var.is_sparse()) {
|
||||
TORCH_WARN_ONCE(
|
||||
|
|
|
|||
Loading…
Reference in a new issue