Introduce slice_inverse() op (#117041)

Introduces a new op `slice_inverse()`. This is used in the reverse view_func for slice and several other ops (e.g. `split_with_sizes`, `chunk`). It's implemented behind the scenes by a call to `as_strided()`, but it's easier for subclasses to implement the more limited `slice_inverse()` than the full `as_strided()`. This PR:
* Introduces the op itself
* Updates all relevant functional inverses to call `slice_inverse()` instead of `as_strided()` directly
* Makes codegen changes to allow `slice_scatter()` to be the copy variant for `slice_inverse()`
    * Need to avoid view_copy codegen (assumes if view name ends in inverse, we don't need to gen one, which is possibly a bad assumption)

@albanD / @soulitzer / @bdhirsh: I'm most interested in your thoughts on the codegen changes and whether this is the right way to go.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117041
Approved by: https://github.com/bdhirsh
This commit is contained in:
Joel Schlosser 2024-01-16 14:42:57 -05:00 committed by PyTorch MergeBot
parent f6767244cf
commit 5aac95c713
11 changed files with 103 additions and 36 deletions

View file

@ -224,48 +224,48 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
// NB: assumes mutated_view is a narrowed view of base.
// We should NOT do this for functionalization
return mutated_view.as_strided_symint(
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
return mutated_view.slice_inverse_symint(
base, dim, std::move(start), std::move(end), std::move(step));
} else {
return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step));
}
}
Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
// on top of the base tensor.
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
dim = at::maybe_wrap_dim(dim, base.dim());
auto dim_size = base.sym_size(dim);
auto start = split_size * mutated_view_idx;
auto end = split_size + start;
if (end > dim_size) end = dim_size;
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
// NB: assumes mutated_view is a narrowed view of base.
// We should NOT do this for functionalization
return mutated_view.as_strided_symint(
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
} else {
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
// on top of the base tensor.
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
dim = at::maybe_wrap_dim(dim, base.dim());
auto dim_size = base.sym_size(dim);
auto start = split_size * mutated_view_idx;
auto end = split_size + start;
if (end > dim_size) end = dim_size;
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
}
}
Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
dim = at::maybe_wrap_dim(dim, base.dim());
auto dim_size = base.sym_size(dim);
c10::SymInt start = 0;
for (auto i = 0; i < mutated_view_idx; ++i) {
start += split_sizes[i];
}
auto end = start + split_sizes[mutated_view_idx];
if (end > dim_size) end = dim_size;
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
// NB: assumes mutated_view is a narrowed view of base.
// We should NOT do this for functionalization
return mutated_view.as_strided_symint(
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
} else {
dim = at::maybe_wrap_dim(dim, base.dim());
auto dim_size = base.sym_size(dim);
c10::SymInt start = 0;
for (auto i = 0; i < mutated_view_idx; ++i) {
start += split_sizes[i];
}
auto end = start + split_sizes[mutated_view_idx];
if (end > dim_size) end = dim_size;
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
}
}
@ -428,12 +428,22 @@ Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Ten
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
// NB: assumes mutated_view is a narrowed view of base.
// We should NOT do this for functionalization
return mutated_view.as_strided_symint(
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
return mutated_view.slice_inverse_symint(base, dim, std::move(start), start + length, 1);
} else {
return base.slice_scatter_symint(
mutated_view, dim, std::move(start), start + length, 1);
}
}
Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
// slice_inverse() inverse is just slice()
if (inverse_return_mode == InverseReturnMode::NeverView) {
return at::slice_copy_symint(
mutated_view, dim, std::move(start), std::move(end), std::move(step));
} else {
return mutated_view.slice_symint(
dim, std::move(start), std::move(end), std::move(step));
}
}
} // namespace at::functionalization

View file

@ -151,6 +151,7 @@
#include <ATen/ops/slice.h>
#include <ATen/ops/slice_backward_native.h>
#include <ATen/ops/slice_copy_native.h>
#include <ATen/ops/slice_inverse_native.h>
#include <ATen/ops/slice_native.h>
#include <ATen/ops/slice_scatter_native.h>
#include <ATen/ops/sparse_coo_tensor.h>
@ -2559,6 +2560,17 @@ Tensor slice(
return result;
}
Tensor slice_inverse_symint(
const Tensor& self,
const Tensor& base,
int64_t /* dim */,
c10::optional<SymInt> /* start */,
c10::optional<SymInt> /* end */,
SymInt /* step */) {
// assume self has enough to storage to be viewed with base's metadata
return self.as_strided_symint(base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
}
Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.slice(dim, start, end, step).copy_(grad);

View file

@ -5379,6 +5379,21 @@
CompositeExplicitAutograd: slice_backward
autogen: slice_backward.out
# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk,
# slice.Tensor, split_with_sizes, et. al.). Currently, these are only used during fake-ification
# of PT2 graph input subclass instances that are views. This means:
# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it)
# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it)
# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph
# input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is
# easier to implement for a subclass than as_strided()
- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
variants: function, method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: slice_inverse_symint
- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
variants: function, method
device_check: NoCheck
@ -5386,7 +5401,7 @@
dispatch:
CompositeExplicitAutogradNonFunctional: slice_scatter
autogen: slice_scatter.out
tags: core
tags: [core, view_copy]
- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
variants: function, method

View file

@ -1161,6 +1161,7 @@ aten::set_.source_Storage_storage_offset
aten::set_.source_Tensor
aten::slice_copy.Tensor
aten::slice_copy.Tensor_out
aten::slice_inverse
aten::slice_scatter
aten::slice_scatter.out
aten::slow_conv3d_forward

View file

@ -1503,6 +1503,11 @@
grad_output: grad.slice_symint(dim, start, end, step)
result: auto_linear
- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
self: grad.slice_symint(dim, start, end, step)
src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step)
result: auto_linear
- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step)
src: grad.slice_symint(dim, start, end, step)

View file

@ -71,6 +71,7 @@ VIEW_FUNCTIONS = {
"permute": "self",
"select": "self",
"slice": "self",
"slice_inverse": "self",
"split": "self",
"split_with_sizes": "self",
"squeeze": "self",

View file

@ -84,7 +84,8 @@ def add_view_copy_derivatives(
view_copy_differentiability_infos[dispatch_key] = view_copy_info
else:
break
if len(view_copy_differentiability_infos) > 0:
# prefer manually-defined derivatives if any
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
assert fn_schema is not None
view_infos[fn_schema] = view_copy_differentiability_infos
@ -105,11 +106,10 @@ def load_derivatives(
# From the parsed native functions, separate out the (generated) view_copy functions,
# so we can generate derivatives for them separately.
native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
native_functions_without_view_copies = concatMap(
# We need to pull out the view_inplace ops too, since they might have their own derivative entries.
native_functions = concatMap(
lambda g: [g]
if isinstance(g, NativeFunction)
else list(g.functions(include_copy=False)),
else list(g.functions(include_copy=True)),
native_functions_with_view_groups,
)
view_groups = [
@ -126,7 +126,7 @@ def load_derivatives(
FunctionSchema, List[NativeFunction]
] = defaultdict(list)
functions_by_schema: Dict[str, NativeFunction] = {}
for function in native_functions_without_view_copies:
for function in native_functions:
functions_by_signature[function.func.signature()].append(function)
assert str(function.func) not in functions_by_schema
functions_by_schema[str(function.func)] = function

View file

@ -1028,6 +1028,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
torch.selu: lambda input, inplace=False: -1,
torch.sigmoid: lambda input, out=None: -1,

View file

@ -15592,7 +15592,13 @@ op_db: List[OpInfo] = [
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_out=False),
skips=(
# RuntimeError: Internal error: pybind11::error_already_set called while
# Python error indicator not set.
# TODO: Investigate this more
DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out'),
),
supports_out=True),
UnaryUfuncInfo('signbit',
ref=np.signbit,
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),

View file

@ -87,6 +87,12 @@ class GenCompositeViewCopyKernel:
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]:
if g.view_copy is None:
return None
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
# If the view_copy doesn't match the standard naming scheme of <op>_copy,
# assume it already exists and doesn't need to be generated.
# Example: slice_inverse() with the copy variant named slice_scatter()
# instead of slice_inverse_copy()
return None
metadata = self.backend_index.get_kernel(g.view_copy)
assert metadata is not None

View file

@ -1613,8 +1613,11 @@ class FunctionSchema:
)
base_name = self.name.name.base
if strip_view_copy_name and base_name.endswith("_copy"):
base_name = base_name.replace("_copy", "")
if strip_view_copy_name:
if base_name.endswith("_copy"):
base_name = base_name.replace("_copy", "")
elif base_name.endswith("_scatter"):
base_name = base_name.replace("scatter", "inverse")
# find mutable inputs that are not originally returned, and convert them to returns
returns_from_mutable_inputs = tuple(
@ -2603,9 +2606,9 @@ class NativeFunctionsViewGroup:
" See Note [view_copy NativeFunctions] for details."
)
else:
assert self.view_copy.func.name.name.base.endswith("_copy")
assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
assert self.view.func.signature() == self.view_copy.func.signature(
strip_view_copy_name=True
strip_view_copy_name=True,
)
assert "view_copy" in self.view_copy.tags, (
f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
@ -2659,6 +2662,13 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
# We also don't need to generate copy variants for inplace views.
if "inplace_view" in f.tags:
return False
# Assume ops ending in _inverse have manually-defined copy variants
# (e.g. slice_inverse() has the copy variant slice_scatter()).
# We -could- probably generate these as well, but the codegen will be
# slightly different, and hand-writing these few kernels keeps codegen
# complexity lower.
if f.func.name.name.base.endswith("_inverse"):
return False
return True