From 5aac95c7131651d9ea3da08c324e8ade79de18c0 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 16 Jan 2024 14:42:57 -0500 Subject: [PATCH] Introduce slice_inverse() op (#117041) Introduces a new op `slice_inverse()`. This is used in the reverse view_func for slice and several other ops (e.g. `split_with_sizes`, `chunk`). It's implemented behind the scenes by a call to `as_strided()`, but it's easier for subclasses to implement the more limited `slice_inverse()` than the full `as_strided()`. This PR: * Introduces the op itself * Updates all relevant functional inverses to call `slice_inverse()` instead of `as_strided()` directly * Makes codegen changes to allow `slice_scatter()` to be the copy variant for `slice_inverse()` * Need to avoid view_copy codegen (assumes if view name ends in inverse, we don't need to gen one, which is possibly a bad assumption) @albanD / @soulitzer / @bdhirsh: I'm most interested in your thoughts on the codegen changes and whether this is the right way to go. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117041 Approved by: https://github.com/bdhirsh --- aten/src/ATen/FunctionalInverses.cpp | 60 +++++++++++-------- aten/src/ATen/native/TensorShape.cpp | 12 ++++ aten/src/ATen/native/native_functions.yaml | 17 +++++- ...asDecompTest.test_has_decomposition.expect | 1 + tools/autograd/derivatives.yaml | 5 ++ tools/autograd/gen_inplace_or_view_type.py | 1 + tools/autograd/load_derivatives.py | 10 ++-- torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 8 ++- torchgen/gen_functionalization_type.py | 6 ++ torchgen/model.py | 18 ++++-- 11 files changed, 103 insertions(+), 36 deletions(-) diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 953636df5ab..36c271d98c6 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -224,48 +224,48 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor if (inverse_return_mode == InverseReturnMode::AlwaysView) { // NB: assumes mutated_view is a narrowed view of base. // We should NOT do this for functionalization - return mutated_view.as_strided_symint( - base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + return mutated_view.slice_inverse_symint( + base, dim, std::move(start), std::move(end), std::move(step)); } else { return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step)); } } Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { + // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. + // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i + // on top of the base tensor. + // For autograd, we have all of the tensors outputted by split() and we just want to stack them. + dim = at::maybe_wrap_dim(dim, base.dim()); + auto dim_size = base.sym_size(dim); + auto start = split_size * mutated_view_idx; + auto end = split_size + start; + if (end > dim_size) end = dim_size; + if (inverse_return_mode == InverseReturnMode::AlwaysView) { // NB: assumes mutated_view is a narrowed view of base. // We should NOT do this for functionalization - return mutated_view.as_strided_symint( - base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + return mutated_view.slice_inverse_symint(base, dim, start, end, 1); } else { - // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. - // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i - // on top of the base tensor. - // For autograd, we have all of the tensors outputted by split() and we just want to stack them. - dim = at::maybe_wrap_dim(dim, base.dim()); - auto dim_size = base.sym_size(dim); - auto start = split_size * mutated_view_idx; - auto end = split_size + start; - if (end > dim_size) end = dim_size; return base.slice_scatter_symint(mutated_view, dim, start, end, 1); } } Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) { + dim = at::maybe_wrap_dim(dim, base.dim()); + auto dim_size = base.sym_size(dim); + c10::SymInt start = 0; + for (auto i = 0; i < mutated_view_idx; ++i) { + start += split_sizes[i]; + } + auto end = start + split_sizes[mutated_view_idx]; + if (end > dim_size) end = dim_size; + if (inverse_return_mode == InverseReturnMode::AlwaysView) { // NB: assumes mutated_view is a narrowed view of base. // We should NOT do this for functionalization - return mutated_view.as_strided_symint( - base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + return mutated_view.slice_inverse_symint(base, dim, start, end, 1); } else { - dim = at::maybe_wrap_dim(dim, base.dim()); - auto dim_size = base.sym_size(dim); - c10::SymInt start = 0; - for (auto i = 0; i < mutated_view_idx; ++i) { - start += split_sizes[i]; - } - auto end = start + split_sizes[mutated_view_idx]; - if (end > dim_size) end = dim_size; return base.slice_scatter_symint(mutated_view, dim, start, end, 1); } } @@ -428,12 +428,22 @@ Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Ten if (inverse_return_mode == InverseReturnMode::AlwaysView) { // NB: assumes mutated_view is a narrowed view of base. // We should NOT do this for functionalization - return mutated_view.as_strided_symint( - base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + return mutated_view.slice_inverse_symint(base, dim, std::move(start), start + length, 1); } else { return base.slice_scatter_symint( mutated_view, dim, std::move(start), start + length, 1); } } +Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional start, std::optional end, c10::SymInt step) { + // slice_inverse() inverse is just slice() + if (inverse_return_mode == InverseReturnMode::NeverView) { + return at::slice_copy_symint( + mutated_view, dim, std::move(start), std::move(end), std::move(step)); + } else { + return mutated_view.slice_symint( + dim, std::move(start), std::move(end), std::move(step)); + } +} + } // namespace at::functionalization diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 05bd5f4cafa..a7d6f4df092 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -151,6 +151,7 @@ #include #include #include +#include #include #include #include @@ -2559,6 +2560,17 @@ Tensor slice( return result; } +Tensor slice_inverse_symint( + const Tensor& self, + const Tensor& base, + int64_t /* dim */, + c10::optional /* start */, + c10::optional /* end */, + SymInt /* step */) { + // assume self has enough to storage to be viewed with base's metadata + return self.as_strided_symint(base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); +} + Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { auto grad_input = at::zeros(input_sizes, grad.options()); grad_input.slice(dim, start, end, step).copy_(grad); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cb7bff9c54a..3d8f59f61cc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5379,6 +5379,21 @@ CompositeExplicitAutograd: slice_backward autogen: slice_backward.out +# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk, +# slice.Tensor, split_with_sizes, et. al.). Currently, these are only used during fake-ification +# of PT2 graph input subclass instances that are views. This means: +# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it) +# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it) +# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph +# input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is +# easier to implement for a subclass than as_strided() +- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: slice_inverse_symint + - func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor variants: function, method device_check: NoCheck @@ -5386,7 +5401,7 @@ dispatch: CompositeExplicitAutogradNonFunctional: slice_scatter autogen: slice_scatter.out - tags: core + tags: [core, view_copy] - func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor variants: function, method diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 083a39d06ce..e2c3c025170 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1161,6 +1161,7 @@ aten::set_.source_Storage_storage_offset aten::set_.source_Tensor aten::slice_copy.Tensor aten::slice_copy.Tensor_out +aten::slice_inverse aten::slice_scatter aten::slice_scatter.out aten::slow_conv3d_forward diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 16a384e5702..e692ae9e923 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1503,6 +1503,11 @@ grad_output: grad.slice_symint(dim, start, end, step) result: auto_linear +- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: grad.slice_symint(dim, start, end, step) + src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step) + result: auto_linear + - name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step) src: grad.slice_symint(dim, start, end, step) diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index 154c0be8648..6e713579445 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -71,6 +71,7 @@ VIEW_FUNCTIONS = { "permute": "self", "select": "self", "slice": "self", + "slice_inverse": "self", "split": "self", "split_with_sizes": "self", "squeeze": "self", diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 6b336cd6888..f9500fa2182 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -84,7 +84,8 @@ def add_view_copy_derivatives( view_copy_differentiability_infos[dispatch_key] = view_copy_info else: break - if len(view_copy_differentiability_infos) > 0: + # prefer manually-defined derivatives if any + if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: assert fn_schema is not None view_infos[fn_schema] = view_copy_differentiability_infos @@ -105,11 +106,10 @@ def load_derivatives( # From the parsed native functions, separate out the (generated) view_copy functions, # so we can generate derivatives for them separately. native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) - native_functions_without_view_copies = concatMap( - # We need to pull out the view_inplace ops too, since they might have their own derivative entries. + native_functions = concatMap( lambda g: [g] if isinstance(g, NativeFunction) - else list(g.functions(include_copy=False)), + else list(g.functions(include_copy=True)), native_functions_with_view_groups, ) view_groups = [ @@ -126,7 +126,7 @@ def load_derivatives( FunctionSchema, List[NativeFunction] ] = defaultdict(list) functions_by_schema: Dict[str, NativeFunction] = {} - for function in native_functions_without_view_copies: + for function in native_functions: functions_by_signature[function.func.signature()].append(function) assert str(function.func) not in functions_by_schema functions_by_schema[str(function.func)] = function diff --git a/torch/overrides.py b/torch/overrides.py index 13f5681dd48..7c161510e45 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1028,6 +1028,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, torch.select: lambda input, dim, index: -1, torch.select_scatter: lambda input, src, dim, index: -1, + torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1, torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1, torch.selu: lambda input, inplace=False: -1, torch.sigmoid: lambda input, out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b78122dbc96..1a03115006e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15592,7 +15592,13 @@ op_db: List[OpInfo] = [ gradcheck_fast_mode=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - supports_out=False), + skips=( + # RuntimeError: Internal error: pybind11::error_already_set called while + # Python error indicator not set. + # TODO: Investigate this more + DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out'), + ), + supports_out=True), UnaryUfuncInfo('signbit', ref=np.signbit, dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 9f4d48b1296..67771c0ad93 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -87,6 +87,12 @@ class GenCompositeViewCopyKernel: def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None + elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy": + # If the view_copy doesn't match the standard naming scheme of _copy, + # assume it already exists and doesn't need to be generated. + # Example: slice_inverse() with the copy variant named slice_scatter() + # instead of slice_inverse_copy() + return None metadata = self.backend_index.get_kernel(g.view_copy) assert metadata is not None diff --git a/torchgen/model.py b/torchgen/model.py index 02609a96b2c..f73971e51dd 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1613,8 +1613,11 @@ class FunctionSchema: ) base_name = self.name.name.base - if strip_view_copy_name and base_name.endswith("_copy"): - base_name = base_name.replace("_copy", "") + if strip_view_copy_name: + if base_name.endswith("_copy"): + base_name = base_name.replace("_copy", "") + elif base_name.endswith("_scatter"): + base_name = base_name.replace("scatter", "inverse") # find mutable inputs that are not originally returned, and convert them to returns returns_from_mutable_inputs = tuple( @@ -2603,9 +2606,9 @@ class NativeFunctionsViewGroup: " See Note [view_copy NativeFunctions] for details." ) else: - assert self.view_copy.func.name.name.base.endswith("_copy") + assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter")) assert self.view.func.signature() == self.view_copy.func.signature( - strip_view_copy_name=True + strip_view_copy_name=True, ) assert "view_copy" in self.view_copy.tags, ( f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" @@ -2659,6 +2662,13 @@ def gets_generated_view_copy(f: NativeFunction) -> bool: # We also don't need to generate copy variants for inplace views. if "inplace_view" in f.tags: return False + # Assume ops ending in _inverse have manually-defined copy variants + # (e.g. slice_inverse() has the copy variant slice_scatter()). + # We -could- probably generate these as well, but the codegen will be + # slightly different, and hand-writing these few kernels keeps codegen + # complexity lower. + if f.func.name.name.base.endswith("_inverse"): + return False return True