mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
80 lines
2.9 KiB
Python
80 lines
2.9 KiB
Python
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
|
|
|
|
|
# A simple tensor subclass that holds two tensors internally, and runs every op on both tensors.
|
|
class TwoTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, a, b):
|
|
assert (
|
|
a.device == b.device
|
|
and a.layout == b.layout
|
|
and a.requires_grad == b.requires_grad
|
|
and a.dtype == b.dtype
|
|
)
|
|
# I guess it would be more accurate to represent the shape as torch.cat(a, b).shape
|
|
shape = a.shape
|
|
kwargs = {}
|
|
kwargs["strides"] = a.stride()
|
|
kwargs["storage_offset"] = a.storage_offset()
|
|
kwargs["device"] = a.device
|
|
kwargs["layout"] = a.layout
|
|
kwargs["requires_grad"] = a.requires_grad
|
|
kwargs["dtype"] = a.dtype
|
|
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
|
|
|
assert a.shape == b.shape
|
|
assert a.stride() == b.stride()
|
|
assert a.storage_offset() == b.storage_offset()
|
|
return out
|
|
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def __repr__(self):
|
|
a_repr = repr(self.a)
|
|
b_repr = repr(self.b)
|
|
return f"TwoTensor({a_repr}, {b_repr})"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ["a", "b"], None
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
|
assert meta is None
|
|
a, b = inner_tensors["a"], inner_tensors["b"]
|
|
return TwoTensor(a, b)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
args_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, args)
|
|
args_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, args)
|
|
|
|
kwargs_a = pytree.tree_map_only(TwoTensor, lambda x: x.a, kwargs)
|
|
kwargs_b = pytree.tree_map_only(TwoTensor, lambda x: x.b, kwargs)
|
|
|
|
out_a = func(*args_a, **kwargs_a)
|
|
out_b = func(*args_b, **kwargs_b)
|
|
assert type(out_a) == type(out_b)
|
|
out_a_flat, spec = pytree.tree_flatten(out_a)
|
|
out_b_flat = pytree.tree_leaves(out_b)
|
|
# for aten ops that return non-tensors, just assume that
|
|
# our two inner tensors return the same value
|
|
out_flat = [
|
|
TwoTensor(o_a, o_b) if isinstance(o_a, torch.Tensor) else o_a
|
|
for o_a, o_b in zip(out_a_flat, out_b_flat)
|
|
]
|
|
out = pytree.tree_unflatten(out_flat, spec)
|
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
|
|
|
|
|
class TwoTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
out = func(*args, **kwargs)
|
|
if torch._subclasses.fake_tensor._is_tensor_constructor(func):
|
|
out = TwoTensor(out, out.clone())
|
|
return out
|