Update custom Function preserve torch function when inputs returned as-is (#109825)

Fixes https://github.com/pytorch/pytorch/issues/109805
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109825
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2023-10-04 12:43:03 -04:00 committed by PyTorch MergeBot
parent 21d77bcf80
commit 4e73eee93f
5 changed files with 70 additions and 11 deletions

View file

@ -87,7 +87,7 @@ class TestDataParallel(TestCase):
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_lazy_linear(self):
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
with self.assertRaisesRegex(ValueError, 'Attempted to use an uninitialized parameter'):
model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
model_dp(torch.rand(10, 10).to(0))

View file

@ -7169,6 +7169,34 @@ for shape in [(1,), ()]:
with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
output.zero_()
def test_custom_function_preserve_torch_function_when_return_as_is(self):
class Custom(torch.Tensor):
def __init__(self, data):
super().__init__()
self._data = data
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
args = tuple(a._data if isinstance(a, cls) else a for a in args)
out = func(*args, **kwargs)
if isinstance(out, torch.Tensor):
out = cls(out)
return out
class Fn(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx):
pass
x = Custom(torch.randn(2, 3))
y = Fn.apply(x)
self.assertTrue(isinstance(y, Custom))
def test_grad_mode_restored_reentrant(self):
class MyFunction(Function):
@staticmethod

View file

@ -247,7 +247,9 @@ static void _process_forward_mode_AD(
}
}
static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) {
static at::Tensor _view_as_self_with_no_grad(
const at::Tensor& self,
const _view_as_self_fn_t& view_as_self_fn) {
// This is called below in _process_backward_mode_ad in two places:
//
// (1) An input has been returned, but it wasn't modified. Return it as a view
@ -265,7 +267,10 @@ static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) {
// ignored.
at::AutoFwGradMode fw_grad_mode(false);
AutoGradMode grad_mode(false);
return self.view_as(self);
// We thread through this view_as_self_fn lambda so that in the case we are a
// Python custom function (rather than a cpp one), we can properly call the
// view_as from python so that torch function logic can still trigger.
return view_as_self_fn(self);
}
static optional_variable_list _process_backward_mode_ad(
@ -274,7 +279,8 @@ static optional_variable_list _process_backward_mode_ad(
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn) {
auto num_outputs = raw_outputs.size();
const char* error_msg_input_returned_as_is =
@ -295,7 +301,7 @@ static optional_variable_list _process_backward_mode_ad(
if (is_input && !is_modified) {
TORCH_CHECK(
!is_saved_and_setup_context, error_msg_input_returned_as_is)
var = _view_as_self_with_no_grad(var);
var = _view_as_self_with_no_grad(var, view_as_self_fn);
}
return;
}
@ -350,7 +356,7 @@ static optional_variable_list _process_backward_mode_ad(
}
} else if (is_input) {
TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
var = _view_as_self_with_no_grad(var);
var = _view_as_self_with_no_grad(var, view_as_self_fn);
impl::set_gradient_edge(var, {cdata, output_nr});
} else if (cdata) {
impl::set_gradient_edge(var, {cdata, output_nr});
@ -453,7 +459,8 @@ optional_variable_list _wrap_outputs(
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
const _jvp_fn_t& jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn) {
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
inputs_mapping.reserve(input_vars.size());
for (const auto i : c10::irange(input_vars.size())) {
@ -466,7 +473,8 @@ optional_variable_list _wrap_outputs(
dirty_inputs,
raw_outputs,
cdata,
to_save_if_setup_context);
to_save_if_setup_context,
view_as_self_fn);
// This must happen after the backward processing as we expect the
// computations happening here to track backward mode gradients.

View file

@ -13,6 +13,7 @@ namespace autograd {
using optional_variable_list = std::vector<c10::optional<Variable>>;
using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
using _view_as_self_fn_t = std::function<at::Tensor(at::Tensor)>;
TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
const variable_list& input_vars,
@ -21,7 +22,8 @@ TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node>& cdata,
const _jvp_fn_t& jvp_user_function,
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context);
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
const _view_as_self_fn_t& view_as_self_fn);
TORCH_API void check_variable_result(
const at::TensorBase& original,
@ -311,6 +313,10 @@ auto Function<T>::apply(Args&&... args)
"Please open a feature request on GitHub if you need this.");
};
auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
return x.view_as(x);
};
auto wrapped_outputs = _wrap_outputs(
input_vars,
node->ctx_.get_non_differentiable(),
@ -318,7 +324,8 @@ auto Function<T>::apply(Args&&... args)
to_optional(outputs),
is_executable ? node : nullptr,
jvp_fn,
{});
{},
view_as_self_fn);
node->output_info_.reserve(wrapped_outputs.size());
for (auto& output : wrapped_outputs) {

View file

@ -507,6 +507,21 @@ static void _wrap_outputs(
return results;
};
auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
pybind11::gil_scoped_acquire gil;
THPObjectPtr py_x(THPVariable_Wrap(x));
THPObjectPtr py_view_as_method(PyObject_GetAttrString(py_x, "view_as"));
if (!py_view_as_method)
throw python_error();
THPObjectPtr args(PyTuple_Pack(1, py_x.get()));
if (!args)
throw python_error();
THPObjectPtr result(PyObject_CallObject(py_view_as_method, args));
if (!result)
throw python_error();
return THPVariable_Unpack(result);
};
// Wrap only the tensor outputs.
auto wrapped_outputs = _wrap_outputs(
input_vars,
@ -515,7 +530,8 @@ static void _wrap_outputs(
raw_output_vars,
cdata_if_executable,
jvp_user_function,
to_save_if_setup_context);
to_save_if_setup_context,
view_as_self_fn);
for (const auto i : c10::irange(num_outputs)) {
PyObject* obj = PyTuple_GetItem(raw_output, i);