mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Update custom Function preserve torch function when inputs returned as-is (#109825)
Fixes https://github.com/pytorch/pytorch/issues/109805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109825 Approved by: https://github.com/albanD
This commit is contained in:
parent
21d77bcf80
commit
4e73eee93f
5 changed files with 70 additions and 11 deletions
|
|
@ -87,7 +87,7 @@ class TestDataParallel(TestCase):
|
|||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||
def test_data_parallel_lazy_linear(self):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Modules with uninitialized parameters'):
|
||||
with self.assertRaisesRegex(ValueError, 'Attempted to use an uninitialized parameter'):
|
||||
model_dp = torch.nn.DataParallel(torch.nn.LazyLinear(10).to(0))
|
||||
model_dp(torch.rand(10, 10).to(0))
|
||||
|
||||
|
|
|
|||
|
|
@ -7169,6 +7169,34 @@ for shape in [(1,), ()]:
|
|||
with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
|
||||
output.zero_()
|
||||
|
||||
def test_custom_function_preserve_torch_function_when_return_as_is(self):
|
||||
class Custom(torch.Tensor):
|
||||
def __init__(self, data):
|
||||
super().__init__()
|
||||
self._data = data
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
args = tuple(a._data if isinstance(a, cls) else a for a in args)
|
||||
out = func(*args, **kwargs)
|
||||
if isinstance(out, torch.Tensor):
|
||||
out = cls(out)
|
||||
return out
|
||||
|
||||
class Fn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx):
|
||||
pass
|
||||
|
||||
x = Custom(torch.randn(2, 3))
|
||||
y = Fn.apply(x)
|
||||
self.assertTrue(isinstance(y, Custom))
|
||||
|
||||
def test_grad_mode_restored_reentrant(self):
|
||||
class MyFunction(Function):
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -247,7 +247,9 @@ static void _process_forward_mode_AD(
|
|||
}
|
||||
}
|
||||
|
||||
static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) {
|
||||
static at::Tensor _view_as_self_with_no_grad(
|
||||
const at::Tensor& self,
|
||||
const _view_as_self_fn_t& view_as_self_fn) {
|
||||
// This is called below in _process_backward_mode_ad in two places:
|
||||
//
|
||||
// (1) An input has been returned, but it wasn't modified. Return it as a view
|
||||
|
|
@ -265,7 +267,10 @@ static at::Tensor _view_as_self_with_no_grad(const at::Tensor& self) {
|
|||
// ignored.
|
||||
at::AutoFwGradMode fw_grad_mode(false);
|
||||
AutoGradMode grad_mode(false);
|
||||
return self.view_as(self);
|
||||
// We thread through this view_as_self_fn lambda so that in the case we are a
|
||||
// Python custom function (rather than a cpp one), we can properly call the
|
||||
// view_as from python so that torch function logic can still trigger.
|
||||
return view_as_self_fn(self);
|
||||
}
|
||||
|
||||
static optional_variable_list _process_backward_mode_ad(
|
||||
|
|
@ -274,7 +279,8 @@ static optional_variable_list _process_backward_mode_ad(
|
|||
const std::unordered_set<at::TensorImpl*>& dirty_inputs,
|
||||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn) {
|
||||
auto num_outputs = raw_outputs.size();
|
||||
|
||||
const char* error_msg_input_returned_as_is =
|
||||
|
|
@ -295,7 +301,7 @@ static optional_variable_list _process_backward_mode_ad(
|
|||
if (is_input && !is_modified) {
|
||||
TORCH_CHECK(
|
||||
!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
||||
var = _view_as_self_with_no_grad(var);
|
||||
var = _view_as_self_with_no_grad(var, view_as_self_fn);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -350,7 +356,7 @@ static optional_variable_list _process_backward_mode_ad(
|
|||
}
|
||||
} else if (is_input) {
|
||||
TORCH_CHECK(!is_saved_and_setup_context, error_msg_input_returned_as_is)
|
||||
var = _view_as_self_with_no_grad(var);
|
||||
var = _view_as_self_with_no_grad(var, view_as_self_fn);
|
||||
impl::set_gradient_edge(var, {cdata, output_nr});
|
||||
} else if (cdata) {
|
||||
impl::set_gradient_edge(var, {cdata, output_nr});
|
||||
|
|
@ -453,7 +459,8 @@ optional_variable_list _wrap_outputs(
|
|||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const _jvp_fn_t& jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context) {
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn) {
|
||||
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
|
||||
inputs_mapping.reserve(input_vars.size());
|
||||
for (const auto i : c10::irange(input_vars.size())) {
|
||||
|
|
@ -466,7 +473,8 @@ optional_variable_list _wrap_outputs(
|
|||
dirty_inputs,
|
||||
raw_outputs,
|
||||
cdata,
|
||||
to_save_if_setup_context);
|
||||
to_save_if_setup_context,
|
||||
view_as_self_fn);
|
||||
|
||||
// This must happen after the backward processing as we expect the
|
||||
// computations happening here to track backward mode gradients.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ namespace autograd {
|
|||
|
||||
using optional_variable_list = std::vector<c10::optional<Variable>>;
|
||||
using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
|
||||
using _view_as_self_fn_t = std::function<at::Tensor(at::Tensor)>;
|
||||
|
||||
TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
|
||||
const variable_list& input_vars,
|
||||
|
|
@ -21,7 +22,8 @@ TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
|
|||
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
|
||||
const std::shared_ptr<Node>& cdata,
|
||||
const _jvp_fn_t& jvp_user_function,
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context);
|
||||
const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
|
||||
const _view_as_self_fn_t& view_as_self_fn);
|
||||
|
||||
TORCH_API void check_variable_result(
|
||||
const at::TensorBase& original,
|
||||
|
|
@ -311,6 +313,10 @@ auto Function<T>::apply(Args&&... args)
|
|||
"Please open a feature request on GitHub if you need this.");
|
||||
};
|
||||
|
||||
auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
|
||||
return x.view_as(x);
|
||||
};
|
||||
|
||||
auto wrapped_outputs = _wrap_outputs(
|
||||
input_vars,
|
||||
node->ctx_.get_non_differentiable(),
|
||||
|
|
@ -318,7 +324,8 @@ auto Function<T>::apply(Args&&... args)
|
|||
to_optional(outputs),
|
||||
is_executable ? node : nullptr,
|
||||
jvp_fn,
|
||||
{});
|
||||
{},
|
||||
view_as_self_fn);
|
||||
|
||||
node->output_info_.reserve(wrapped_outputs.size());
|
||||
for (auto& output : wrapped_outputs) {
|
||||
|
|
|
|||
|
|
@ -507,6 +507,21 @@ static void _wrap_outputs(
|
|||
return results;
|
||||
};
|
||||
|
||||
auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
THPObjectPtr py_x(THPVariable_Wrap(x));
|
||||
THPObjectPtr py_view_as_method(PyObject_GetAttrString(py_x, "view_as"));
|
||||
if (!py_view_as_method)
|
||||
throw python_error();
|
||||
THPObjectPtr args(PyTuple_Pack(1, py_x.get()));
|
||||
if (!args)
|
||||
throw python_error();
|
||||
THPObjectPtr result(PyObject_CallObject(py_view_as_method, args));
|
||||
if (!result)
|
||||
throw python_error();
|
||||
return THPVariable_Unpack(result);
|
||||
};
|
||||
|
||||
// Wrap only the tensor outputs.
|
||||
auto wrapped_outputs = _wrap_outputs(
|
||||
input_vars,
|
||||
|
|
@ -515,7 +530,8 @@ static void _wrap_outputs(
|
|||
raw_output_vars,
|
||||
cdata_if_executable,
|
||||
jvp_user_function,
|
||||
to_save_if_setup_context);
|
||||
to_save_if_setup_context,
|
||||
view_as_self_fn);
|
||||
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
PyObject* obj = PyTuple_GetItem(raw_output, i);
|
||||
|
|
|
|||
Loading…
Reference in a new issue