From 71b4b320144c011695cf954a092f5d9a52dc1426 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 15 Sep 2023 09:58:21 -0700 Subject: [PATCH] return_and_correct_aliasing: massage some schemas to work with torchgen (#108897) This issue is that `str(torch.ops.aten.conv2d.default._schema)` does not return the same schema that is in native_functions.yaml ([link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L1654)). Torchscript appears to change the default arg string `int[2] strides=1` to `int[2] strides=[1, 1]`. If you try to parse that with torchgen, torchgen is unhappy (it tries to split arguments on comma, but now we have a comma inside of the default argument). Fixing the issue directly in torchgen was a bit more painful, so I opted just to undo the transformation that torchscript made: convert `=[1, 1]` back into `=1`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108897 Approved by: https://github.com/ezyang ghstack dependencies: #106404, #107917 --- test/test_python_dispatch.py | 10 ++++++++ torch/utils/_python_dispatch.py | 44 +++++++++++++++++---------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index e85a9ab85a9..b1716a1a165 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2249,6 +2249,16 @@ class TestWrapperSubclassAliasing(TestCase): kwargs = sample.kwargs self._test_wrapper_subclass_aliasing(op, args, kwargs) + def test_wrapper_subclass_aliasing_conv2d(self, device): + args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4)) + kwargs = {} + # conv2d has a default arg 'int[2] strides=0', + # which torchscript expands into 'int[2] strides=[0, 0]' + # Make sure that _return_and_correct_aliasing can handle this case + # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch) + with torch.inference_mode(): + self._test_wrapper_subclass_aliasing(torch.ops.aten.conv2d.default, args, kwargs) + instantiate_device_type_tests(TestWrapperSubclassAliasing, globals()) if __name__ == '__main__': diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 308057fa087..3660ddcb974 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -192,7 +192,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs): # plain tensors, we could remove the assert and just not perform the aliasing, # but it seems safer to learn more about this case first. if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret): - assert type(arg) == type(ret), f"""Called {str(func)} with input of type {type(arg)} + ret_list = ret if isinstance(ret, list) else [ret] + for r in ret_list: + assert type(arg) == type(r), f"""Called {str(func)} with input of type {type(arg)} and output of type {type(ret)}. But expected types to match.""" # Need to run under no_dispatch, because we explicitly do **not** # want our subclass to intercept the set_() call. @@ -211,7 +213,12 @@ and output of type {type(ret)}. But expected types to match.""" # Example: out = inp.expand(inp.shape[0], inp.shape[0]) # This requires swapping the storage of out to be the same as inp, # but we do *not* want it to change the sizes/strides that were compute for out. - torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape) + if isinstance(ret, list): + for r in ret: + torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape) + else: + assert isinstance(ret, torch.Tensor), f"type: {type(ret)}" + torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape) finally: torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) @@ -226,23 +233,6 @@ and output of type {type(ret)}. But expected types to match.""" if is_read_only_alias_match(schema_info.args[arg_idx], schema_info.outs[return_idx]): alias_non_inplace_storage(args[arg_idx], outs[return_idx]) - # Sigh... the torchscript parser has a bug where alias annotations for Tensor[](a) don't show up properly - # See https://github.com/pytorch/pytorch/issues/106173 - if func.overloadpacket in [ - torch.ops.aten.chunk, - torch.ops.aten.tensor_split, - torch.ops.aten.split, - torch.ops.aten.split_with_sizes, - torch.ops.aten.hsplit, - torch.ops.aten.vsplit, - torch.ops.aten.dsplit, - torch.ops.aten.unbind, - ]: - assert isinstance(outs, list) and all(isinstance(x, torch.Tensor) for x in outs) - for o in outs: - # For lists of outputs, need to alias every individual tensor to the input - alias_non_inplace_storage(args[0], o) - # This abstracts over the fact that in return_and_correct_aliasing, # we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy), # and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested). @@ -267,7 +257,19 @@ def get_alias_info(func) -> SchemaInfo: # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations # properly for some ops that output tensorlists) if func.namespace == "aten": - torchgen_schema = torchgen.model.FunctionSchema.parse(str(func._schema)) + torchgen_schema_str = str(func._schema) + assert torchgen_schema_str.startswith("aten::") + # remove the aten:: namespace, which is added by the torchscript parser, + # and torchgen doesn't know how to handle + torchgen_schema_str = torchgen_schema_str[6:] + import re + # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1], + # which torchgen chokes on. + torchgen_schema_str = re.sub(r'=\[[0, ]+\]', '=0', torchgen_schema_str) + torchgen_schema_str = re.sub(r'=\[[1, ]+\]', '=1', torchgen_schema_str) + # for aten::rot90 + torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]") + torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str) arg_schemas = [AliasInfo( alias_set=set() if a.annotation is None else set(a.annotation.alias_set), is_write=a.annotation is not None and a.annotation.is_write @@ -331,7 +333,7 @@ def return_and_correct_aliasing(func, args, kwargs, out): # Fix up the storages of any outs so that they point to the same storage as the input, # if func is a view op. - _correct_storage_aliasing(func, schema_info, args, [out] if not isinstance(out, (list, tuple)) else out) + _correct_storage_aliasing(func, schema_info, args, (out,) if not isinstance(out, tuple) else out) # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's # metadata is set correctly.