mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
### Background:
`set(x,y)` changes the untyped storage of x to be the same as y.
```python
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
x1 = torch.ones(2,3)
y1 = torch.ones(2,3)
z1 = torch.ops.aten.set_.source_Tensor(x1, y1)
fake_tensor_mode = FakeTensorMode()
x2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
y2 = fake_tensor_mode.from_tensor(torch.ones(2,3))
z2 = torch.ops.aten.set_.source_Tensor(x2, y2)
print(f"x1: {x1.untyped_storage()._cdata}, y1: {y1.untyped_storage()._cdata}, z1: {z1.untyped_storage()._cdata}")
print(f"x2: {x2.untyped_storage()._cdata}, y2: {y2.untyped_storage()._cdata}, z2: {z2.untyped_storage()._cdata}")
# x1: 99973024, y1: 99973024, z1: 99973024
# x2: 112107232, y2: 112107232, z2: 112107232
```
### Error before this diff
Consider this example:
```python
import torch
def fn(x):
p = torch.nn.Parameter(x + 123)
return p, p.sin()
opt = torch.compile(fn, fullgraph=True)
x = torch.ones(16, device="cuda", requires_grad=True)
p, r = opt(x)
r.sum().backward()
```
When running with `TORCH_LOGS=aot`, we have `set_` in the graph.
```
def forward(self, primals_1: "f32[16][1]cuda:0", primals_2: "f32[16][1]cuda:0"):
# File: /home/boyuan/playground/inductor/donated_buffer.py:4 in fn, code: p = torch.nn.Parameter(x + 123)
add: "f32[16][1]cuda:0" = torch.ops.aten.add.Tensor(primals_1, 123); primals_1 = None
# File: /home/boyuan/playground/inductor/donated_buffer.py:5 in fn, code: return p, p.sin()
sin: "f32[16][1]cuda:0" = torch.ops.aten.sin.default(add)
# No stacktrace found for following nodes
set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add); primals_2 = set_ = None
return (sin, add)
```
`set_: "f32[16][1]cuda:0" = torch.ops.aten.set_.source_Tensor(primals_2, add)` should change the storage of `primals_2` to be the same as `add`. However, this is not true before this diff. We found different untyped_storage() for meta['val'] of `set_`, `add`, and `primals_2`.
This also leads to an error with donated buffer (#130580), which checks alias by untyped_storage. Since `add` and `primals_2` have different untyped_storage (which is wrong), add is wrongly marked as donated buffer.
### Root Cause
During tracing, we have args, kwargs, out, and proxy_args, proxy_kwargs, proxy_out.
We use args and kwargs to compute `out = func(*args, **kwargs)` ([Here](https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/proxy_tensor.py#L912)). Later, we set out to its proxy, essentially calling `proxy_out.node.meta["val"] = out.detach()`.
Due to the detach, the storage change happens on args but not on proxy_args.node.meta["val"] when func is torch.ops.aten.set_. I repro'ed this behavior of detach in eager code.
```python
import torch
x = torch.ones(2,3)
x_detach = x.detach()
y = torch.ones(2,3)
z = torch.ops.aten.set_.source_Tensor(x_detach, y)
print(f"x: {x.untyped_storage()._cdata}, x_detach: {x_detach.untyped_storage()._cdata}, y: {y.untyped_storage()._cdata}, z: {z.untyped_storage()._cdata}")
# x: 97023632, x_detach: 97026480, y: 97026480, z: 97026480
```
To fix the issue, this PR manually resets node.meta["val"] if the storage has changed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141308
Approved by: https://github.com/bdhirsh
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
import threading
|
|
from contextlib import contextmanager
|
|
from typing import Any, Generator, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
# See [Note: Metadata mutation in proxy tracing] for why sacrificial parameter mutates
|
|
# metadata during proxy tracing and we should remove the sacrificial parameter logic.
|
|
doc = """
|
|
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
|
|
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
|
|
becomes a graph arg and has no storage backing it. At the point in the graph where the parameter
|
|
actually should be created we mutate this sacrificial placeholder into it. This allows gradients
|
|
to flow into the parameter as if it were an input to the graph (which is the only thing we are
|
|
allowed to compute gradients on).
|
|
""".strip()
|
|
|
|
|
|
class TracableCreateParameter(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx: Any, tensor: Any, placeholder: Any) -> torch.nn.Parameter:
|
|
assert not tensor.requires_grad
|
|
return placeholder.set_(tensor)
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grad_outputs: torch.Tensor) -> Tuple[None, torch.Tensor]:
|
|
grad = grad_outputs[0]
|
|
return None, grad # grad flows to placeholder
|
|
|
|
|
|
def tracable_create_parameter(
|
|
tensor: torch.Tensor, placeholder: torch.nn.Parameter
|
|
) -> torch.nn.Parameter:
|
|
with torch.set_grad_enabled(placeholder.requires_grad):
|
|
out = TracableCreateParameter.apply(tensor, placeholder)
|
|
return out
|
|
|
|
|
|
def new_parameter_placeholder(
|
|
size: Tuple[int, ...], dtype: torch.dtype, device: torch.device, requires_grad: bool
|
|
) -> torch.nn.Parameter:
|
|
"""Create a placeholder to be passed to the above functions"""
|
|
result = torch.nn.Parameter(
|
|
torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
|
|
)
|
|
# TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
|
|
# Allocating a zero tensor would causes assert failures in autograd.
|
|
result.untyped_storage().resize_(0)
|
|
return result
|
|
|
|
|
|
_TLS = threading.local()
|
|
|
|
|
|
@contextmanager
|
|
def do_not_convert_to_tracable_parameter() -> Generator[bool, None, None]:
|
|
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
|
|
_TLS.convert_tracable_parameter = False
|
|
try:
|
|
yield False
|
|
finally:
|
|
_TLS.convert_tracable_parameter = old_flag
|
|
|
|
|
|
def can_convert_to_tracable_parameter() -> bool:
|
|
return getattr(_TLS, "convert_tracable_parameter", True)
|