diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 5e92eaeb4e6..94ec64b1356 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8493,6 +8493,39 @@ class CommonTemplate: f(*args) self.assertEqual(cloned_args, args) + @skip_if_cpp_wrapper( + "Without major redesign, cpp_wrapper will not support custom ops that are " + "defined in Python." + ) + @config.patch(implicit_fallbacks=True) + def test_fallback_mutable_op_list_tensor(self): + @torch.library.custom_op( + "mylib::mysin", + mutates_args=["out_list"], + schema="(Tensor x, Tensor(a!)[]? out_list) -> Tensor", + ) + def mysin(x, out_list) -> torch.Tensor: + r = x.sin() + if out_list is not None: + out_list[0].copy_(r) + return r + + @mysin.register_fake + def _(x, out_list) -> torch.Tensor: + return torch.empty_like(x) + + def fn(x): + x = x * 3 + s = [torch.empty_like(x)] + x = mysin(x, s) + x = x / 3 + return x, s[0] + + x = torch.randn(3, requires_grad=False) + expected = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, expected) + @config.patch(implicit_fallbacks=True) def test_fallback_mutable_op_with_return(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 0943f8030e2..69beab06665 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any, Optional, Union import torch +import torch._library.utils as library_utils import torch.utils._pytree as pytree from torch import Tensor from torch._C import DispatchKey @@ -194,17 +195,17 @@ def write_view_information_to_args( for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): arg = kwargs[arg_name] - if isinstance(arg_type, torch.ListType): + if library_utils.is_tensorlist_like_type(arg_type): if arg is None: kwargs[f"_{arg_name}_length"] = None + else: + kwargs[f"_{arg_name}_length"] = len(arg) + for i, elem in enumerate(arg): + write_single_view( + f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] + ) - kwargs[f"_{arg_name}_length"] = len(arg) - for i, elem in enumerate(arg): - write_single_view( - f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] - ) - - elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + elif library_utils.is_tensor_like_type(arg_type): write_single_view( f"_{arg_name}", kwargs[arg_name], @@ -257,7 +258,7 @@ def read_view_information_from_args( args_view_info: dict[str, Any] = {} for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): - if isinstance(arg_type, torch.ListType): + if library_utils.is_tensorlist_like_type(arg_type): length = get_arg(f"_{arg_name}_length") if length is None: # The whole list is None. @@ -267,7 +268,7 @@ def read_view_information_from_args( read_single_view(f"_{arg_name}_{i}") for i in range(length) ] - elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + elif library_utils.is_tensor_like_type(arg_type): args_view_info[arg_name] = read_single_view(f"_{arg_name}") else: raise RuntimeError(f"Unsupported type {arg_type}") @@ -382,20 +383,10 @@ def can_auto_functionalize(op: OperatorBase) -> bool: continue if not arg.alias_info.is_write: continue - if type(arg.type) is torch.TensorType: + if torch._library.utils.is_tensor_like_type(arg.type): continue - if ( - type(arg.type) is torch.OptionalType - and type(arg.type.getElementType()) is torch.TensorType - ): + if torch._library.utils.is_tensorlist_like_type(arg.type): continue - if ( - type(arg.type) is torch.ListType - and type(arg.type.getElementType()) is torch.TensorType - ): - continue - # Not yet supported: other Tensor types. This includes things like - # Tensor?[], Tensor[]?. return False if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1114ed5bc86..6a38f92e46e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -31,6 +31,7 @@ import sympy from sympy import Expr, Integer, Symbol import torch._export.serde.schema as export_schema +import torch._library.utils as library_utils import torch._logging import torch.fx import torch.utils._pytree as pytree @@ -6382,13 +6383,7 @@ class FallbackKernel(ExternKernelAlloc): # Assertions to make sure we didn't mismatch args if isinstance(info.type, torch.ListType): assert isinstance(arg, (list, tuple)) - is_optional_tensor = isinstance( - info.type, torch.OptionalType - ) and isinstance(info.type.getElementType(), torch.TensorType) - is_list_tensor = isinstance(info.type, torch.ListType) and isinstance( - info.type.getElementType(), torch.TensorType - ) - if is_optional_tensor or isinstance(info.type, torch.TensorType): + if library_utils.is_tensor_like_type(info.type): # PyTorch also accepts None and scalar types for args marked as "Tensor". # We're not going to check all of them here. assert not isinstance(arg, (tuple, list)) @@ -6405,11 +6400,12 @@ class FallbackKernel(ExternKernelAlloc): MutationOutput(NoneLayout(device=t.get_device()), t, self) ) - if is_list_tensor: - for tensor_arg in arg: - add_alias(tensor_arg) + if library_utils.is_tensorlist_like_type(info.type): + if arg is not None: + for optional_tensor_arg in arg: + add_alias(optional_tensor_arg) else: - assert isinstance(info.type, torch.TensorType) or is_optional_tensor + assert library_utils.is_tensor_like_type(info.type) add_alias(arg) for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):