Fix python tensor hooks behavior on inplace (#92734)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92734
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2023-01-20 22:16:54 -05:00 committed by PyTorch MergeBot
parent de69cedf98
commit 97342ae04b
5 changed files with 140 additions and 3 deletions

View file

@ -91,6 +91,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void trace_gpu_device_synchronization() const override {}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
void trace_gpu_event_synchronization(uintptr_t event) const override {}
void reset_backward_hooks(const TensorImpl* self) const override {
PANIC(reset_backward_hooks);
};
};
void PyInterpreter::disarm() noexcept {

View file

@ -183,6 +183,8 @@ struct C10_API PyInterpreterVTable {
virtual void trace_gpu_device_synchronization() const = 0;
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
};
struct C10_API PyInterpreter {

View file

@ -1476,6 +1476,116 @@ class TestAutograd(TestCase):
self.assertEqual(view.grad, view2.grad)
self.assertEqual(view.grad, torch.tensor([1.]))
def test_tensor_hooks_inplace(self):
# Check that the second hook gets registered to the new version of tensor
count1 = [0]
count2 = [0]
def fn1(grad):
count1[0] += 1
# x2 from mul, x2 from fn2
self.assertEqual(grad, torch.tensor([4.]))
return grad * 2
def fn2(grad):
count2[0] += 1
self.assertEqual(grad, torch.tensor([1.]))
return grad * 2
a = torch.tensor([1.], requires_grad=True)
b = a.clone()
b.register_hook(fn1)
b.mul_(2)
b.register_hook(fn2)
b.sum().backward()
self.assertEqual(count1[0], 1)
self.assertEqual(count2[0], 1)
self.assertEqual(a.grad, torch.tensor([8.]))
count3 = [0]
def fn3(grad):
count3[0] += 1
self.assertEqual(grad, torch.tensor([4.]))
return grad * 2
a = torch.tensor([1.], requires_grad=True)
b = a.clone()
b.register_hook(fn3)
# Inplace multiple times is OK
b.mul_(2)
b.mul_(2)
b.sum().backward()
self.assertEqual(count1[0], 1)
self.assertEqual(a.grad, torch.tensor([8.]))
def test_tensor_hooks_inplace_multiple_outputs(self):
class DoubleMul(Function):
@staticmethod
def forward(ctx, x):
return x * 2, x * 3
@staticmethod
def backward(ctx, g1, g2):
return g1 * 2 + g2 * 3
var_mean = partial(torch.var_mean, dim=0)
for fn in (DoubleMul.apply, var_mean):
counts = [0, 0, 0]
def fn0(grad):
counts[0] += 1
self.assertEqual(grad, torch.ones_like(out1) * 2)
def fn1(grad):
counts[1] += 1
self.assertEqual(grad, torch.ones_like(out1) * 3)
def fn2(grad):
counts[2] += 1
self.assertEqual(grad, torch.ones_like(out1))
b = torch.rand(3, 3, requires_grad=True)
out1, out2 = fn(b)
out1.register_hook(fn0)
out2.register_hook(fn1)
# node refers to two hook dicts
# out1 no longer no longer points to its old hook dict
out1.mul_(2)
# fn2 is registered to out1's new hook dict
out1.register_hook(fn2)
(out1 + out2 * 3).sum().backward()
self.assertEqual(counts, [1, 1, 1])
def test_tensor_hooks_inplace_over_view(self):
# There might be a better UX here, but this is the way it is now
count = [0]
def fn0(grad):
self.fail()
def fn1(grad):
self.fail()
def fn2(grad):
count[0] += 1
self.assertEqual(grad, torch.tensor([1.]))
base = torch.tensor([1.], requires_grad=True).clone()
view = base[:]
view2 = base[:]
view.register_hook(fn0)
view2.register_hook(fn1)
view.mul_(2)
# We need to explicitly trigger an update to view to update its grad_fn
view2.grad_fn
view2.register_hook(fn2)
(view + view2).sum().backward()
# The hooks originally registered to view are not fired, one must explicitly
# trigger an update to the view's grad_fn, and then register a new hook
self.assertEqual(count[0], 1)
def test_retain_grad_cycle(self):
x = torch.ones(5, 5, requires_grad=True)

View file

@ -277,6 +277,8 @@ struct ConcretePyInterpreterVTable final
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
}
void reset_backward_hooks(const TensorImpl* self) const override;
static ConcretePyInterpreterVTable* instance() {
static ConcretePyInterpreterVTable s;
return &s;
@ -2815,4 +2817,17 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
END_HANDLE_TH_ERRORS_PYBIND
}
void ConcretePyInterpreterVTable::reset_backward_hooks(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
Tensor self_t = Tensor(
c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
END_HANDLE_TH_ERRORS_PYBIND
}
} // anonymous namespace

View file

@ -156,7 +156,7 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
return get_autograd_meta(self);
}
void update_cpp_hooks_on_new_gradfn(
void update_tensor_hooks_on_new_gradfn(
const at::TensorBase& self,
const std::shared_ptr<torch::autograd::Node>& new_fn) {
// This function is called whenever the grad_fn of the tensor is
@ -173,6 +173,11 @@ void update_cpp_hooks_on_new_gradfn(
// old grad_fn so hooks registered to the older version of the tensor
// will continue to be active.
meta->cpp_hooks_list_ = nullptr;
const c10::impl::PyInterpreter* interp =
self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interp) {
(*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
}
// (2) If there is a retains_grad hook registered, move that from the
// old cpp_hooks_list_ to the new one
if (self.retains_grad()) {
@ -215,7 +220,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
set_gradient_edge(self, std::move(gradient_edge));
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn());
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
self, self.grad_fn());
}
void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
@ -728,7 +734,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
}
diff_view_meta->set_attr_version(current_version);
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
self, diff_view_meta->grad_fn_);
}
return diff_view_meta->grad_fn_;