mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[auto_functionalized] Support Tensor(a!)[]? (#145400)
Summary: This is just updating some of the checks to allow the Tensor(a!)[]? type through. Fixes #144072 Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/145400 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
282d185ec1
commit
1bb977a2a4
3 changed files with 53 additions and 33 deletions
|
|
@ -8493,6 +8493,39 @@ class CommonTemplate:
|
|||
f(*args)
|
||||
self.assertEqual(cloned_args, args)
|
||||
|
||||
@skip_if_cpp_wrapper(
|
||||
"Without major redesign, cpp_wrapper will not support custom ops that are "
|
||||
"defined in Python."
|
||||
)
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_fallback_mutable_op_list_tensor(self):
|
||||
@torch.library.custom_op(
|
||||
"mylib::mysin",
|
||||
mutates_args=["out_list"],
|
||||
schema="(Tensor x, Tensor(a!)[]? out_list) -> Tensor",
|
||||
)
|
||||
def mysin(x, out_list) -> torch.Tensor:
|
||||
r = x.sin()
|
||||
if out_list is not None:
|
||||
out_list[0].copy_(r)
|
||||
return r
|
||||
|
||||
@mysin.register_fake
|
||||
def _(x, out_list) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
def fn(x):
|
||||
x = x * 3
|
||||
s = [torch.empty_like(x)]
|
||||
x = mysin(x, s)
|
||||
x = x / 3
|
||||
return x, s[0]
|
||||
|
||||
x = torch.randn(3, requires_grad=False)
|
||||
expected = fn(x)
|
||||
result = torch.compile(fn, fullgraph=True)(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
def test_fallback_mutable_op_with_return(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._library.utils as library_utils
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
from torch._C import DispatchKey
|
||||
|
|
@ -194,17 +195,17 @@ def write_view_information_to_args(
|
|||
|
||||
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
|
||||
arg = kwargs[arg_name]
|
||||
if isinstance(arg_type, torch.ListType):
|
||||
if library_utils.is_tensorlist_like_type(arg_type):
|
||||
if arg is None:
|
||||
kwargs[f"_{arg_name}_length"] = None
|
||||
else:
|
||||
kwargs[f"_{arg_name}_length"] = len(arg)
|
||||
for i, elem in enumerate(arg):
|
||||
write_single_view(
|
||||
f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
|
||||
)
|
||||
|
||||
kwargs[f"_{arg_name}_length"] = len(arg)
|
||||
for i, elem in enumerate(arg):
|
||||
write_single_view(
|
||||
f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
|
||||
)
|
||||
|
||||
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
|
||||
elif library_utils.is_tensor_like_type(arg_type):
|
||||
write_single_view(
|
||||
f"_{arg_name}",
|
||||
kwargs[arg_name],
|
||||
|
|
@ -257,7 +258,7 @@ def read_view_information_from_args(
|
|||
|
||||
args_view_info: dict[str, Any] = {}
|
||||
for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
|
||||
if isinstance(arg_type, torch.ListType):
|
||||
if library_utils.is_tensorlist_like_type(arg_type):
|
||||
length = get_arg(f"_{arg_name}_length")
|
||||
if length is None:
|
||||
# The whole list is None.
|
||||
|
|
@ -267,7 +268,7 @@ def read_view_information_from_args(
|
|||
read_single_view(f"_{arg_name}_{i}") for i in range(length)
|
||||
]
|
||||
|
||||
elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
|
||||
elif library_utils.is_tensor_like_type(arg_type):
|
||||
args_view_info[arg_name] = read_single_view(f"_{arg_name}")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported type {arg_type}")
|
||||
|
|
@ -382,20 +383,10 @@ def can_auto_functionalize(op: OperatorBase) -> bool:
|
|||
continue
|
||||
if not arg.alias_info.is_write:
|
||||
continue
|
||||
if type(arg.type) is torch.TensorType:
|
||||
if torch._library.utils.is_tensor_like_type(arg.type):
|
||||
continue
|
||||
if (
|
||||
type(arg.type) is torch.OptionalType
|
||||
and type(arg.type.getElementType()) is torch.TensorType
|
||||
):
|
||||
if torch._library.utils.is_tensorlist_like_type(arg.type):
|
||||
continue
|
||||
if (
|
||||
type(arg.type) is torch.ListType
|
||||
and type(arg.type.getElementType()) is torch.TensorType
|
||||
):
|
||||
continue
|
||||
# Not yet supported: other Tensor types. This includes things like
|
||||
# Tensor?[], Tensor[]?.
|
||||
return False
|
||||
|
||||
if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import sympy
|
|||
from sympy import Expr, Integer, Symbol
|
||||
|
||||
import torch._export.serde.schema as export_schema
|
||||
import torch._library.utils as library_utils
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -6382,13 +6383,7 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
# Assertions to make sure we didn't mismatch args
|
||||
if isinstance(info.type, torch.ListType):
|
||||
assert isinstance(arg, (list, tuple))
|
||||
is_optional_tensor = isinstance(
|
||||
info.type, torch.OptionalType
|
||||
) and isinstance(info.type.getElementType(), torch.TensorType)
|
||||
is_list_tensor = isinstance(info.type, torch.ListType) and isinstance(
|
||||
info.type.getElementType(), torch.TensorType
|
||||
)
|
||||
if is_optional_tensor or isinstance(info.type, torch.TensorType):
|
||||
if library_utils.is_tensor_like_type(info.type):
|
||||
# PyTorch also accepts None and scalar types for args marked as "Tensor".
|
||||
# We're not going to check all of them here.
|
||||
assert not isinstance(arg, (tuple, list))
|
||||
|
|
@ -6405,11 +6400,12 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
MutationOutput(NoneLayout(device=t.get_device()), t, self)
|
||||
)
|
||||
|
||||
if is_list_tensor:
|
||||
for tensor_arg in arg:
|
||||
add_alias(tensor_arg)
|
||||
if library_utils.is_tensorlist_like_type(info.type):
|
||||
if arg is not None:
|
||||
for optional_tensor_arg in arg:
|
||||
add_alias(optional_tensor_arg)
|
||||
else:
|
||||
assert isinstance(info.type, torch.TensorType) or is_optional_tensor
|
||||
assert library_utils.is_tensor_like_type(info.type)
|
||||
add_alias(arg)
|
||||
|
||||
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue