To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.
Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.
And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.
## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.
Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol