From 97342ae04bc7393acede4f8140bb7dd8fba32155 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 20 Jan 2023 22:16:54 -0500 Subject: [PATCH] Fix python tensor hooks behavior on inplace (#92734) Pull Request resolved: https://github.com/pytorch/pytorch/pull/92734 Approved by: https://github.com/albanD --- c10/core/impl/PyInterpreter.cpp | 4 + c10/core/impl/PyInterpreter.h | 2 + test/test_autograd.py | 110 ++++++++++++++++++++++++ torch/csrc/autograd/python_variable.cpp | 15 ++++ torch/csrc/autograd/variable.cpp | 12 ++- 5 files changed, 140 insertions(+), 3 deletions(-) diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 8c29f13f3e5..0e251538e14 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -91,6 +91,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { void trace_gpu_device_synchronization() const override {} void trace_gpu_stream_synchronization(uintptr_t stream) const override {} void trace_gpu_event_synchronization(uintptr_t event) const override {} + + void reset_backward_hooks(const TensorImpl* self) const override { + PANIC(reset_backward_hooks); + }; }; void PyInterpreter::disarm() noexcept { diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 3f4a5735138..e329fbca267 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -183,6 +183,8 @@ struct C10_API PyInterpreterVTable { virtual void trace_gpu_device_synchronization() const = 0; virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0; virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0; + + virtual void reset_backward_hooks(const TensorImpl* self) const = 0; }; struct C10_API PyInterpreter { diff --git a/test/test_autograd.py b/test/test_autograd.py index c0ad32ad0b9..b1084306a4b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1476,6 +1476,116 @@ class TestAutograd(TestCase): self.assertEqual(view.grad, view2.grad) self.assertEqual(view.grad, torch.tensor([1.])) + def test_tensor_hooks_inplace(self): + # Check that the second hook gets registered to the new version of tensor + count1 = [0] + count2 = [0] + + def fn1(grad): + count1[0] += 1 + # x2 from mul, x2 from fn2 + self.assertEqual(grad, torch.tensor([4.])) + return grad * 2 + + def fn2(grad): + count2[0] += 1 + self.assertEqual(grad, torch.tensor([1.])) + return grad * 2 + + a = torch.tensor([1.], requires_grad=True) + b = a.clone() + b.register_hook(fn1) + b.mul_(2) + b.register_hook(fn2) + b.sum().backward() + self.assertEqual(count1[0], 1) + self.assertEqual(count2[0], 1) + self.assertEqual(a.grad, torch.tensor([8.])) + + count3 = [0] + + def fn3(grad): + count3[0] += 1 + self.assertEqual(grad, torch.tensor([4.])) + return grad * 2 + + a = torch.tensor([1.], requires_grad=True) + b = a.clone() + b.register_hook(fn3) + # Inplace multiple times is OK + b.mul_(2) + b.mul_(2) + b.sum().backward() + self.assertEqual(count1[0], 1) + self.assertEqual(a.grad, torch.tensor([8.])) + + def test_tensor_hooks_inplace_multiple_outputs(self): + class DoubleMul(Function): + @staticmethod + def forward(ctx, x): + return x * 2, x * 3 + + @staticmethod + def backward(ctx, g1, g2): + return g1 * 2 + g2 * 3 + + var_mean = partial(torch.var_mean, dim=0) + + for fn in (DoubleMul.apply, var_mean): + counts = [0, 0, 0] + + def fn0(grad): + counts[0] += 1 + self.assertEqual(grad, torch.ones_like(out1) * 2) + + def fn1(grad): + counts[1] += 1 + self.assertEqual(grad, torch.ones_like(out1) * 3) + + def fn2(grad): + counts[2] += 1 + self.assertEqual(grad, torch.ones_like(out1)) + + b = torch.rand(3, 3, requires_grad=True) + out1, out2 = fn(b) + out1.register_hook(fn0) + out2.register_hook(fn1) + # node refers to two hook dicts + # out1 no longer no longer points to its old hook dict + out1.mul_(2) + # fn2 is registered to out1's new hook dict + out1.register_hook(fn2) + (out1 + out2 * 3).sum().backward() + self.assertEqual(counts, [1, 1, 1]) + + def test_tensor_hooks_inplace_over_view(self): + # There might be a better UX here, but this is the way it is now + count = [0] + + def fn0(grad): + self.fail() + + def fn1(grad): + self.fail() + + def fn2(grad): + count[0] += 1 + self.assertEqual(grad, torch.tensor([1.])) + + base = torch.tensor([1.], requires_grad=True).clone() + view = base[:] + view2 = base[:] + view.register_hook(fn0) + view2.register_hook(fn1) + view.mul_(2) + # We need to explicitly trigger an update to view to update its grad_fn + view2.grad_fn + view2.register_hook(fn2) + (view + view2).sum().backward() + # The hooks originally registered to view are not fired, one must explicitly + # trigger an update to the view's grad_fn, and then register a new hook + self.assertEqual(count[0], 1) + def test_retain_grad_cycle(self): x = torch.ones(5, 5, requires_grad=True) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 973022360fc..a25147cd5b7 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -277,6 +277,8 @@ struct ConcretePyInterpreterVTable final CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event); } + void reset_backward_hooks(const TensorImpl* self) const override; + static ConcretePyInterpreterVTable* instance() { static ConcretePyInterpreterVTable s; return &s; @@ -2815,4 +2817,17 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( END_HANDLE_TH_ERRORS_PYBIND } +void ConcretePyInterpreterVTable::reset_backward_hooks( + const c10::TensorImpl* self) const { + pybind11::gil_scoped_acquire gil; + at::impl::MaybeSetTLSOnEntryGuard guard; + HANDLE_TH_ERRORS + Tensor self_t = Tensor( + c10::intrusive_ptr:: + unsafe_reclaim_from_nonowning(const_cast(self))); + auto self_p = py::reinterpret_steal(THPVariable_Wrap(self_t)); + PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None); + END_HANDLE_TH_ERRORS_PYBIND +} + } // anonymous namespace diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 8212e174846..18a1e0f85d3 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -156,7 +156,7 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { return get_autograd_meta(self); } -void update_cpp_hooks_on_new_gradfn( +void update_tensor_hooks_on_new_gradfn( const at::TensorBase& self, const std::shared_ptr& new_fn) { // This function is called whenever the grad_fn of the tensor is @@ -173,6 +173,11 @@ void update_cpp_hooks_on_new_gradfn( // old grad_fn so hooks registered to the older version of the tensor // will continue to be active. meta->cpp_hooks_list_ = nullptr; + const c10::impl::PyInterpreter* interp = + self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter(); + if (interp) { + (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl()); + } // (2) If there is a retains_grad hook registered, move that from the // old cpp_hooks_list_ to the new one if (self.retains_grad()) { @@ -215,7 +220,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) { set_gradient_edge(self, std::move(gradient_edge)); // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly - torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn()); + torch::autograd::impl::update_tensor_hooks_on_new_gradfn( + self, self.grad_fn()); } void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) { @@ -728,7 +734,7 @@ const std::shared_ptr& VariableHooks::grad_fn( } diff_view_meta->set_attr_version(current_version); - torch::autograd::impl::update_cpp_hooks_on_new_gradfn( + torch::autograd::impl::update_tensor_hooks_on_new_gradfn( self, diff_view_meta->grad_fn_); } return diff_view_meta->grad_fn_;