mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
return_and_correct_aliasing: massage some schemas to work with torchgen (#108897)
This issue is that `str(torch.ops.aten.conv2d.default._schema)` does not return the same schema that is in native_functions.yaml ([link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L1654)). Torchscript appears to change the default arg string `int[2] strides=1` to `int[2] strides=[1, 1]`. If you try to parse that with torchgen, torchgen is unhappy (it tries to split arguments on comma, but now we have a comma inside of the default argument). Fixing the issue directly in torchgen was a bit more painful, so I opted just to undo the transformation that torchscript made: convert `=[1, 1]` back into `=1`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108897 Approved by: https://github.com/ezyang ghstack dependencies: #106404, #107917
This commit is contained in:
parent
0ad595954a
commit
71b4b32014
2 changed files with 33 additions and 21 deletions
|
|
@ -2249,6 +2249,16 @@ class TestWrapperSubclassAliasing(TestCase):
|
|||
kwargs = sample.kwargs
|
||||
self._test_wrapper_subclass_aliasing(op, args, kwargs)
|
||||
|
||||
def test_wrapper_subclass_aliasing_conv2d(self, device):
|
||||
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
|
||||
kwargs = {}
|
||||
# conv2d has a default arg 'int[2] strides=0',
|
||||
# which torchscript expands into 'int[2] strides=[0, 0]'
|
||||
# Make sure that _return_and_correct_aliasing can handle this case
|
||||
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
|
||||
with torch.inference_mode():
|
||||
self._test_wrapper_subclass_aliasing(torch.ops.aten.conv2d.default, args, kwargs)
|
||||
|
||||
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -192,7 +192,9 @@ def _correct_storage_aliasing(func, schema_info, args, outs):
|
|||
# plain tensors, we could remove the assert and just not perform the aliasing,
|
||||
# but it seems safer to learn more about this case first.
|
||||
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
|
||||
assert type(arg) == type(ret), f"""Called {str(func)} with input of type {type(arg)}
|
||||
ret_list = ret if isinstance(ret, list) else [ret]
|
||||
for r in ret_list:
|
||||
assert type(arg) == type(r), f"""Called {str(func)} with input of type {type(arg)}
|
||||
and output of type {type(ret)}. But expected types to match."""
|
||||
# Need to run under no_dispatch, because we explicitly do **not**
|
||||
# want our subclass to intercept the set_() call.
|
||||
|
|
@ -211,7 +213,12 @@ and output of type {type(ret)}. But expected types to match."""
|
|||
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
|
||||
# This requires swapping the storage of out to be the same as inp,
|
||||
# but we do *not* want it to change the sizes/strides that were compute for out.
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
|
||||
if isinstance(ret, list):
|
||||
for r in ret:
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(r, arg.untyped_storage(), r.storage_offset(), r.shape)
|
||||
else:
|
||||
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
|
||||
torch.ops.aten.set_.source_Storage_storage_offset(ret, arg.untyped_storage(), ret.storage_offset(), ret.shape)
|
||||
finally:
|
||||
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
|
||||
|
||||
|
|
@ -226,23 +233,6 @@ and output of type {type(ret)}. But expected types to match."""
|
|||
if is_read_only_alias_match(schema_info.args[arg_idx], schema_info.outs[return_idx]):
|
||||
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
|
||||
|
||||
# Sigh... the torchscript parser has a bug where alias annotations for Tensor[](a) don't show up properly
|
||||
# See https://github.com/pytorch/pytorch/issues/106173
|
||||
if func.overloadpacket in [
|
||||
torch.ops.aten.chunk,
|
||||
torch.ops.aten.tensor_split,
|
||||
torch.ops.aten.split,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.hsplit,
|
||||
torch.ops.aten.vsplit,
|
||||
torch.ops.aten.dsplit,
|
||||
torch.ops.aten.unbind,
|
||||
]:
|
||||
assert isinstance(outs, list) and all(isinstance(x, torch.Tensor) for x in outs)
|
||||
for o in outs:
|
||||
# For lists of outputs, need to alias every individual tensor to the input
|
||||
alias_non_inplace_storage(args[0], o)
|
||||
|
||||
# This abstracts over the fact that in return_and_correct_aliasing,
|
||||
# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
|
||||
# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
|
||||
|
|
@ -267,7 +257,19 @@ def get_alias_info(func) -> SchemaInfo:
|
|||
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
|
||||
# properly for some ops that output tensorlists)
|
||||
if func.namespace == "aten":
|
||||
torchgen_schema = torchgen.model.FunctionSchema.parse(str(func._schema))
|
||||
torchgen_schema_str = str(func._schema)
|
||||
assert torchgen_schema_str.startswith("aten::")
|
||||
# remove the aten:: namespace, which is added by the torchscript parser,
|
||||
# and torchgen doesn't know how to handle
|
||||
torchgen_schema_str = torchgen_schema_str[6:]
|
||||
import re
|
||||
# the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
|
||||
# which torchgen chokes on.
|
||||
torchgen_schema_str = re.sub(r'=\[[0, ]+\]', '=0', torchgen_schema_str)
|
||||
torchgen_schema_str = re.sub(r'=\[[1, ]+\]', '=1', torchgen_schema_str)
|
||||
# for aten::rot90
|
||||
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
|
||||
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
|
||||
arg_schemas = [AliasInfo(
|
||||
alias_set=set() if a.annotation is None else set(a.annotation.alias_set),
|
||||
is_write=a.annotation is not None and a.annotation.is_write
|
||||
|
|
@ -331,7 +333,7 @@ def return_and_correct_aliasing(func, args, kwargs, out):
|
|||
|
||||
# Fix up the storages of any outs so that they point to the same storage as the input,
|
||||
# if func is a view op.
|
||||
_correct_storage_aliasing(func, schema_info, args, [out] if not isinstance(out, (list, tuple)) else out)
|
||||
_correct_storage_aliasing(func, schema_info, args, (out,) if not isinstance(out, tuple) else out)
|
||||
|
||||
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
|
||||
# metadata is set correctly.
|
||||
|
|
|
|||
Loading…
Reference in a new issue