[auto_functionalized] Support Tensor(a!)[]? (#145400)

Summary:
This is just updating some of the checks to allow the Tensor(a!)[]? type
through.

Fixes #144072

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145400
Approved by: https://github.com/laithsakka
This commit is contained in:
rzou 2025-02-04 10:58:48 -08:00 committed by PyTorch MergeBot
parent 282d185ec1
commit 1bb977a2a4
3 changed files with 53 additions and 33 deletions

View file

@ -8493,6 +8493,39 @@ class CommonTemplate:
f(*args)
self.assertEqual(cloned_args, args)
@skip_if_cpp_wrapper(
"Without major redesign, cpp_wrapper will not support custom ops that are "
"defined in Python."
)
@config.patch(implicit_fallbacks=True)
def test_fallback_mutable_op_list_tensor(self):
@torch.library.custom_op(
"mylib::mysin",
mutates_args=["out_list"],
schema="(Tensor x, Tensor(a!)[]? out_list) -> Tensor",
)
def mysin(x, out_list) -> torch.Tensor:
r = x.sin()
if out_list is not None:
out_list[0].copy_(r)
return r
@mysin.register_fake
def _(x, out_list) -> torch.Tensor:
return torch.empty_like(x)
def fn(x):
x = x * 3
s = [torch.empty_like(x)]
x = mysin(x, s)
x = x / 3
return x, s[0]
x = torch.randn(3, requires_grad=False)
expected = fn(x)
result = torch.compile(fn, fullgraph=True)(x)
self.assertEqual(result, expected)
@config.patch(implicit_fallbacks=True)
def test_fallback_mutable_op_with_return(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as m:

View file

@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch._library.utils as library_utils
import torch.utils._pytree as pytree
from torch import Tensor
from torch._C import DispatchKey
@ -194,17 +195,17 @@ def write_view_information_to_args(
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
arg = kwargs[arg_name]
if isinstance(arg_type, torch.ListType):
if library_utils.is_tensorlist_like_type(arg_type):
if arg is None:
kwargs[f"_{arg_name}_length"] = None
else:
kwargs[f"_{arg_name}_length"] = len(arg)
for i, elem in enumerate(arg):
write_single_view(
f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
)
kwargs[f"_{arg_name}_length"] = len(arg)
for i, elem in enumerate(arg):
write_single_view(
f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
)
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
elif library_utils.is_tensor_like_type(arg_type):
write_single_view(
f"_{arg_name}",
kwargs[arg_name],
@ -257,7 +258,7 @@ def read_view_information_from_args(
args_view_info: dict[str, Any] = {}
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
if isinstance(arg_type, torch.ListType):
if library_utils.is_tensorlist_like_type(arg_type):
length = get_arg(f"_{arg_name}_length")
if length is None:
# The whole list is None.
@ -267,7 +268,7 @@ def read_view_information_from_args(
read_single_view(f"_{arg_name}_{i}") for i in range(length)
]
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
elif library_utils.is_tensor_like_type(arg_type):
args_view_info[arg_name] = read_single_view(f"_{arg_name}")
else:
raise RuntimeError(f"Unsupported type {arg_type}")
@ -382,20 +383,10 @@ def can_auto_functionalize(op: OperatorBase) -> bool:
continue
if not arg.alias_info.is_write:
continue
if type(arg.type) is torch.TensorType:
if torch._library.utils.is_tensor_like_type(arg.type):
continue
if (
type(arg.type) is torch.OptionalType
and type(arg.type.getElementType()) is torch.TensorType
):
if torch._library.utils.is_tensorlist_like_type(arg.type):
continue
if (
type(arg.type) is torch.ListType
and type(arg.type.getElementType()) is torch.TensorType
):
continue
# Not yet supported: other Tensor types. This includes things like
# Tensor?[], Tensor[]?.
return False
if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):

View file

@ -31,6 +31,7 @@ import sympy
from sympy import Expr, Integer, Symbol
import torch._export.serde.schema as export_schema
import torch._library.utils as library_utils
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
@ -6382,13 +6383,7 @@ class FallbackKernel(ExternKernelAlloc):
# Assertions to make sure we didn't mismatch args
if isinstance(info.type, torch.ListType):
assert isinstance(arg, (list, tuple))
is_optional_tensor = isinstance(
info.type, torch.OptionalType
) and isinstance(info.type.getElementType(), torch.TensorType)
is_list_tensor = isinstance(info.type, torch.ListType) and isinstance(
info.type.getElementType(), torch.TensorType
)
if is_optional_tensor or isinstance(info.type, torch.TensorType):
if library_utils.is_tensor_like_type(info.type):
# PyTorch also accepts None and scalar types for args marked as "Tensor".
# We're not going to check all of them here.
assert not isinstance(arg, (tuple, list))
@ -6405,11 +6400,12 @@ class FallbackKernel(ExternKernelAlloc):
MutationOutput(NoneLayout(device=t.get_device()), t, self)
)
if is_list_tensor:
for tensor_arg in arg:
add_alias(tensor_arg)
if library_utils.is_tensorlist_like_type(info.type):
if arg is not None:
for optional_tensor_arg in arg:
add_alias(optional_tensor_arg)
else:
assert isinstance(info.type, torch.TensorType) or is_optional_tensor
assert library_utils.is_tensor_like_type(info.type)
add_alias(arg)
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):