mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Register Saved Tensors hooks (#60663)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60663 Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D29466223 fbshipit-source-id: 65dc3a935c18a0e6b93a37e24543c696e6ae0321
This commit is contained in:
parent
94965212e5
commit
ee5a97de11
8 changed files with 88 additions and 7 deletions
|
|
@ -4808,7 +4808,13 @@ for shape in [(1,), ()]:
|
|||
with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
|
||||
saved[1].register_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
saved[0].register_hooks(lambda x: x)
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
saved[0].register_hooks(1, 1)
|
||||
saved[0].register_hooks(lambda x: x, lambda x: x)
|
||||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
saved[0].register_hooks(lambda x: x, lambda x: x)
|
||||
y.sum().backward()
|
||||
|
||||
# Using a reference to the SavedTensor object after the
|
||||
|
|
|
|||
|
|
@ -644,6 +644,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/autograd/functions/init.cpp",
|
||||
"torch/csrc/autograd/init.cpp",
|
||||
"torch/csrc/autograd/python_anomaly_mode.cpp",
|
||||
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
|
||||
"torch/csrc/autograd/python_cpp_function.cpp",
|
||||
"torch/csrc/autograd/python_engine.cpp",
|
||||
"torch/csrc/autograd/python_function.cpp",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
#include <torch/csrc/autograd/python_function.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/saved_variable.h>
|
||||
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
|
|
@ -269,8 +270,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
TORCH_CHECK(false, "Trying to create a SavedTensor object from Python is forbidden.");
|
||||
}))
|
||||
.def("register_hooks", [](torch::autograd::SavedVariable &s, py::function &pack_hook, py::function &unpack_hook) {
|
||||
s.register_hooks();
|
||||
});
|
||||
// Because we use a py::object, pybind will increment the refcount of the hook functions for us
|
||||
s.register_hooks(std::make_unique<torch::autograd::PySavedVariableHooks>(pack_hook, unpack_hook));
|
||||
});
|
||||
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
|
|
|||
28
torch/csrc/autograd/python_saved_variable_hooks.cpp
Normal file
28
torch/csrc/autograd/python_saved_variable_hooks.cpp
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) :
|
||||
// steals the reference (we will decref ourselves)
|
||||
pack_hook_(pack_hook.release().ptr()),
|
||||
unpack_hook_(unpack_hook.release().ptr()) {}
|
||||
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-parameter)
|
||||
void PySavedVariableHooks::call_pack_hook(at::Tensor &tensor) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Hooks are not implemented yet");
|
||||
}
|
||||
|
||||
at::Tensor PySavedVariableHooks::call_unpack_hook() {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Hooks are not implemented yet");
|
||||
}
|
||||
|
||||
PySavedVariableHooks::~PySavedVariableHooks() {
|
||||
// If python is already dead, leak the wrapped python objects
|
||||
if (Py_IsInitialized()) {
|
||||
py::gil_scoped_acquire gil;
|
||||
Py_XDECREF(pack_hook_);
|
||||
Py_XDECREF(unpack_hook_);
|
||||
}
|
||||
}
|
||||
}}
|
||||
25
torch/csrc/autograd/python_saved_variable_hooks.h
Normal file
25
torch/csrc/autograd/python_saved_variable_hooks.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/autograd/python_variable.h>
|
||||
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/THP_export.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct PySavedVariableHooks : public SavedVariableHooks {
|
||||
PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
|
||||
void call_pack_hook(at::Tensor &tensor) override;
|
||||
at::Tensor call_unpack_hook() override;
|
||||
~PySavedVariableHooks() override;
|
||||
|
||||
private:
|
||||
PyObject* pack_hook_;
|
||||
PyObject* unpack_hook_;
|
||||
};
|
||||
|
||||
}}
|
||||
|
|
@ -153,7 +153,7 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
|
|||
return var;
|
||||
}
|
||||
|
||||
void SavedVariable::register_hooks() {
|
||||
void SavedVariable::register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks) {
|
||||
if (!data_.defined()) {
|
||||
if (!was_default_constructed_) {
|
||||
TORCH_CHECK(false,
|
||||
|
|
@ -167,14 +167,18 @@ void SavedVariable::register_hooks() {
|
|||
"Calling register_hooks on a saved tensor with value None is forbidden");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(!hooks_,
|
||||
"Calling register_hooks on a saved tensor whose hooks have already been set. "
|
||||
"Hint: only one pair of hooks is allowed at a time.");
|
||||
hooks_ = std::move(hooks);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
const char* ERR_BACKWARD_TWICE =
|
||||
"Trying to backward through the graph a second time (or directly access saved "
|
||||
"variables after they have already been freed). Saved intermediate values "
|
||||
"tensors after they have already been freed). Saved intermediate values "
|
||||
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
|
||||
"retain_graph=True if you need to backward through the graph a second time or "
|
||||
"if you need to access saved variables after calling backward.";
|
||||
"if you need to access saved tensors after calling backward.";
|
||||
|
||||
}} // namespace torch::autograd
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/autograd/forward_grad.h>
|
||||
#include <torch/csrc/autograd/saved_variable_hooks.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
|
@ -37,8 +38,7 @@ class TORCH_API SavedVariable {
|
|||
/// circular reference.
|
||||
Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
|
||||
|
||||
// Temporarily a no op
|
||||
void register_hooks();
|
||||
void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
|
||||
|
||||
void reset_data() {
|
||||
return data_.reset();
|
||||
|
|
@ -71,6 +71,8 @@ class TORCH_API SavedVariable {
|
|||
std::weak_ptr<Node> weak_grad_fn_;
|
||||
c10::VariableVersion version_counter_;
|
||||
|
||||
std::unique_ptr<SavedVariableHooks> hooks_;
|
||||
|
||||
uint32_t saved_version_ = 0;
|
||||
uint32_t output_nr_ = 0;
|
||||
bool was_default_constructed_ = true;
|
||||
|
|
|
|||
13
torch/csrc/autograd/saved_variable_hooks.h
Normal file
13
torch/csrc/autograd/saved_variable_hooks.h
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
struct TORCH_API SavedVariableHooks {
|
||||
virtual void call_pack_hook(at::Tensor &tensor) = 0;
|
||||
virtual at::Tensor call_unpack_hook() = 0;
|
||||
virtual ~SavedVariableHooks() = default;
|
||||
};
|
||||
|
||||
}}
|
||||
Loading…
Reference in a new issue