From 4e73eee93f411596fcabb32cc8e7686890d1c7fb Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 4 Oct 2023 12:43:03 -0400 Subject: [PATCH] Update custom Function preserve torch function when inputs returned as-is (#109825) Fixes https://github.com/pytorch/pytorch/issues/109805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109825 Approved by: https://github.com/albanD --- test/distributed/test_data_parallel.py | 2 +- test/test_autograd.py | 28 +++++++++++++++++++++++++ torch/csrc/autograd/custom_function.cpp | 22 ++++++++++++------- torch/csrc/autograd/custom_function.h | 11 ++++++++-- torch/csrc/autograd/python_function.cpp | 18 +++++++++++++++- 5 files changed, 70 insertions(+), 11 deletions(-) diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 53490197919..3d88fc38515 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -87,7 +87,7 @@ class TestDataParallel(TestCase): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_lazy_linear(self): - with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'): + with self.assertRaisesRegex(ValueError, 'Attempted to use an uninitialized parameter'): model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0)) model_dp(torch.rand(10, 10).to(0)) diff --git a/test/test_autograd.py b/test/test_autograd.py index ad98d7031c8..17337157dda 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7169,6 +7169,34 @@ for shape in [(1,), ()]: with self.assertRaisesRegex(RuntimeError, leaf_grad_err): output.zero_() + def test_custom_function_preserve_torch_function_when_return_as_is(self): + class Custom(torch.Tensor): + def __init__(self, data): + super().__init__() + self._data = data + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + args = tuple(a._data if isinstance(a, cls) else a for a in args) + out = func(*args, **kwargs) + if isinstance(out, torch.Tensor): + out = cls(out) + return out + + class Fn(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx): + pass + + x = Custom(torch.randn(2, 3)) + y = Fn.apply(x) + self.assertTrue(isinstance(y, Custom)) + def test_grad_mode_restored_reentrant(self): class MyFunction(Function): @staticmethod diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 72303902475..8c61ea52a17 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -247,7 +247,9 @@ static void _process_forward_mode_AD( } } -static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) { +static at::Tensor _view_as_self_with_no_grad( + const at::Tensor& self, + const _view_as_self_fn_t& view_as_self_fn) { // This is called below in _process_backward_mode_ad in two places: // // (1) An input has been returned, but it wasn't modified. Return it as a view @@ -265,7 +267,10 @@ static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) { // ignored. at::AutoFwGradMode fw_grad_mode(false); AutoGradMode grad_mode(false); - return self.view_as(self); + // We thread through this view_as_self_fn lambda so that in the case we are a + // Python custom function (rather than a cpp one), we can properly call the + // view_as from python so that torch function logic can still trigger. + return view_as_self_fn(self); } static optional_variable_list _process_backward_mode_ad( @@ -274,7 +279,8 @@ static optional_variable_list _process_backward_mode_ad( const std::unordered_set& dirty_inputs, const at::ArrayRef> raw_outputs, const std::shared_ptr& cdata, - const std::unordered_set& to_save_if_setup_context) { + const std::unordered_set& to_save_if_setup_context, + const _view_as_self_fn_t& view_as_self_fn) { auto num_outputs = raw_outputs.size(); const char* error_msg_input_returned_as_is = @@ -295,7 +301,7 @@ static optional_variable_list _process_backward_mode_ad( if (is_input && !is_modified) { TORCH_CHECK( !is_saved_and_setup_context, error_msg_input_returned_as_is) - var = _view_as_self_with_no_grad(var); + var = _view_as_self_with_no_grad(var, view_as_self_fn); } return; } @@ -350,7 +356,7 @@ static optional_variable_list _process_backward_mode_ad( } } else if (is_input) { TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is) - var = _view_as_self_with_no_grad(var); + var = _view_as_self_with_no_grad(var, view_as_self_fn); impl::set_gradient_edge(var, {cdata, output_nr}); } else if (cdata) { impl::set_gradient_edge(var, {cdata, output_nr}); @@ -453,7 +459,8 @@ optional_variable_list _wrap_outputs( const at::ArrayRef> raw_outputs, const std::shared_ptr& cdata, const _jvp_fn_t& jvp_user_function, - const std::unordered_set& to_save_if_setup_context) { + const std::unordered_set& to_save_if_setup_context, + const _view_as_self_fn_t& view_as_self_fn) { std::unordered_map inputs_mapping; inputs_mapping.reserve(input_vars.size()); for (const auto i : c10::irange(input_vars.size())) { @@ -466,7 +473,8 @@ optional_variable_list _wrap_outputs( dirty_inputs, raw_outputs, cdata, - to_save_if_setup_context); + to_save_if_setup_context, + view_as_self_fn); // This must happen after the backward processing as we expect the // computations happening here to track backward mode gradients. diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 0777475f8e1..a60102bbec8 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -13,6 +13,7 @@ namespace autograd { using optional_variable_list = std::vector>; using _jvp_fn_t = std::function; +using _view_as_self_fn_t = std::function; TORCH_API std::vector> _wrap_outputs( const variable_list& input_vars, @@ -21,7 +22,8 @@ TORCH_API std::vector> _wrap_outputs( const at::ArrayRef> raw_outputs, const std::shared_ptr& cdata, const _jvp_fn_t& jvp_user_function, - const std::unordered_set& to_save_if_setup_context); + const std::unordered_set& to_save_if_setup_context, + const _view_as_self_fn_t& view_as_self_fn); TORCH_API void check_variable_result( const at::TensorBase& original, @@ -311,6 +313,10 @@ auto Function::apply(Args&&... args) "Please open a feature request on GitHub if you need this."); }; + auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor { + return x.view_as(x); + }; + auto wrapped_outputs = _wrap_outputs( input_vars, node->ctx_.get_non_differentiable(), @@ -318,7 +324,8 @@ auto Function::apply(Args&&... args) to_optional(outputs), is_executable ? node : nullptr, jvp_fn, - {}); + {}, + view_as_self_fn); node->output_info_.reserve(wrapped_outputs.size()); for (auto& output : wrapped_outputs) { diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index c40fb7f53ef..de7cb12f86b 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -507,6 +507,21 @@ static void _wrap_outputs( return results; }; + auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor { + pybind11::gil_scoped_acquire gil; + THPObjectPtr py_x(THPVariable_Wrap(x)); + THPObjectPtr py_view_as_method(PyObject_GetAttrString(py_x, "view_as")); + if (!py_view_as_method) + throw python_error(); + THPObjectPtr args(PyTuple_Pack(1, py_x.get())); + if (!args) + throw python_error(); + THPObjectPtr result(PyObject_CallObject(py_view_as_method, args)); + if (!result) + throw python_error(); + return THPVariable_Unpack(result); + }; + // Wrap only the tensor outputs. auto wrapped_outputs = _wrap_outputs( input_vars, @@ -515,7 +530,8 @@ static void _wrap_outputs( raw_output_vars, cdata_if_executable, jvp_user_function, - to_save_if_setup_context); + to_save_if_setup_context, + view_as_self_fn); for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GetItem(raw_output, i);