pytorch/test/distributed/_tensor
Brian Hirsh 63dcb5b0f2 make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)
Fixes https://github.com/pytorch/pytorch/issues/122459, https://github.com/pytorch/torchtrain/issues/61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
2024-04-15 17:23:20 +00:00
..
debug get CommsDebugMode to work with DTensor (#118769) 2024-02-29 01:11:05 +00:00
experimental [functional collective] change the Python APIs to only use the native funcol ops (#123777) 2024-04-13 03:08:36 +00:00
__init__.py
README.md
test_api.py nn.Module: use swap_tensors for Tensor subclasses (#122755) 2024-03-28 02:03:09 +00:00
test_attention.py DTensor: add ring attention for _scaled_dot_product_flash_attention (#122460) 2024-04-03 06:45:00 +00:00
test_common_rules.py [dtensor] refactor schema suggestions in output sharding (#122929) 2024-04-01 17:39:39 +00:00
test_convolution_ops.py
test_dtensor.py [functional collective] change the Python APIs to only use the native funcol ops (#123777) 2024-04-13 03:08:36 +00:00
test_dtensor_compile.py make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347) 2024-04-15 17:23:20 +00:00
test_dtensor_ops.py [dtensor] switch aten.t to use op strategy (#122950) 2024-04-01 17:39:43 +00:00
test_embedding_ops.py [dtensor] implement dim-0 (row) embedding sharding with MaskPartial (#118080) 2024-01-26 19:01:24 +00:00
test_experimental_ops.py
test_init.py
test_math_ops.py [dtensor][TP] check funcol calls and improve doc for loss parallel (#121366) 2024-03-08 01:41:31 +00:00
test_matrix_ops.py [functional collective] change the Python APIs to only use the native funcol ops (#123777) 2024-04-13 03:08:36 +00:00
test_op_strategy.py [dtensor] refactor sharding cost model to count for latency (#119897) 2024-02-15 00:35:56 +00:00
test_optimizers.py [DTensor] Enable ASGD foreach optimizer and add the associated unit test (#121942) 2024-03-15 20:21:27 +00:00
test_pointwise_ops.py [nit][DTensor][Test] Update test name to reflect the actual test (#118960) 2024-02-18 08:23:06 +00:00
test_random_ops.py
test_redistribute.py [dtensor] move early return check into redistribute autograd function (#121653) 2024-03-12 17:37:30 +00:00
test_tensor_ops.py [dtensor] refactor and generalize stack strategy (#121869) 2024-03-15 00:34:25 +00:00
test_utils.py
test_view_ops.py [dtensor] add op support for view_as_complex and view_as_real (#122569) 2024-03-26 03:32:04 +00:00
test_xla_integration.py [DTensor][XLA] support XLA backend in distirbute_module API (#121355) 2024-03-08 15:47:33 +00:00

Run distributed tensor tests:

from root, run (either CPU or GPU)

pytest test/spmd/tensor/test_tensor.py

pytest test/spmd/tensor/test_ddp.py

run specific test case and print stdout/stderr:

pytest test/spmd/tensor/test_tensor.py -s -k test_tensor_from_local