mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix python tensor hooks behavior on inplace (#92734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92734 Approved by: https://github.com/albanD
This commit is contained in:
parent
de69cedf98
commit
97342ae04b
5 changed files with 140 additions and 3 deletions
|
|
@ -91,6 +91,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||
void trace_gpu_device_synchronization() const override {}
|
||||
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
|
||||
void trace_gpu_event_synchronization(uintptr_t event) const override {}
|
||||
|
||||
void reset_backward_hooks(const TensorImpl* self) const override {
|
||||
PANIC(reset_backward_hooks);
|
||||
};
|
||||
};
|
||||
|
||||
void PyInterpreter::disarm() noexcept {
|
||||
|
|
|
|||
|
|
@ -183,6 +183,8 @@ struct C10_API PyInterpreterVTable {
|
|||
virtual void trace_gpu_device_synchronization() const = 0;
|
||||
virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
|
||||
virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
|
||||
|
||||
virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
|
||||
};
|
||||
|
||||
struct C10_API PyInterpreter {
|
||||
|
|
|
|||
|
|
@ -1476,6 +1476,116 @@ class TestAutograd(TestCase):
|
|||
self.assertEqual(view.grad, view2.grad)
|
||||
self.assertEqual(view.grad, torch.tensor([1.]))
|
||||
|
||||
def test_tensor_hooks_inplace(self):
|
||||
# Check that the second hook gets registered to the new version of tensor
|
||||
count1 = [0]
|
||||
count2 = [0]
|
||||
|
||||
def fn1(grad):
|
||||
count1[0] += 1
|
||||
# x2 from mul, x2 from fn2
|
||||
self.assertEqual(grad, torch.tensor([4.]))
|
||||
return grad * 2
|
||||
|
||||
def fn2(grad):
|
||||
count2[0] += 1
|
||||
self.assertEqual(grad, torch.tensor([1.]))
|
||||
return grad * 2
|
||||
|
||||
a = torch.tensor([1.], requires_grad=True)
|
||||
b = a.clone()
|
||||
b.register_hook(fn1)
|
||||
b.mul_(2)
|
||||
b.register_hook(fn2)
|
||||
b.sum().backward()
|
||||
self.assertEqual(count1[0], 1)
|
||||
self.assertEqual(count2[0], 1)
|
||||
self.assertEqual(a.grad, torch.tensor([8.]))
|
||||
|
||||
count3 = [0]
|
||||
|
||||
def fn3(grad):
|
||||
count3[0] += 1
|
||||
self.assertEqual(grad, torch.tensor([4.]))
|
||||
return grad * 2
|
||||
|
||||
a = torch.tensor([1.], requires_grad=True)
|
||||
b = a.clone()
|
||||
b.register_hook(fn3)
|
||||
# Inplace multiple times is OK
|
||||
b.mul_(2)
|
||||
b.mul_(2)
|
||||
b.sum().backward()
|
||||
self.assertEqual(count1[0], 1)
|
||||
self.assertEqual(a.grad, torch.tensor([8.]))
|
||||
|
||||
def test_tensor_hooks_inplace_multiple_outputs(self):
|
||||
class DoubleMul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x * 2, x * 3
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, g1, g2):
|
||||
return g1 * 2 + g2 * 3
|
||||
|
||||
var_mean = partial(torch.var_mean, dim=0)
|
||||
|
||||
for fn in (DoubleMul.apply, var_mean):
|
||||
counts = [0, 0, 0]
|
||||
|
||||
def fn0(grad):
|
||||
counts[0] += 1
|
||||
self.assertEqual(grad, torch.ones_like(out1) * 2)
|
||||
|
||||
def fn1(grad):
|
||||
counts[1] += 1
|
||||
self.assertEqual(grad, torch.ones_like(out1) * 3)
|
||||
|
||||
def fn2(grad):
|
||||
counts[2] += 1
|
||||
self.assertEqual(grad, torch.ones_like(out1))
|
||||
|
||||
b = torch.rand(3, 3, requires_grad=True)
|
||||
out1, out2 = fn(b)
|
||||
out1.register_hook(fn0)
|
||||
out2.register_hook(fn1)
|
||||
# node refers to two hook dicts
|
||||
# out1 no longer no longer points to its old hook dict
|
||||
out1.mul_(2)
|
||||
# fn2 is registered to out1's new hook dict
|
||||
out1.register_hook(fn2)
|
||||
(out1 + out2 * 3).sum().backward()
|
||||
self.assertEqual(counts, [1, 1, 1])
|
||||
|
||||
def test_tensor_hooks_inplace_over_view(self):
|
||||
# There might be a better UX here, but this is the way it is now
|
||||
count = [0]
|
||||
|
||||
def fn0(grad):
|
||||
self.fail()
|
||||
|
||||
def fn1(grad):
|
||||
self.fail()
|
||||
|
||||
def fn2(grad):
|
||||
count[0] += 1
|
||||
self.assertEqual(grad, torch.tensor([1.]))
|
||||
|
||||
base = torch.tensor([1.], requires_grad=True).clone()
|
||||
view = base[:]
|
||||
view2 = base[:]
|
||||
view.register_hook(fn0)
|
||||
view2.register_hook(fn1)
|
||||
view.mul_(2)
|
||||
# We need to explicitly trigger an update to view to update its grad_fn
|
||||
view2.grad_fn
|
||||
view2.register_hook(fn2)
|
||||
(view + view2).sum().backward()
|
||||
# The hooks originally registered to view are not fired, one must explicitly
|
||||
# trigger an update to the view's grad_fn, and then register a new hook
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
def test_retain_grad_cycle(self):
|
||||
x = torch.ones(5, 5, requires_grad=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -277,6 +277,8 @@ struct ConcretePyInterpreterVTable final
|
|||
CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event);
|
||||
}
|
||||
|
||||
void reset_backward_hooks(const TensorImpl* self) const override;
|
||||
|
||||
static ConcretePyInterpreterVTable* instance() {
|
||||
static ConcretePyInterpreterVTable s;
|
||||
return &s;
|
||||
|
|
@ -2815,4 +2817,17 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
|||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}
|
||||
|
||||
void ConcretePyInterpreterVTable::reset_backward_hooks(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
HANDLE_TH_ERRORS
|
||||
Tensor self_t = Tensor(
|
||||
c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
|
||||
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
|
||||
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
|
||||
PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
|
|||
return get_autograd_meta(self);
|
||||
}
|
||||
|
||||
void update_cpp_hooks_on_new_gradfn(
|
||||
void update_tensor_hooks_on_new_gradfn(
|
||||
const at::TensorBase& self,
|
||||
const std::shared_ptr<torch::autograd::Node>& new_fn) {
|
||||
// This function is called whenever the grad_fn of the tensor is
|
||||
|
|
@ -173,6 +173,11 @@ void update_cpp_hooks_on_new_gradfn(
|
|||
// old grad_fn so hooks registered to the older version of the tensor
|
||||
// will continue to be active.
|
||||
meta->cpp_hooks_list_ = nullptr;
|
||||
const c10::impl::PyInterpreter* interp =
|
||||
self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
||||
if (interp) {
|
||||
(*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
|
||||
}
|
||||
// (2) If there is a retains_grad hook registered, move that from the
|
||||
// old cpp_hooks_list_ to the new one
|
||||
if (self.retains_grad()) {
|
||||
|
|
@ -215,7 +220,8 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
|
|||
|
||||
set_gradient_edge(self, std::move(gradient_edge));
|
||||
// Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
|
||||
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(self, self.grad_fn());
|
||||
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
|
||||
self, self.grad_fn());
|
||||
}
|
||||
|
||||
void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
|
||||
|
|
@ -728,7 +734,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
|||
}
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
|
||||
torch::autograd::impl::update_cpp_hooks_on_new_gradfn(
|
||||
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
|
||||
self, diff_view_meta->grad_fn_);
|
||||
}
|
||||
return diff_view_meta->grad_fn_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue