mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Introduce slice_inverse() op (#117041)
Introduces a new op `slice_inverse()`. This is used in the reverse view_func for slice and several other ops (e.g. `split_with_sizes`, `chunk`). It's implemented behind the scenes by a call to `as_strided()`, but it's easier for subclasses to implement the more limited `slice_inverse()` than the full `as_strided()`. This PR:
* Introduces the op itself
* Updates all relevant functional inverses to call `slice_inverse()` instead of `as_strided()` directly
* Makes codegen changes to allow `slice_scatter()` to be the copy variant for `slice_inverse()`
* Need to avoid view_copy codegen (assumes if view name ends in inverse, we don't need to gen one, which is possibly a bad assumption)
@albanD / @soulitzer / @bdhirsh: I'm most interested in your thoughts on the codegen changes and whether this is the right way to go.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117041
Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
f6767244cf
commit
5aac95c713
11 changed files with 103 additions and 36 deletions
|
|
@ -224,48 +224,48 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor
|
|||
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
|
||||
// NB: assumes mutated_view is a narrowed view of base.
|
||||
// We should NOT do this for functionalization
|
||||
return mutated_view.as_strided_symint(
|
||||
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
|
||||
return mutated_view.slice_inverse_symint(
|
||||
base, dim, std::move(start), std::move(end), std::move(step));
|
||||
} else {
|
||||
return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step));
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
|
||||
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
|
||||
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
|
||||
// on top of the base tensor.
|
||||
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
|
||||
dim = at::maybe_wrap_dim(dim, base.dim());
|
||||
auto dim_size = base.sym_size(dim);
|
||||
auto start = split_size * mutated_view_idx;
|
||||
auto end = split_size + start;
|
||||
if (end > dim_size) end = dim_size;
|
||||
|
||||
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
|
||||
// NB: assumes mutated_view is a narrowed view of base.
|
||||
// We should NOT do this for functionalization
|
||||
return mutated_view.as_strided_symint(
|
||||
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
|
||||
return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
|
||||
} else {
|
||||
// It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
|
||||
// For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
|
||||
// on top of the base tensor.
|
||||
// For autograd, we have all of the tensors outputted by split() and we just want to stack them.
|
||||
dim = at::maybe_wrap_dim(dim, base.dim());
|
||||
auto dim_size = base.sym_size(dim);
|
||||
auto start = split_size * mutated_view_idx;
|
||||
auto end = split_size + start;
|
||||
if (end > dim_size) end = dim_size;
|
||||
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
|
||||
dim = at::maybe_wrap_dim(dim, base.dim());
|
||||
auto dim_size = base.sym_size(dim);
|
||||
c10::SymInt start = 0;
|
||||
for (auto i = 0; i < mutated_view_idx; ++i) {
|
||||
start += split_sizes[i];
|
||||
}
|
||||
auto end = start + split_sizes[mutated_view_idx];
|
||||
if (end > dim_size) end = dim_size;
|
||||
|
||||
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
|
||||
// NB: assumes mutated_view is a narrowed view of base.
|
||||
// We should NOT do this for functionalization
|
||||
return mutated_view.as_strided_symint(
|
||||
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
|
||||
return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
|
||||
} else {
|
||||
dim = at::maybe_wrap_dim(dim, base.dim());
|
||||
auto dim_size = base.sym_size(dim);
|
||||
c10::SymInt start = 0;
|
||||
for (auto i = 0; i < mutated_view_idx; ++i) {
|
||||
start += split_sizes[i];
|
||||
}
|
||||
auto end = start + split_sizes[mutated_view_idx];
|
||||
if (end > dim_size) end = dim_size;
|
||||
return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -428,12 +428,22 @@ Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Ten
|
|||
if (inverse_return_mode == InverseReturnMode::AlwaysView) {
|
||||
// NB: assumes mutated_view is a narrowed view of base.
|
||||
// We should NOT do this for functionalization
|
||||
return mutated_view.as_strided_symint(
|
||||
base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
|
||||
return mutated_view.slice_inverse_symint(base, dim, std::move(start), start + length, 1);
|
||||
} else {
|
||||
return base.slice_scatter_symint(
|
||||
mutated_view, dim, std::move(start), start + length, 1);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
|
||||
// slice_inverse() inverse is just slice()
|
||||
if (inverse_return_mode == InverseReturnMode::NeverView) {
|
||||
return at::slice_copy_symint(
|
||||
mutated_view, dim, std::move(start), std::move(end), std::move(step));
|
||||
} else {
|
||||
return mutated_view.slice_symint(
|
||||
dim, std::move(start), std::move(end), std::move(step));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::functionalization
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@
|
|||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/slice_backward_native.h>
|
||||
#include <ATen/ops/slice_copy_native.h>
|
||||
#include <ATen/ops/slice_inverse_native.h>
|
||||
#include <ATen/ops/slice_native.h>
|
||||
#include <ATen/ops/slice_scatter_native.h>
|
||||
#include <ATen/ops/sparse_coo_tensor.h>
|
||||
|
|
@ -2559,6 +2560,17 @@ Tensor slice(
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor slice_inverse_symint(
|
||||
const Tensor& self,
|
||||
const Tensor& base,
|
||||
int64_t /* dim */,
|
||||
c10::optional<SymInt> /* start */,
|
||||
c10::optional<SymInt> /* end */,
|
||||
SymInt /* step */) {
|
||||
// assume self has enough to storage to be viewed with base's metadata
|
||||
return self.as_strided_symint(base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
|
||||
}
|
||||
|
||||
Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
||||
auto grad_input = at::zeros(input_sizes, grad.options());
|
||||
grad_input.slice(dim, start, end, step).copy_(grad);
|
||||
|
|
|
|||
|
|
@ -5379,6 +5379,21 @@
|
|||
CompositeExplicitAutograd: slice_backward
|
||||
autogen: slice_backward.out
|
||||
|
||||
# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk,
|
||||
# slice.Tensor, split_with_sizes, et. al.). Currently, these are only used during fake-ification
|
||||
# of PT2 graph input subclass instances that are views. This means:
|
||||
# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it)
|
||||
# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it)
|
||||
# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph
|
||||
# input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is
|
||||
# easier to implement for a subclass than as_strided()
|
||||
- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: slice_inverse_symint
|
||||
|
||||
- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
|
|
@ -5386,7 +5401,7 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: slice_scatter
|
||||
autogen: slice_scatter.out
|
||||
tags: core
|
||||
tags: [core, view_copy]
|
||||
|
||||
- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -1161,6 +1161,7 @@ aten::set_.source_Storage_storage_offset
|
|||
aten::set_.source_Tensor
|
||||
aten::slice_copy.Tensor
|
||||
aten::slice_copy.Tensor_out
|
||||
aten::slice_inverse
|
||||
aten::slice_scatter
|
||||
aten::slice_scatter.out
|
||||
aten::slow_conv3d_forward
|
||||
|
|
|
|||
|
|
@ -1503,6 +1503,11 @@
|
|||
grad_output: grad.slice_symint(dim, start, end, step)
|
||||
result: auto_linear
|
||||
|
||||
- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
|
||||
self: grad.slice_symint(dim, start, end, step)
|
||||
src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step)
|
||||
result: auto_linear
|
||||
|
||||
- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
||||
self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step)
|
||||
src: grad.slice_symint(dim, start, end, step)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ VIEW_FUNCTIONS = {
|
|||
"permute": "self",
|
||||
"select": "self",
|
||||
"slice": "self",
|
||||
"slice_inverse": "self",
|
||||
"split": "self",
|
||||
"split_with_sizes": "self",
|
||||
"squeeze": "self",
|
||||
|
|
|
|||
|
|
@ -84,7 +84,8 @@ def add_view_copy_derivatives(
|
|||
view_copy_differentiability_infos[dispatch_key] = view_copy_info
|
||||
else:
|
||||
break
|
||||
if len(view_copy_differentiability_infos) > 0:
|
||||
# prefer manually-defined derivatives if any
|
||||
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
|
||||
assert fn_schema is not None
|
||||
view_infos[fn_schema] = view_copy_differentiability_infos
|
||||
|
||||
|
|
@ -105,11 +106,10 @@ def load_derivatives(
|
|||
# From the parsed native functions, separate out the (generated) view_copy functions,
|
||||
# so we can generate derivatives for them separately.
|
||||
native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
|
||||
native_functions_without_view_copies = concatMap(
|
||||
# We need to pull out the view_inplace ops too, since they might have their own derivative entries.
|
||||
native_functions = concatMap(
|
||||
lambda g: [g]
|
||||
if isinstance(g, NativeFunction)
|
||||
else list(g.functions(include_copy=False)),
|
||||
else list(g.functions(include_copy=True)),
|
||||
native_functions_with_view_groups,
|
||||
)
|
||||
view_groups = [
|
||||
|
|
@ -126,7 +126,7 @@ def load_derivatives(
|
|||
FunctionSchema, List[NativeFunction]
|
||||
] = defaultdict(list)
|
||||
functions_by_schema: Dict[str, NativeFunction] = {}
|
||||
for function in native_functions_without_view_copies:
|
||||
for function in native_functions:
|
||||
functions_by_signature[function.func.signature()].append(function)
|
||||
assert str(function.func) not in functions_by_schema
|
||||
functions_by_schema[str(function.func)] = function
|
||||
|
|
|
|||
|
|
@ -1028,6 +1028,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
|
||||
torch.select: lambda input, dim, index: -1,
|
||||
torch.select_scatter: lambda input, src, dim, index: -1,
|
||||
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
||||
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
|
||||
torch.selu: lambda input, inplace=False: -1,
|
||||
torch.sigmoid: lambda input, out=None: -1,
|
||||
|
|
|
|||
|
|
@ -15592,7 +15592,13 @@ op_db: List[OpInfo] = [
|
|||
gradcheck_fast_mode=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
supports_out=False),
|
||||
skips=(
|
||||
# RuntimeError: Internal error: pybind11::error_already_set called while
|
||||
# Python error indicator not set.
|
||||
# TODO: Investigate this more
|
||||
DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_out'),
|
||||
),
|
||||
supports_out=True),
|
||||
UnaryUfuncInfo('signbit',
|
||||
ref=np.signbit,
|
||||
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
|
||||
|
|
|
|||
|
|
@ -87,6 +87,12 @@ class GenCompositeViewCopyKernel:
|
|||
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
|
||||
# If the view_copy doesn't match the standard naming scheme of <op>_copy,
|
||||
# assume it already exists and doesn't need to be generated.
|
||||
# Example: slice_inverse() with the copy variant named slice_scatter()
|
||||
# instead of slice_inverse_copy()
|
||||
return None
|
||||
|
||||
metadata = self.backend_index.get_kernel(g.view_copy)
|
||||
assert metadata is not None
|
||||
|
|
|
|||
|
|
@ -1613,8 +1613,11 @@ class FunctionSchema:
|
|||
)
|
||||
|
||||
base_name = self.name.name.base
|
||||
if strip_view_copy_name and base_name.endswith("_copy"):
|
||||
base_name = base_name.replace("_copy", "")
|
||||
if strip_view_copy_name:
|
||||
if base_name.endswith("_copy"):
|
||||
base_name = base_name.replace("_copy", "")
|
||||
elif base_name.endswith("_scatter"):
|
||||
base_name = base_name.replace("scatter", "inverse")
|
||||
|
||||
# find mutable inputs that are not originally returned, and convert them to returns
|
||||
returns_from_mutable_inputs = tuple(
|
||||
|
|
@ -2603,9 +2606,9 @@ class NativeFunctionsViewGroup:
|
|||
" See Note [view_copy NativeFunctions] for details."
|
||||
)
|
||||
else:
|
||||
assert self.view_copy.func.name.name.base.endswith("_copy")
|
||||
assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
|
||||
assert self.view.func.signature() == self.view_copy.func.signature(
|
||||
strip_view_copy_name=True
|
||||
strip_view_copy_name=True,
|
||||
)
|
||||
assert "view_copy" in self.view_copy.tags, (
|
||||
f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
|
||||
|
|
@ -2659,6 +2662,13 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
|
|||
# We also don't need to generate copy variants for inplace views.
|
||||
if "inplace_view" in f.tags:
|
||||
return False
|
||||
# Assume ops ending in _inverse have manually-defined copy variants
|
||||
# (e.g. slice_inverse() has the copy variant slice_scatter()).
|
||||
# We -could- probably generate these as well, but the codegen will be
|
||||
# slightly different, and hand-writing these few kernels keeps codegen
|
||||
# complexity lower.
|
||||
if f.func.name.name.base.endswith("_inverse"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue