Commit graph

9 commits

Author SHA1 Message Date
fduwjj
23b7035b3c [TP] Add an input resharding wrapper for TP and unit test for 2D + AC (#103334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103334
Approved by: https://github.com/kumpera
2023-06-23 04:05:01 +00:00
fduwjj
953aa6d90e [TP] Enable more generic attn in Tensor Parallelism (#100508)
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.

Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.

And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
2023-05-07 18:15:49 +00:00
Xilun Wu
ce60997376 [BE][DTensor] validate the mesh argument in DeviceMesh construction (#99094)
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.

## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.

Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol
2023-04-21 23:47:51 +00:00
Wanchao Liang
16e7e5a24b [dtensor] lazy init process groups in device mesh (#96700)
This PR adds a private flag to allow process grou lazy initialization, this is
replacing the previous `dim_groups` arg, as no one is using that now

This could help avoid creating process groups when not necessary

Differential Revision: [D44044664](https://our.internmc.facebook.com/intern/diff/D44044664)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96700
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-03-20 17:50:04 +00:00
Xuehai Pan
046e88a291 [BE] [3/3] Rewrite super() calls in test (#94592)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-12 22:20:53 +00:00
fduwjj
3fb6e119e2 [PT-D][TP] Fix the module registration in TP API (#93412)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93412
Approved by: https://github.com/XilunWu
2023-02-01 21:03:56 +00:00
fduwjj
913866efbf [PT-D][TP] Fix TP API for FQN path based parallelization (#93029)
We have not tested dict based parallelize_module and turns out we had mistakes here.

1. Fix the error.
2. Add unit test cases for it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93029
Approved by: https://github.com/wz337
2023-01-26 09:10:21 +00:00
Wanchao Liang
ca5526cf1f [tp] ufmt test/distributed/tensor (#89970)
formatting stack to make dtensor and tp align with pytorch format standard.

cmd: `ufmt format test/distributed/tensor`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89970
Approved by: https://github.com/fduwjj
2022-12-01 20:58:16 +00:00
Wanchao Liang
4451eb24e6 Move tensor_parallel out to distributed.tensor folder (#89878)
This PR moves tensor parallel from torch.distributed._tensor.parallel
to torch.distributed.tensor.parallel, to prepare for beta release
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89878
Approved by: https://github.com/fduwjj
2022-11-30 22:13:10 +00:00
Renamed from test/distributed/_tensor/parallel/test_parallelize_api.py (Browse further)