mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Support non-tensor inputs and outputs for checkpointed functions. (#52422)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52422 As mentioned in https://github.com/pytorch/pytorch/issues/52415, `torch.utils.checkpoint` doesn't support checkpointing for functions which have non-tensor inputs and outputs. This PR resolves this issue by ensuring the autograd machinery ignores the non-tensor inputs and outputs and processes the tensors accordingly. ghstack-source-id: 124406867 Test Plan: 1) unit test 2) waitforbuildbot Reviewed By: albanD Differential Revision: D26507228 fbshipit-source-id: 0a5a1591570814176185362e83ad18dabd9c84b0
This commit is contained in:
parent
9d9986fd10
commit
4fa47e5e7d
7 changed files with 293 additions and 49 deletions
|
|
@ -3,6 +3,7 @@ import gc
|
|||
import sys
|
||||
import io
|
||||
import math
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
|
|
@ -292,6 +293,95 @@ class TestAutograd(TestCase):
|
|||
with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
|
||||
t3.sum().backward()
|
||||
|
||||
def test_custom_function_non_tensor_inputs_outputs(self):
|
||||
class MyFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, t1, t2, scale, t3):
|
||||
t4 = t1 + t2 * t3
|
||||
t5 = t1 * t2 + t3
|
||||
t4 *= scale
|
||||
t5 *= scale
|
||||
|
||||
# Save scale
|
||||
ctx.scale = scale
|
||||
ctx.save_for_backward(t1, t2, t3)
|
||||
return scale, t4, None, True, t5, "bar", t1
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, *grads):
|
||||
# Verify grads
|
||||
self.assertEqual(7, len(grads))
|
||||
self.assertIsNone(grads[0])
|
||||
self.assertIsNone(grads[2])
|
||||
self.assertIsNone(grads[3])
|
||||
self.assertIsNone(grads[5])
|
||||
|
||||
scale = ctx.scale
|
||||
var1, var2, var3 = ctx.saved_tensors
|
||||
return (
|
||||
grads[1] * scale + grads[4] * var2 * scale + grads[6],
|
||||
grads[1] * var3 * scale + grads[4] * var1 * scale,
|
||||
None,
|
||||
grads[1] * var2 * scale + grads[4] * scale,
|
||||
)
|
||||
|
||||
t1 = torch.rand(10, requires_grad=True)
|
||||
t2 = torch.rand(10, requires_grad=True)
|
||||
t3 = torch.rand(10)
|
||||
scale = random.randint(0, 10)
|
||||
res = MyFunction.apply(t1, t2, scale, t3)
|
||||
self.assertEqual(scale, res[0])
|
||||
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
||||
self.assertEqual(None, res[2])
|
||||
self.assertEqual(True, res[3])
|
||||
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
||||
self.assertEqual("bar", res[5])
|
||||
self.assertEqual(t1, res[6])
|
||||
|
||||
# Validate running backward.
|
||||
torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
|
||||
self.assertIsNotNone(t1.grad)
|
||||
self.assertIsNotNone(t2.grad)
|
||||
self.assertIsNone(t3.grad)
|
||||
|
||||
# Test gradcheck
|
||||
def foo(t1, t2, t3):
|
||||
res = MyFunction.apply(t1, t2, scale, t3)
|
||||
return res[1], res[4], res[6]
|
||||
|
||||
gradcheck(foo, (t1, t2, t3))
|
||||
|
||||
def test_custom_function_no_tensors(self):
|
||||
class MyFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, t1, t2, scale, t3):
|
||||
t4 = t1 + t2 * t3
|
||||
t5 = t1 * t2 + t3
|
||||
t4 *= scale
|
||||
t5 *= scale
|
||||
return scale, t4, None, True, t5, "bar", t1
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, *args):
|
||||
return (args[0], args[1], None, args[2])
|
||||
|
||||
t1 = random.random()
|
||||
t2 = random.random()
|
||||
t3 = random.random()
|
||||
scale = random.randint(0, 10)
|
||||
res = MyFunction.apply(t1, t2, scale, t3)
|
||||
self.assertEqual(scale, res[0])
|
||||
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
||||
self.assertEqual(None, res[2])
|
||||
self.assertEqual(True, res[3])
|
||||
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
||||
self.assertEqual("bar", res[5])
|
||||
self.assertEqual(t1, res[6])
|
||||
|
||||
def test_invalid_gradients(self):
|
||||
class MyFunction(Function):
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -269,6 +269,65 @@ class TestCheckpoint(TestCase):
|
|||
out = checkpoint(run_fn, input_var, None)
|
||||
out.sum().backward()
|
||||
|
||||
def test_checkpoint_non_tensor_inputs_outputs(self):
|
||||
def foo(t1, t2, scale, t3):
|
||||
t4 = t1 + t2 * t3
|
||||
t5 = t1 * t2 + t3
|
||||
t4 *= scale
|
||||
t5 *= scale
|
||||
return scale, t4, None, True, t5, "bar", t1
|
||||
|
||||
t1 = torch.rand(10, requires_grad=True)
|
||||
t2 = torch.rand(10, requires_grad=True)
|
||||
t3 = torch.rand(10)
|
||||
scale = random.randint(0, 10)
|
||||
res = checkpoint(foo, t1, t2, scale, t3)
|
||||
self.assertEqual(scale, res[0])
|
||||
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
||||
self.assertEqual(None, res[2])
|
||||
self.assertEqual(True, res[3])
|
||||
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
||||
self.assertEqual("bar", res[5])
|
||||
self.assertEqual(t1, res[6])
|
||||
|
||||
# Validate running backward.
|
||||
res[1].sum().backward(retain_graph=True)
|
||||
res[4].sum().backward(retain_graph=True)
|
||||
res[6].sum().backward()
|
||||
with self.assertRaisesRegex(RuntimeError, "Trying to backward through the graph a second time"):
|
||||
res[6].sum().backward()
|
||||
t1_grad = t1.grad
|
||||
t2_grad = t2.grad
|
||||
|
||||
# Reset grads, run without checkpoint and validate we receive same grads.
|
||||
t1.grad = None
|
||||
t2.grad = None
|
||||
res = foo(t1, t2, scale, t3)
|
||||
torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
|
||||
self.assertEqual(t1.grad, t1_grad)
|
||||
self.assertEqual(t2.grad, t2_grad)
|
||||
|
||||
def test_checkpoint_no_tensors(self):
|
||||
def foo(t1, t2, scale, t3):
|
||||
t4 = t1 + t2 * t3
|
||||
t5 = t1 * t2 + t3
|
||||
t4 *= scale
|
||||
t5 *= scale
|
||||
return scale, t4, None, True, t5, "bar", t1
|
||||
|
||||
t1 = random.random()
|
||||
t2 = random.random()
|
||||
t3 = random.random()
|
||||
scale = random.randint(0, 10)
|
||||
res = checkpoint(foo, t1, t2, scale, t3)
|
||||
self.assertEqual(scale, res[0])
|
||||
self.assertEqual((t1 + t2 * t3) * scale, res[1])
|
||||
self.assertEqual(None, res[2])
|
||||
self.assertEqual(True, res[3])
|
||||
self.assertEqual((t1 * t2 + t3) * scale, res[4])
|
||||
self.assertEqual("bar", res[5])
|
||||
self.assertEqual(t1, res[6])
|
||||
|
||||
def test_checkpoint_partial_grad(self):
|
||||
def run_fn(tensor1, tensor2):
|
||||
# tensor 2 is used for other application logic
|
||||
|
|
|
|||
|
|
@ -173,8 +173,8 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
|
|||
It must accept a context ctx as the first argument, followed by any
|
||||
number of arguments (tensors or other types).
|
||||
|
||||
The context can be used to store tensors that can be then retrieved
|
||||
during the backward pass.
|
||||
The context can be used to store arbitrary data that can be then
|
||||
retrieved during the backward pass.
|
||||
"""
|
||||
raise NotImplementedError("You must implement the forward function for custom"
|
||||
" autograd.Function.")
|
||||
|
|
@ -186,10 +186,13 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
|
|||
This function is to be overridden by all subclasses.
|
||||
|
||||
It must accept a context :attr:`ctx` as the first argument, followed by
|
||||
as many outputs did :func:`forward` return, and it should return as many
|
||||
tensors, as there were inputs to :func:`forward`. Each argument is the
|
||||
gradient w.r.t the given output, and each returned value should be the
|
||||
gradient w.r.t. the corresponding input.
|
||||
as many outputs as the :func:`forward` returned (None will be passed in
|
||||
for non tensor outputs of the forward function),
|
||||
and it should return as many tensors, as there were inputs to
|
||||
:func:`forward`. Each argument is the gradient w.r.t the given output,
|
||||
and each returned value should be the gradient w.r.t. the
|
||||
corresponding input. If an input is not a Tensor or is a Tensor not
|
||||
requiring grads, you can just pass None as a gradient for that input.
|
||||
|
||||
The context can be used to retrieve tensors saved during the forward
|
||||
pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
|
||||
|
|
|
|||
|
|
@ -9,18 +9,26 @@ VariableInfo::VariableInfo(const Variable& var)
|
|||
, device(var.device())
|
||||
, scalar_type(var.scalar_type())
|
||||
, size(var.sizes().vec())
|
||||
, requires_grad(var.requires_grad()) {
|
||||
, requires_grad(var.requires_grad())
|
||||
, is_empty(false) {
|
||||
}
|
||||
|
||||
VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
|
||||
|
||||
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
||||
return at::zeros(size,
|
||||
at::TensorOptions(scalar_type).device(device).layout(layout));
|
||||
if (is_empty) {
|
||||
// Return undefined tensor.
|
||||
return at::Tensor();
|
||||
} else {
|
||||
return at::zeros(
|
||||
size, at::TensorOptions(scalar_type).device(device).layout(layout));
|
||||
}
|
||||
}
|
||||
|
||||
variable_list _wrap_outputs(const variable_list &input_vars,
|
||||
std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
|
||||
const std::unordered_set<at::TensorImpl*> &non_differentiable,
|
||||
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
|
||||
const at::ArrayRef<Variable> raw_outputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node> &cdata) {
|
||||
|
||||
std::unordered_set<at::TensorImpl*> inputs;
|
||||
|
|
@ -96,16 +104,26 @@ variable_list _wrap_outputs(const variable_list &input_vars,
|
|||
}
|
||||
};
|
||||
|
||||
std::vector<torch::autograd::Variable> outputs;
|
||||
std::vector<c10::optional<Variable>> outputs;
|
||||
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
|
||||
outputs.reserve(num_outputs);
|
||||
int num_diff_outputs = 0;
|
||||
|
||||
|
||||
for (auto i = 0; i < num_outputs; ++i) {
|
||||
Variable var = raw_outputs[i];
|
||||
// For outputs that are not tensors, put a placeholder undefined input.
|
||||
if (!raw_outputs[i].has_value()) {
|
||||
if (cdata) {
|
||||
auto output_nr = cdata->add_input_metadata(Node::undefined_input());
|
||||
AT_ASSERT(i == (int)output_nr);
|
||||
}
|
||||
outputs.emplace_back();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto out_tensor_impl = raw_outputs[i].unsafeGetTensorImpl();
|
||||
Variable var = raw_outputs[i].value();
|
||||
|
||||
auto out_tensor_impl = var.unsafeGetTensorImpl();
|
||||
bool is_input = inputs.count(out_tensor_impl) > 0;
|
||||
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
|
||||
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
|
||||
|
|
@ -139,9 +157,11 @@ variable_list _wrap_outputs(const variable_list &input_vars,
|
|||
// See NOTE [ View + Inplace detection ] for more details
|
||||
if (num_diff_outputs > 1) {
|
||||
for (auto& var: outputs) {
|
||||
auto diff_view_meta = impl::get_view_autograd_meta(var);
|
||||
if (diff_view_meta) {
|
||||
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
|
||||
if (var.has_value()) {
|
||||
auto diff_view_meta = impl::get_view_autograd_meta(var.value());
|
||||
if (diff_view_meta) {
|
||||
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@
|
|||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
TORCH_API variable_list _wrap_outputs(
|
||||
TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
|
||||
const variable_list &input_vars,
|
||||
const std::unordered_set<at::TensorImpl*> &non_differentiable,
|
||||
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
|
||||
const at::ArrayRef<Variable> raw_outputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node> &cdata);
|
||||
|
||||
TORCH_API void check_variable_result(const Variable& original,
|
||||
|
|
@ -137,6 +137,7 @@ private:
|
|||
};
|
||||
|
||||
struct TORCH_API VariableInfo {
|
||||
explicit VariableInfo();
|
||||
explicit VariableInfo(const Variable& var);
|
||||
|
||||
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
|
||||
|
|
@ -146,6 +147,7 @@ struct TORCH_API VariableInfo {
|
|||
at::ScalarType scalar_type = at::kFloat;
|
||||
std::vector<int64_t> size;
|
||||
bool requires_grad;
|
||||
bool is_empty;
|
||||
};
|
||||
|
||||
// CppNode<T> is the Node in the autograd graph that represents the user defined
|
||||
|
|
@ -194,10 +196,30 @@ inline void extract_vars(std::vector<bool> &is_var, variable_list& list, Args&&.
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<T, variable_list>::value, T&>::type to_output_type(variable_list& output_list) { return output_list; }
|
||||
typename std::enable_if<std::is_same<T, variable_list>::value, T>::type to_output_type(
|
||||
std::vector<c10::optional<Variable>>& output_list) {
|
||||
variable_list result;
|
||||
std::transform(output_list.begin(), output_list.end(), std::back_inserter(result),
|
||||
[](const c10::optional<Variable>& var) { return *var; });
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_same<T, Variable>::value, T>::type to_output_type(variable_list& output_list) { return output_list[0]; }
|
||||
typename std::enable_if<std::is_same<T, Variable>::value, T>::type to_output_type(
|
||||
std::vector<c10::optional<Variable>>& output_list) {
|
||||
return *output_list[0];
|
||||
}
|
||||
|
||||
inline std::vector<c10::optional<Variable>> to_optional(Variable& output) {
|
||||
return std::vector<c10::optional<Variable>>{output};
|
||||
}
|
||||
|
||||
inline std::vector<c10::optional<Variable>> to_optional(variable_list& output) {
|
||||
std::vector<c10::optional<Variable>> result;
|
||||
std::transform(output.begin(), output.end(), std::back_inserter(result),
|
||||
[](const Variable& var) { return var; });
|
||||
return result;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<typename X, typename... Args>
|
||||
|
|
@ -229,12 +251,19 @@ auto Function<T>::apply(Args&&... args) -> std::enable_if_t<std::is_same<X,T>::v
|
|||
outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
auto wrapped_outputs = _wrap_outputs(input_vars, node->ctx_.get_non_differentiable(), node->ctx_.get_and_bump_dirty(), outputs, is_executable ? node : nullptr);
|
||||
auto wrapped_outputs = _wrap_outputs(
|
||||
input_vars,
|
||||
node->ctx_.get_non_differentiable(),
|
||||
node->ctx_.get_and_bump_dirty(),
|
||||
to_optional(outputs),
|
||||
is_executable ? node : nullptr);
|
||||
|
||||
node->output_info_.reserve(wrapped_outputs.size());
|
||||
for (auto& output : wrapped_outputs) {
|
||||
if (is_executable) {
|
||||
node->output_info_.emplace_back(output);
|
||||
if (is_executable && output.has_value()) {
|
||||
node->output_info_.emplace_back(output.value());
|
||||
} else if (is_executable) {
|
||||
node->output_info_.emplace_back();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <torch/csrc/jit/frontend/tracer.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/python/python_tracer.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
|
|
@ -362,30 +363,39 @@ static void _wrap_outputs(const std::shared_ptr<PyNode>& cdata, THPFunction *sel
|
|||
self->output_info.reserve(num_outputs);
|
||||
}
|
||||
|
||||
auto as_variable = [&](PyObject* obj, int i) -> Variable {
|
||||
if (THPVariable_Check(obj)) {
|
||||
return ((THPVariable*)obj)->cdata;
|
||||
}
|
||||
throw TypeError("%s.forward: expected Tensor or tuple of Tensor (got %s) for return value %d",
|
||||
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
|
||||
};
|
||||
|
||||
auto non_differentiable = _parse_non_differentiable(self);
|
||||
auto dirty_inputs = _mark_dirty(self);
|
||||
|
||||
std::vector<Variable> raw_output_vars;
|
||||
std::vector<c10::optional<Variable>> raw_output_vars;
|
||||
raw_output_vars.reserve(num_outputs);
|
||||
for(int i = 0; i < num_outputs; ++i){
|
||||
PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
|
||||
raw_output_vars.push_back(as_variable(obj,i));
|
||||
// Only process tensors as outputs for autograd purposes.
|
||||
if (THPVariable_Check(obj)) {
|
||||
raw_output_vars.emplace_back(((THPVariable*)obj)->cdata);
|
||||
} else {
|
||||
raw_output_vars.emplace_back();
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap only the tensor outputs.
|
||||
auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata_if_executable);
|
||||
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
if (is_executable) {
|
||||
self->output_info.emplace_back(wrapped_outputs[i]);
|
||||
PyObject* obj = PyTuple_GetItem(raw_output, i);
|
||||
// Keep the non-tensor outputs as is.
|
||||
if (!THPVariable_Check(obj)) {
|
||||
if (is_executable) {
|
||||
self->output_info.emplace_back();
|
||||
}
|
||||
Py_INCREF(obj);
|
||||
PyTuple_SetItem(outputs, i, obj);
|
||||
} else {
|
||||
if (is_executable) {
|
||||
self->output_info.emplace_back(*wrapped_outputs[i]);
|
||||
}
|
||||
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
|
||||
}
|
||||
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(wrapped_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -410,7 +420,7 @@ static void _save_variables(const std::shared_ptr<PyNode>& cdata_ptr, THPFunctio
|
|||
bool is_output = variable->cdata.grad_fn().get() == cdata_ptr.get();
|
||||
self->saved_variables.emplace_back(variable->cdata, is_output);
|
||||
} else {
|
||||
throw TypeError(
|
||||
throw torch::TypeError(
|
||||
"save_for_backward can only save variables, but argument %d is of "
|
||||
"type %s", i, Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
|
|
@ -492,7 +502,7 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
|
|||
}
|
||||
|
||||
static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
|
||||
if (tracer::isTracing()) {
|
||||
if (jit::tracer::isTracing()) {
|
||||
std::ostringstream oss;
|
||||
oss << "Attempted to trace " << name;
|
||||
oss << ", but tracing of legacy functions is not supported";
|
||||
|
|
@ -557,11 +567,14 @@ static void _trace_post_record(
|
|||
node = unpacked;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
|
||||
Value* value = node->outputs()[i];
|
||||
if (var->cdata.defined()) {
|
||||
value->inferTypeFrom(var->cdata);
|
||||
jit::tracer::setValueTrace(var->cdata, value);
|
||||
PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
|
||||
if (THPVariable_Check(obj)) {
|
||||
auto var = (THPVariable*)obj;
|
||||
Value* value = node->outputs()[i];
|
||||
if (var->cdata.defined()) {
|
||||
value->inferTypeFrom(var->cdata);
|
||||
jit::tracer::setValueTrace(var->cdata, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,22 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
if torch.cuda._initialized:
|
||||
ctx.had_cuda_in_fwd = True
|
||||
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
|
||||
ctx.save_for_backward(*args)
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
tensor_inputs.append(arg)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
return outputs
|
||||
|
|
@ -82,7 +97,15 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
|
||||
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
|
||||
" argument.")
|
||||
inputs = ctx.saved_tensors
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
inputs[idx] = tensors[i]
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
|
|
@ -94,7 +117,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_cuda_in_fwd:
|
||||
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
|
||||
detached_inputs = detach_variable(inputs)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
|
|
@ -105,7 +128,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].requires_grad:
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
|
|
@ -113,8 +136,9 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
"none of output has requires_grad=True,"
|
||||
" this checkpoint() is not necessary")
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs)
|
||||
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
|
|
@ -135,6 +159,12 @@ def checkpoint(function, *args, **kwargs):
|
|||
:attr:`function` again, now tracking the intermediate activations, and then
|
||||
the gradients are calculated using these activation values.
|
||||
|
||||
The output of :attr:`function` can contain non-Tensor values and gradient
|
||||
recording is only performed for the Tensor values. Note that if the output
|
||||
consists of nested structures (ex: custom objects, lists, dicts etc.)
|
||||
consisting of Tensors, these Tensors nested in custom structures will not
|
||||
be considered as part of autograd.
|
||||
|
||||
.. warning::
|
||||
Checkpointing currently only supports :func:`torch.autograd.backward`
|
||||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
||||
|
|
|
|||
Loading…
Reference in a new issue