Support non-tensor inputs and outputs for checkpointed functions. (#52422)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52422

As mentioned in https://github.com/pytorch/pytorch/issues/52415,
`torch.utils.checkpoint` doesn't support checkpointing for functions which have
non-tensor inputs and outputs.

This PR resolves this issue by ensuring the autograd machinery ignores the
non-tensor inputs and outputs and processes the tensors accordingly.
ghstack-source-id: 124406867

Test Plan:
1) unit test
2) waitforbuildbot

Reviewed By: albanD

Differential Revision: D26507228

fbshipit-source-id: 0a5a1591570814176185362e83ad18dabd9c84b0
This commit is contained in:
Pritam Damania 2021-03-19 21:26:07 -07:00 committed by Facebook GitHub Bot
parent 9d9986fd10
commit 4fa47e5e7d
7 changed files with 293 additions and 49 deletions

View file

@ -3,6 +3,7 @@ import gc
import sys
import io
import math
import random
import tempfile
import time
import threading
@ -292,6 +293,95 @@ class TestAutograd(TestCase):
with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
t3.sum().backward()
def test_custom_function_non_tensor_inputs_outputs(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, t1, t2, scale, t3):
t4 = t1 + t2 * t3
t5 = t1 * t2 + t3
t4 *= scale
t5 *= scale
# Save scale
ctx.scale = scale
ctx.save_for_backward(t1, t2, t3)
return scale, t4, None, True, t5, "bar", t1
@staticmethod
@once_differentiable
def backward(ctx, *grads):
# Verify grads
self.assertEqual(7, len(grads))
self.assertIsNone(grads[0])
self.assertIsNone(grads[2])
self.assertIsNone(grads[3])
self.assertIsNone(grads[5])
scale = ctx.scale
var1, var2, var3 = ctx.saved_tensors
return (
grads[1] * scale + grads[4] * var2 * scale + grads[6],
grads[1] * var3 * scale + grads[4] * var1 * scale,
None,
grads[1] * var2 * scale + grads[4] * scale,
)
t1 = torch.rand(10, requires_grad=True)
t2 = torch.rand(10, requires_grad=True)
t3 = torch.rand(10)
scale = random.randint(0, 10)
res = MyFunction.apply(t1, t2, scale, t3)
self.assertEqual(scale, res[0])
self.assertEqual((t1 + t2 * t3) * scale, res[1])
self.assertEqual(None, res[2])
self.assertEqual(True, res[3])
self.assertEqual((t1 * t2 + t3) * scale, res[4])
self.assertEqual("bar", res[5])
self.assertEqual(t1, res[6])
# Validate running backward.
torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
self.assertIsNotNone(t1.grad)
self.assertIsNotNone(t2.grad)
self.assertIsNone(t3.grad)
# Test gradcheck
def foo(t1, t2, t3):
res = MyFunction.apply(t1, t2, scale, t3)
return res[1], res[4], res[6]
gradcheck(foo, (t1, t2, t3))
def test_custom_function_no_tensors(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, t1, t2, scale, t3):
t4 = t1 + t2 * t3
t5 = t1 * t2 + t3
t4 *= scale
t5 *= scale
return scale, t4, None, True, t5, "bar", t1
@staticmethod
@once_differentiable
def backward(ctx, *args):
return (args[0], args[1], None, args[2])
t1 = random.random()
t2 = random.random()
t3 = random.random()
scale = random.randint(0, 10)
res = MyFunction.apply(t1, t2, scale, t3)
self.assertEqual(scale, res[0])
self.assertEqual((t1 + t2 * t3) * scale, res[1])
self.assertEqual(None, res[2])
self.assertEqual(True, res[3])
self.assertEqual((t1 * t2 + t3) * scale, res[4])
self.assertEqual("bar", res[5])
self.assertEqual(t1, res[6])
def test_invalid_gradients(self):
class MyFunction(Function):
@staticmethod

View file

@ -269,6 +269,65 @@ class TestCheckpoint(TestCase):
out = checkpoint(run_fn, input_var, None)
out.sum().backward()
def test_checkpoint_non_tensor_inputs_outputs(self):
def foo(t1, t2, scale, t3):
t4 = t1 + t2 * t3
t5 = t1 * t2 + t3
t4 *= scale
t5 *= scale
return scale, t4, None, True, t5, "bar", t1
t1 = torch.rand(10, requires_grad=True)
t2 = torch.rand(10, requires_grad=True)
t3 = torch.rand(10)
scale = random.randint(0, 10)
res = checkpoint(foo, t1, t2, scale, t3)
self.assertEqual(scale, res[0])
self.assertEqual((t1 + t2 * t3) * scale, res[1])
self.assertEqual(None, res[2])
self.assertEqual(True, res[3])
self.assertEqual((t1 * t2 + t3) * scale, res[4])
self.assertEqual("bar", res[5])
self.assertEqual(t1, res[6])
# Validate running backward.
res[1].sum().backward(retain_graph=True)
res[4].sum().backward(retain_graph=True)
res[6].sum().backward()
with self.assertRaisesRegex(RuntimeError, "Trying to backward through the graph a second time"):
res[6].sum().backward()
t1_grad = t1.grad
t2_grad = t2.grad
# Reset grads, run without checkpoint and validate we receive same grads.
t1.grad = None
t2.grad = None
res = foo(t1, t2, scale, t3)
torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
self.assertEqual(t1.grad, t1_grad)
self.assertEqual(t2.grad, t2_grad)
def test_checkpoint_no_tensors(self):
def foo(t1, t2, scale, t3):
t4 = t1 + t2 * t3
t5 = t1 * t2 + t3
t4 *= scale
t5 *= scale
return scale, t4, None, True, t5, "bar", t1
t1 = random.random()
t2 = random.random()
t3 = random.random()
scale = random.randint(0, 10)
res = checkpoint(foo, t1, t2, scale, t3)
self.assertEqual(scale, res[0])
self.assertEqual((t1 + t2 * t3) * scale, res[1])
self.assertEqual(None, res[2])
self.assertEqual(True, res[3])
self.assertEqual((t1 * t2 + t3) * scale, res[4])
self.assertEqual("bar", res[5])
self.assertEqual(t1, res[6])
def test_checkpoint_partial_grad(self):
def run_fn(tensor1, tensor2):
# tensor 2 is used for other application logic

View file

@ -173,8 +173,8 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
It must accept a context ctx as the first argument, followed by any
number of arguments (tensors or other types).
The context can be used to store tensors that can be then retrieved
during the backward pass.
The context can be used to store arbitrary data that can be then
retrieved during the backward pass.
"""
raise NotImplementedError("You must implement the forward function for custom"
" autograd.Function.")
@ -186,10 +186,13 @@ class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixi
This function is to be overridden by all subclasses.
It must accept a context :attr:`ctx` as the first argument, followed by
as many outputs did :func:`forward` return, and it should return as many
tensors, as there were inputs to :func:`forward`. Each argument is the
gradient w.r.t the given output, and each returned value should be the
gradient w.r.t. the corresponding input.
as many outputs as the :func:`forward` returned (None will be passed in
for non tensor outputs of the forward function),
and it should return as many tensors, as there were inputs to
:func:`forward`. Each argument is the gradient w.r.t the given output,
and each returned value should be the gradient w.r.t. the
corresponding input. If an input is not a Tensor or is a Tensor not
requiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward
pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple

View file

@ -9,18 +9,26 @@ VariableInfo::VariableInfo(const Variable& var)
, device(var.device())
, scalar_type(var.scalar_type())
, size(var.sizes().vec())
, requires_grad(var.requires_grad()) {
, requires_grad(var.requires_grad())
, is_empty(false) {
}
VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
return at::zeros(size,
at::TensorOptions(scalar_type).device(device).layout(layout));
if (is_empty) {
// Return undefined tensor.
return at::Tensor();
} else {
return at::zeros(
size, at::TensorOptions(scalar_type).device(device).layout(layout));
}
}
variable_list _wrap_outputs(const variable_list &input_vars,
std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<Variable> raw_outputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node> &cdata) {
std::unordered_set<at::TensorImpl*> inputs;
@ -96,16 +104,26 @@ variable_list _wrap_outputs(const variable_list &input_vars,
}
};
std::vector<torch::autograd::Variable> outputs;
std::vector<c10::optional<Variable>> outputs;
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
outputs.reserve(num_outputs);
int num_diff_outputs = 0;
for (auto i = 0; i < num_outputs; ++i) {
Variable var = raw_outputs[i];
// For outputs that are not tensors, put a placeholder undefined input.
if (!raw_outputs[i].has_value()) {
if (cdata) {
auto output_nr = cdata->add_input_metadata(Node::undefined_input());
AT_ASSERT(i == (int)output_nr);
}
outputs.emplace_back();
continue;
}
auto out_tensor_impl = raw_outputs[i].unsafeGetTensorImpl();
Variable var = raw_outputs[i].value();
auto out_tensor_impl = var.unsafeGetTensorImpl();
bool is_input = inputs.count(out_tensor_impl) > 0;
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
@ -139,9 +157,11 @@ variable_list _wrap_outputs(const variable_list &input_vars,
// See NOTE [ View + Inplace detection ] for more details
if (num_diff_outputs > 1) {
for (auto& var: outputs) {
auto diff_view_meta = impl::get_view_autograd_meta(var);
if (diff_view_meta) {
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
if (var.has_value()) {
auto diff_view_meta = impl::get_view_autograd_meta(var.value());
if (diff_view_meta) {
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
}
}
}
}

View file

@ -8,11 +8,11 @@
namespace torch { namespace autograd {
TORCH_API variable_list _wrap_outputs(
TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
const variable_list &input_vars,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<Variable> raw_outputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node> &cdata);
TORCH_API void check_variable_result(const Variable& original,
@ -137,6 +137,7 @@ private:
};
struct TORCH_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var);
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
@ -146,6 +147,7 @@ struct TORCH_API VariableInfo {
at::ScalarType scalar_type = at::kFloat;
std::vector<int64_t> size;
bool requires_grad;
bool is_empty;
};
// CppNode<T> is the Node in the autograd graph that represents the user defined
@ -194,10 +196,30 @@ inline void extract_vars(std::vector<bool> &is_var, variable_list& list, Args&&.
}
template <typename T>
typename std::enable_if<std::is_same<T, variable_list>::value, T&>::type to_output_type(variable_list& output_list) { return output_list; }
typename std::enable_if<std::is_same<T, variable_list>::value, T>::type to_output_type(
std::vector<c10::optional<Variable>>& output_list) {
variable_list result;
std::transform(output_list.begin(), output_list.end(), std::back_inserter(result),
[](const c10::optional<Variable>& var) { return *var; });
return result;
}
template <typename T>
typename std::enable_if<std::is_same<T, Variable>::value, T>::type to_output_type(variable_list& output_list) { return output_list[0]; }
typename std::enable_if<std::is_same<T, Variable>::value, T>::type to_output_type(
std::vector<c10::optional<Variable>>& output_list) {
return *output_list[0];
}
inline std::vector<c10::optional<Variable>> to_optional(Variable& output) {
return std::vector<c10::optional<Variable>>{output};
}
inline std::vector<c10::optional<Variable>> to_optional(variable_list& output) {
std::vector<c10::optional<Variable>> result;
std::transform(output.begin(), output.end(), std::back_inserter(result),
[](const Variable& var) { return var; });
return result;
}
template<class T>
template<typename X, typename... Args>
@ -229,12 +251,19 @@ auto Function<T>::apply(Args&&... args) -> std::enable_if_t<std::is_same<X,T>::v
outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
}
auto wrapped_outputs = _wrap_outputs(input_vars, node->ctx_.get_non_differentiable(), node->ctx_.get_and_bump_dirty(), outputs, is_executable ? node : nullptr);
auto wrapped_outputs = _wrap_outputs(
input_vars,
node->ctx_.get_non_differentiable(),
node->ctx_.get_and_bump_dirty(),
to_optional(outputs),
is_executable ? node : nullptr);
node->output_info_.reserve(wrapped_outputs.size());
for (auto& output : wrapped_outputs) {
if (is_executable) {
node->output_info_.emplace_back(output);
if (is_executable && output.has_value()) {
node->output_info_.emplace_back(output.value());
} else if (is_executable) {
node->output_info_.emplace_back();
}
}

View file

@ -21,6 +21,7 @@
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
@ -362,30 +363,39 @@ static void _wrap_outputs(const std::shared_ptr<PyNode>& cdata, THPFunction *sel
self->output_info.reserve(num_outputs);
}
auto as_variable = [&](PyObject* obj, int i) -> Variable {
if (THPVariable_Check(obj)) {
return ((THPVariable*)obj)->cdata;
}
throw TypeError("%s.forward: expected Tensor or tuple of Tensor (got %s) for return value %d",
Py_TYPE(self)->tp_name, Py_TYPE(obj)->tp_name, i);
};
auto non_differentiable = _parse_non_differentiable(self);
auto dirty_inputs = _mark_dirty(self);
std::vector<Variable> raw_output_vars;
std::vector<c10::optional<Variable>> raw_output_vars;
raw_output_vars.reserve(num_outputs);
for(int i = 0; i < num_outputs; ++i){
PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
raw_output_vars.push_back(as_variable(obj,i));
// Only process tensors as outputs for autograd purposes.
if (THPVariable_Check(obj)) {
raw_output_vars.emplace_back(((THPVariable*)obj)->cdata);
} else {
raw_output_vars.emplace_back();
}
}
// Wrap only the tensor outputs.
auto wrapped_outputs = _wrap_outputs(input_vars, non_differentiable, dirty_inputs, raw_output_vars, cdata_if_executable);
for (int i = 0; i < num_outputs; i++) {
if (is_executable) {
self->output_info.emplace_back(wrapped_outputs[i]);
PyObject* obj = PyTuple_GetItem(raw_output, i);
// Keep the non-tensor outputs as is.
if (!THPVariable_Check(obj)) {
if (is_executable) {
self->output_info.emplace_back();
}
Py_INCREF(obj);
PyTuple_SetItem(outputs, i, obj);
} else {
if (is_executable) {
self->output_info.emplace_back(*wrapped_outputs[i]);
}
PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
}
PyTuple_SET_ITEM(outputs, i, THPVariable_Wrap(wrapped_outputs[i]));
}
}
@ -410,7 +420,7 @@ static void _save_variables(const std::shared_ptr<PyNode>& cdata_ptr, THPFunctio
bool is_output = variable->cdata.grad_fn().get() == cdata_ptr.get();
self->saved_variables.emplace_back(variable->cdata, is_output);
} else {
throw TypeError(
throw torch::TypeError(
"save_for_backward can only save variables, but argument %d is of "
"type %s", i, Py_TYPE(obj)->tp_name);
}
@ -492,7 +502,7 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
}
static void _assert_not_tracing(const char* name, const variable_list& input_vars) {
if (tracer::isTracing()) {
if (jit::tracer::isTracing()) {
std::ostringstream oss;
oss << "Attempted to trace " << name;
oss << ", but tracing of legacy functions is not supported";
@ -557,11 +567,14 @@ static void _trace_post_record(
node = unpacked;
}
for (int i = 0; i < num_outputs; ++i) {
auto var = (THPVariable*)PyTuple_GET_ITEM(output_objects, i);
Value* value = node->outputs()[i];
if (var->cdata.defined()) {
value->inferTypeFrom(var->cdata);
jit::tracer::setValueTrace(var->cdata, value);
PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
if (THPVariable_Check(obj)) {
auto var = (THPVariable*)obj;
Value* value = node->outputs()[i];
if (var->cdata.defined()) {
value->inferTypeFrom(var->cdata);
jit::tracer::setValueTrace(var->cdata, value);
}
}
}
}

View file

@ -70,7 +70,22 @@ class CheckpointFunction(torch.autograd.Function):
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@ -82,7 +97,15 @@ class CheckpointFunction(torch.autograd.Function):
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
inputs = ctx.saved_tensors
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
@ -94,7 +117,7 @@ class CheckpointFunction(torch.autograd.Function):
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
outputs = ctx.run_function(*detached_inputs)
@ -105,7 +128,7 @@ class CheckpointFunction(torch.autograd.Function):
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if outputs[i].requires_grad:
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
@ -113,8 +136,9 @@ class CheckpointFunction(torch.autograd.Function):
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None, None) + grads
@ -135,6 +159,12 @@ def checkpoint(function, *args, **kwargs):
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
The output of :attr:`function` can contain non-Tensor values and gradient
recording is only performed for the Tensor values. Note that if the output
consists of nested structures (ex: custom objects, lists, dicts etc.)
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`