mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134247
Approved by: https://github.com/tianyu-l
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| test_micro_pipeline_tp.py | ||
| test_parallelize_api.py | ||
| test_tp_examples.py | ||
| test_tp_random_state.py | ||
| test_tp_style.py | ||