return_and_correct_aliasing: massage some schemas to work with torchgen (#108897)

This issue is that `str(torch.ops.aten.conv2d.default._schema)` does not return the same schema that is in native_functions.yaml ([link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L1654)).

Torchscript appears to change the default arg string `int[2] strides=1` to `int[2] strides=[1, 1]`. If you try to parse that with torchgen, torchgen is unhappy (it tries to split arguments on comma, but now we have a comma inside of the default argument).

Fixing the issue directly in torchgen was a bit more painful, so I opted just to undo the transformation that torchscript made: convert `=[1, 1]` back into `=1`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108897
Approved by: https://github.com/ezyang
ghstack dependencies: #106404, #107917
This commit is contained in:
Brian Hirsh 2023-09-15 09:58:21 -07:00 committed by PyTorch MergeBot
parent 0ad595954a
commit 71b4b32014
2 changed files with 33 additions and 21 deletions

View file

@ -2249,6 +2249,16 @@ class TestWrapperSubclassAliasing(TestCase):
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
def test_wrapper_subclass_aliasing_conv2d(self, device):
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
kwargs = {}
# conv2d has a default arg 'int[2] strides=0',
# which torchscript expands into 'int[2] strides=[0, 0]'
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(torch.ops.aten.conv2d.default, args, kwargs)
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
if __name__ == '__main__':

View file

@ -192,7 +192,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
# plain tensors, we could remove the assert and just not perform the aliasing,
# but it seems safer to learn more about this case first.
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
assert type(arg) == type(ret), f"""Called {str(func)} with input of type {type(arg)}
ret_list = ret if isinstance(ret, list) else [ret]
for r in ret_list:
assert type(arg) == type(r), f"""Called {str(func)} with input of type {type(arg)}
and output of type {type(ret)}. But expected types to match."""
# Need to run under no_dispatch, because we explicitly do **not**
# want our subclass to intercept the set_() call.
@ -211,7 +213,12 @@ and output of type {type(ret)}. But expected types to match."""
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
# This requires swapping the storage of out to be the same as inp,
# but we do *not* want it to change the sizes/strides that were compute for out.
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
if isinstance(ret, list):
for r in ret:
torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape)
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
@ -226,23 +233,6 @@ and output of type {type(ret)}. But expected types to match."""
if is_read_only_alias_match(schema_info.args[arg_idx], schema_info.outs[return_idx]):
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
# Sigh... the torchscript parser has a bug where alias annotations for Tensor[](a) don't show up properly
# See https://github.com/pytorch/pytorch/issues/106173
if func.overloadpacket in [
torch.ops.aten.chunk,
torch.ops.aten.tensor_split,
torch.ops.aten.split,
torch.ops.aten.split_with_sizes,
torch.ops.aten.hsplit,
torch.ops.aten.vsplit,
torch.ops.aten.dsplit,
torch.ops.aten.unbind,
]:
assert isinstance(outs, list) and all(isinstance(x, torch.Tensor) for x in outs)
for o in outs:
# For lists of outputs, need to alias every individual tensor to the input
alias_non_inplace_storage(args[0], o)
# This abstracts over the fact that in return_and_correct_aliasing,
# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
@ -267,7 +257,19 @@ def get_alias_info(func) -> SchemaInfo:
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
# properly for some ops that output tensorlists)
if func.namespace == "aten":
torchgen_schema = torchgen.model.FunctionSchema.parse(str(func._schema))
torchgen_schema_str = str(func._schema)
assert torchgen_schema_str.startswith("aten::")
# remove the aten:: namespace, which is added by the torchscript parser,
# and torchgen doesn't know how to handle
torchgen_schema_str = torchgen_schema_str[6:]
import re
# the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
# which torchgen chokes on.
torchgen_schema_str = re.sub(r'=\[[0, ]+\]', '=0', torchgen_schema_str)
torchgen_schema_str = re.sub(r'=\[[1, ]+\]', '=1', torchgen_schema_str)
# for aten::rot90
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [AliasInfo(
alias_set=set() if a.annotation is None else set(a.annotation.alias_set),
is_write=a.annotation is not None and a.annotation.is_write
@ -331,7 +333,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
# Fix up the storages of any outs so that they point to the same storage as the input,
# if func is a view op.
_correct_storage_aliasing(func, schema_info, args, [out] if not isinstance(out, (list, tuple)) else out)
_correct_storage_aliasing(func, schema_info, args, (out,) if not isinstance(out, tuple) else out)
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
# metadata is set correctly.