add inplace_view tag to resize_() (#82667)

`resize_()` is annoying because it needs special casing for functionalization. It's technically an inplace-view op, but it can't really have a pure view variant, since calling resize_() might bust the old storage. I gave it an `inplace_view` tag so that stuff like `FakeTensor` that relies on tags will pick it up properly, which required  jumping through some codegen hoops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82667
Approved by: https://github.com/eellison
This commit is contained in:
Brian Hirsh 2022-08-02 14:35:12 -07:00 committed by PyTorch MergeBot
parent aa40503954
commit 684ce1b0bc
5 changed files with 14 additions and 5 deletions

View file

@ -2076,6 +2076,7 @@
variants: method
device_check: NoCheck
device_guard: False
tags: inplace_view
dispatch:
CPU, Meta: resize_
CUDA: resize_cuda_

View file

@ -244,7 +244,12 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
f"Consider adding 'structured=True' to the delegated operator"
)
if "inplace_view" in f.tags:
# See Note [resize_ in Functionalization]
# resize_() is technically an inplace view op (and therefore needs the tag),
# but it would be overkill to add a true "view" variant of resize.
# Instead, resize_() gets special treatment in functionalization,
# and we have a resize() op that is non-aliasing + functional.
if "inplace_view" in f.tags and str(f.func.name) != "resize_":
base_name = f.func.name.name
overload_name = f.func.name.overload_name
assert base_name.inplace, (

View file

@ -53,6 +53,8 @@ MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
# It will be BC-breaking, but we should fix their schemas.
# should be inplace?
"record_stream",
# See Note [resize_ in Functionalization]
"resize_",
]
)

View file

@ -890,7 +890,10 @@ class NativeFunction:
is_non_mutating_view = len(rets) > 0 and any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
is_inplace_view = "inplace_view" in self.tags
# See Note [resize_ in Functionalization] for more dtails
is_inplace_view = (
"inplace_view" in self.tags and str(self.func.name) != "resize_"
)
is_wildcard_view = any(
inp.annotation is not None and inp.annotation.alias_set_after != ""
for inp in self.func.schema_order_arguments()

View file

@ -304,9 +304,7 @@ def add_generated_native_functions(
# Don't bother generating functions trio's for native functions that bypass the dispatcher.
are_manual = all(f.manual_cpp_binding for f in d.values())
# Don't bother generating functional + out= variants for view operators
has_view_ops = (
has_inplace and "inplace_view" in d[SchemaKind.inplace].tags
) or any(f.is_view_op for f in d.values())
has_view_ops = any(f.is_view_op for f in d.values())
# Don't generate the other variants for CompositeImplicitAutograd operators.
# We could probably do this, but the main benefit of generating the function triplets
# is for transforms that need them, and transforms don't need to act directly