Register Saved Tensors hooks (#60663)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60663

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D29466223

fbshipit-source-id: 65dc3a935c18a0e6b93a37e24543c696e6ae0321
This commit is contained in:
Victor Quach 2021-07-15 08:07:56 -07:00 committed by Facebook GitHub Bot
parent 94965212e5
commit ee5a97de11
8 changed files with 88 additions and 7 deletions

View file

@ -4808,7 +4808,13 @@ for shape in [(1,), ()]:
with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
saved[1].register_hooks(lambda x: x, lambda x: x)
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
saved[0].register_hooks(lambda x: x)
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
saved[0].register_hooks(1, 1)
saved[0].register_hooks(lambda x: x, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "already been set"):
saved[0].register_hooks(lambda x: x, lambda x: x)
y.sum().backward()
# Using a reference to the SavedTensor object after the

View file

@ -644,6 +644,7 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/python_anomaly_mode.cpp",
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_function.cpp",

View file

@ -11,6 +11,7 @@
#include <torch/csrc/autograd/python_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
@ -269,8 +270,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
TORCH_CHECK(false, "Trying to create a SavedTensor object from Python is forbidden.");
}))
.def("register_hooks", [](torch::autograd::SavedVariable &s, py::function &pack_hook, py::function &unpack_hook) {
s.register_hooks();
});
// Because we use a py::object, pybind will increment the refcount of the hook functions for us
s.register_hooks(std::make_unique<torch::autograd::PySavedVariableHooks>(pack_hook, unpack_hook));
});
Py_RETURN_TRUE;
}

View file

@ -0,0 +1,28 @@
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
namespace py = pybind11;
namespace torch { namespace autograd {
PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) :
// steals the reference (we will decref ourselves)
pack_hook_(pack_hook.release().ptr()),
unpack_hook_(unpack_hook.release().ptr()) {}
// NOLINTNEXTLINE(clang-diagnostic-unused-parameter)
void PySavedVariableHooks::call_pack_hook(at::Tensor &tensor) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Hooks are not implemented yet");
}
at::Tensor PySavedVariableHooks::call_unpack_hook() {
TORCH_CHECK_NOT_IMPLEMENTED(false, "Hooks are not implemented yet");
}
PySavedVariableHooks::~PySavedVariableHooks() {
// If python is already dead, leak the wrapped python objects
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook_);
Py_XDECREF(unpack_hook_);
}
}
}}

View file

@ -0,0 +1,25 @@
#pragma once
#include <pybind11/pybind11.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/THP_export.h>
#include <ATen/ATen.h>
namespace py = pybind11;
namespace torch { namespace autograd {
struct PySavedVariableHooks : public SavedVariableHooks {
PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
void call_pack_hook(at::Tensor &tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
private:
PyObject* pack_hook_;
PyObject* unpack_hook_;
};
}}

View file

@ -153,7 +153,7 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
return var;
}
void SavedVariable::register_hooks() {
void SavedVariable::register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks) {
if (!data_.defined()) {
if (!was_default_constructed_) {
TORCH_CHECK(false,
@ -167,14 +167,18 @@ void SavedVariable::register_hooks() {
"Calling register_hooks on a saved tensor with value None is forbidden");
}
}
TORCH_CHECK(!hooks_,
"Calling register_hooks on a saved tensor whose hooks have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
hooks_ = std::move(hooks);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time (or directly access saved "
"variables after they have already been freed). Saved intermediate values "
"tensors after they have already been freed). Saved intermediate values "
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
"retain_graph=True if you need to backward through the graph a second time or "
"if you need to access saved variables after calling backward.";
"if you need to access saved tensors after calling backward.";
}} // namespace torch::autograd

View file

@ -2,6 +2,7 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/forward_grad.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <ATen/ATen.h>
@ -37,8 +38,7 @@ class TORCH_API SavedVariable {
/// circular reference.
Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
// Temporarily a no op
void register_hooks();
void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
void reset_data() {
return data_.reset();
@ -71,6 +71,8 @@ class TORCH_API SavedVariable {
std::weak_ptr<Node> weak_grad_fn_;
c10::VariableVersion version_counter_;
std::unique_ptr<SavedVariableHooks> hooks_;
uint32_t saved_version_ = 0;
uint32_t output_nr_ = 0;
bool was_default_constructed_ = true;

View file

@ -0,0 +1,13 @@
#pragma once
#include <ATen/ATen.h>
namespace torch { namespace autograd {
struct TORCH_API SavedVariableHooks {
virtual void call_pack_hook(at::Tensor &tensor) = 0;
virtual at::Tensor call_unpack_hook() = 0;
virtual ~SavedVariableHooks() = default;
};
}}