pytorch/torch/_dynamo/create_parameter_op.py
Boyuan Feng 3ef031909f [Donated Buffer] support metadata mutation ops (#141308)
### Background:

`set(x,y)` changes the untyped storage of x to be the same as y.

```python
import torch
from torch._subclasses.fake_tensor import FakeTensorMode

x1 = torch.ones(2,3)
y1 = torch.ones(2,3)
z1 = torch.ops.aten.set_.source_Tensor(x1, y1)

fake_tensor_mode = FakeTensorMode()
x2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
y2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
z2 = torch.ops.aten.set_.source_Tensor(x2, y2)

print(f"x1: {x1.untyped_storage()._cdata}, y1: {y1.untyped_storage()._cdata}, z1: {z1.untyped_storage()._cdata}")
print(f"x2: {x2.untyped_storage()._cdata}, y2: {y2.untyped_storage()._cdata}, z2: {z2.untyped_storage()._cdata}")
# x1: 99973024, y1: 99973024, z1: 99973024
# x2: 112107232, y2: 112107232, z2: 112107232
```

### Error before this diff

Consider this example:
```python
import torch

def fn(x):
    p = torch.nn.Parameter(x + 123)
    return p, p.sin()

opt = torch.compile(fn, fullgraph=True)
x = torch.ones(16, device="cuda", requires_grad=True)

p, r = opt(x)
r.sum().backward()
```

When running with `TORCH_LOGS=aot`, we have `set_` in the graph.
```
def forward(self, primals_1: "f32[16][1]cuda:0", primals_2: "f32[16][1]cuda:0"):
   # File: /home/boyuan/playground/inductor/donated_buffer.py:4 in fn, code: p = torch.nn.Parameter(x + 123)
  add: "f32[16][1]cuda:0" = torch.ops.aten.add.Tensor(primals_1, 123);  primals_1 = None

   # File: /home/boyuan/playground/inductor/donated_buffer.py:5 in fn, code: return p, p.sin()
  sin: "f32[16][1]cuda:0" = torch.ops.aten.sin.default(add)

  # No stacktrace found for following nodes
  set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add);  primals_2 = set_ = None
  return (sin, add)
```

`set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add)` should change the storage of `primals_2` to be the same as `add`. However, this is not true before this diff. We found different untyped_storage() for meta['val'] of `set_`, `add`, and `primals_2`.

This also leads to an error with donated buffer (#130580), which checks alias by untyped_storage. Since `add` and `primals_2` have different untyped_storage (which is wrong), add is wrongly marked as donated buffer.

### Root Cause

During tracing, we have args, kwargs, out, and proxy_args, proxy_kwargs, proxy_out.

We use args and kwargs to compute `out = func(*args, **kwargs)` ([Here](https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L912)). Later, we set out to its proxy, essentially calling `proxy_out.node.meta["val"] = out.detach()`.

Due to the detach, the storage change happens on args but not on proxy_args.node.meta["val"] when func is torch.ops.aten.set_. I repro'ed this behavior of detach in eager code.

```python
import torch

x = torch.ones(2,3)
x_detach = x.detach()
y = torch.ones(2,3)
z = torch.ops.aten.set_.source_Tensor(x_detach, y)

print(f"x: {x.untyped_storage()._cdata}, x_detach: {x_detach.untyped_storage()._cdata}, y: {y.untyped_storage()._cdata}, z: {z.untyped_storage()._cdata}")
# x: 97023632, x_detach: 97026480, y: 97026480, z: 97026480
```

To fix the issue, this PR manually resets node.meta["val"] if the storage has changed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141308
Approved by: https://github.com/bdhirsh
2024-11-26 17:06:46 +00:00

67 lines
2.4 KiB
Python

import threading
from contextlib import contextmanager
from typing import Any, Generator, Tuple
import torch
# See [Note: Metadata mutation in proxy tracing] for why sacrificial parameter mutates
# metadata during proxy tracing and we should remove the sacrificial parameter logic.
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
to flow into the parameter as if it were an input to the graph (which is the only thing we are
allowed to compute gradients on).
""".strip()
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
assert not tensor.requires_grad
return placeholder.set_(tensor)
@staticmethod
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> Tuple[None, torch.Tensor]:
grad = grad_outputs[0]
return None, grad # grad flows to placeholder
def tracable_create_parameter(
tensor: torch.Tensor, placeholder: torch.nn.Parameter
) -> torch.nn.Parameter:
with torch.set_grad_enabled(placeholder.requires_grad):
out = TracableCreateParameter.apply(tensor, placeholder)
return out
def new_parameter_placeholder(
size: Tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
) -> torch.nn.Parameter:
"""Create a placeholder to be passed to the above functions"""
result = torch.nn.Parameter(
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
)
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result
_TLS = threading.local()
@contextmanager
def do_not_convert_to_tracable_parameter() -> Generator[bool, None, None]:
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
_TLS.convert_tracable_parameter = False
try:
yield False
finally:
_TLS.convert_tracable_parameter = old_flag
def can_convert_to_tracable_parameter() -> bool:
return getattr(_TLS, "convert_tracable_parameter", True)