Commit graph

585 commits

Author SHA1 Message Date
Andrew Gu
c30659ffcc [ZeRO] (Reland) Add ctor support for multiple param groups (#72932)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.

**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).

To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.

**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932

Reviewed By: rohan-varma

Differential Revision: D34281482

Pulled By: awgu

fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc6349ff1aad403563206fb170a3af0c70)
2022-02-22 16:29:55 +00:00
Michael Suo
bf03d93496 Revert D33919683: [FSDP] Implement local_state_dict and load_local_state_dict
Test Plan: revert-hammer

Differential Revision:
D33919683 (d50643adcd)

Original commit changeset: c9f1b43ce04d

Original Phabricator Diff: D33919683 (d50643adcd)

fbshipit-source-id: c54c181edf8eb6a3bc509ed54d34ffdce11b93f5
(cherry picked from commit 4dfb50cd0d86abfb17fcfbecd1f42a2dc633afb9)
2022-02-20 02:32:48 +00:00
Chien-Chin Huang
d50643adcd [FSDP] Implement local_state_dict and load_local_state_dict (#72469)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72469

1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149559985

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D33919683

fbshipit-source-id: c9f1b43ce04da7db65c4aebf6ac2c7a0ac5e9de8
(cherry picked from commit 55fd6230c9656fdf30a70dcd8071d094d2e67022)
2022-02-19 20:29:27 +00:00
Rohan Varma
209a948896 [Reland][FSDP] Implement apply() (#72925)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72925

Reland with fix to add the owner string in test file
ghstack-source-id: 149280348

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34273858

fbshipit-source-id: 2174c1d71fcc5148282d94e375071a50b92114f2
(cherry picked from commit 158762bbb36f9652d93b3f23beca51c319435cc7)
2022-02-17 21:50:03 +00:00
Philip Meier
b5f2574f36 no longer coalesce sparse COO tensors before comparison (#69751)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69751

cc nikitaved pearu cpuhrsch IvanYashchuk

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D34262453

Pulled By: ezyang

fbshipit-source-id: e2e62d2aa03fc569d2951c880960b256f5dc4aaa
(cherry picked from commit cb6b0ef7198c5252c51a8fec1c19e3c17b33cc87)
2022-02-17 02:33:08 +00:00
Junjie Wang (PyTorch)
b02c514764 [PT-D][Sharded Tensor] new init api for local tensor and sharding spec auto inference (#72733)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72733

To improve the perf cost due to communication in the process of init the sharded tensor. There are two changes in this PR/diff:

1. We create a new API named `_init_from_local_tensor` so that if we have only one local tensor, we can initiate a sharded tensor directly from it. (GH issue: https://github.com/pytorch/pytorch/issues/72092)
2. We create a new API to infer the sharding spec from global meta data, so we don't have to manually set the sharding spec when it's not `EnumerableShardingSpec`. (GH issue: https://github.com/pytorch/pytorch/issues/67244)
ghstack-source-id: 149229259

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D34132739

fbshipit-source-id: 3a60135761bcc19d6020b6c45cb2979869645ce6
(cherry picked from commit af569325e2794309a4a86e51749642a062a25f6e)
2022-02-16 17:42:39 +00:00
Nikita Shulga
ccdff4c480 Revert D34111109: [FSDP] Implement apply()
Test Plan: revert-hammer

Differential Revision:
D34111109 (1f29b3130a)

Original commit changeset: 60d9d3f5c4d6

Original Phabricator Diff: D34111109 (1f29b3130a)

fbshipit-source-id: d959533f656a1fa69b2af7c029130f674fdd6023
(cherry picked from commit b0d3e2b1c368dea84b94cfa2a06c9e02c5a66906)
2022-02-16 15:49:04 +00:00
Rohan Varma
1f29b3130a [FSDP] Implement apply() (#72600)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72600

Implements `apply()` which applies a `callable` of signature `f(m: Module) -> None` recursively to every submodule. The main difference from `nn.module.apply` is that this version summons the full parameters before apply() so it works appropriately with FSDP.
ghstack-source-id: 149217423

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34111109

fbshipit-source-id: 60d9d3f5c4d6c27763f5d68728dfb0bae3d9f644
(cherry picked from commit b20c65e06070f27fda0e5260f5cbbb41e3e33f46)
2022-02-16 08:35:35 +00:00
Nikita Shulga
84cb810b3f Revert D34106940: [ZeRO] Add ctor support for multiple param groups
Test Plan: revert-hammer

Differential Revision:
D34106940 (5dd0732457)

Original commit changeset: 7e70fc0b3cec

Original Phabricator Diff: D34106940 (5dd0732457)

fbshipit-source-id: 08f846c9c02be8756475f4e0b57eb381f10c27bd
(cherry picked from commit 7675497d8358cb289549539dae98579353d85834)
2022-02-16 03:45:15 +00:00
Rohan Varma
aeacf910b5 [Checkpoint] Rename file (#72748)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72748

Removes underscore from file/class as directory is already private
ghstack-source-id: 149109295

Test Plan: Ci

Reviewed By: samdow

Differential Revision: D34179308

fbshipit-source-id: 8e956f3c83f21159c5e0fcdce09624ecb8a73ac0
(cherry picked from commit adfd8bc357b2ee4920054a3c984464b51daf0e35)
2022-02-16 00:08:23 +00:00
Rohan Varma
08889b24df [FSDP] Improved shape unflattening test (#72573)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72573

Verify shapes are restored appropriately in this test.
ghstack-source-id: 149111043

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34101125

fbshipit-source-id: 94260da2b7420cf58c5569e596885aa65fe7726e
(cherry picked from commit e57a30e8e4caea0593836e52084194a3d3497b72)
2022-02-16 00:03:45 +00:00
Rohan Varma
b01d1ad171 [FSDP] Fix summon_full_params when not sharded (#72572)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72572

Use `continue` instead of `pass` which would result in AttributeError because `_full_param_padded` is not created for unsharded parameter when world_size == 1. Add a test to cover this case.
ghstack-source-id: 149111044

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D34101124

fbshipit-source-id: 71d82bf94a091ef90f52b31c213192a5dd547332
(cherry picked from commit cc7899a5eaf5bc091eb772ade68a0a24a1fdab80)
2022-02-16 00:03:45 +00:00
Andrew Gu
5dd0732457 [ZeRO] Add ctor support for multiple param groups (#72578)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72578

**Overview**
This adds `ZeroRedundancyOptimizer` constructor support for multiple parameter groups (i.e. passing an `iterable` of `dict`s instead of an `iterable` of `torch.Tensor` as the `parameters` argument) to mirror the API for non-sharded optimizers.

Fixes https://github.com/pytorch/pytorch/issues/71347 and https://github.com/pytorch/pytorch/issues/59973.

This modifies `test_collect_shards()` to skip if ROCm.

**Test Plan**
I adjusted the existing constructor test, and I added a test for parity between constructing with two parameter groups up front versus constructor with one parameter group and adding the second parameter group after (via `add_param_group()`) versus a non-sharded optimizer.

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D34106940

Pulled By: awgu

fbshipit-source-id: 7e70fc0b3cec891646e0698eaedf02ff4354c128
(cherry picked from commit 40f2d45172ba3286b64000a466e42c055cca8ddc)
2022-02-15 16:51:30 +00:00
Chien-Chin Huang
c73cc92eff [FSDP] Use unflatten_parameter in _summon_full_parameters (#72467)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72467

`_unflatten_params_as_views` does not (and cannot) delete `flat_param`. Using `_unflatten_params_as_views` to implement `_summon_full_parameters` will make the traverse of the full parameter failed -- will get the redundent flat_param.
ghstack-source-id: 148959167

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D33989893

fbshipit-source-id: 698c97766266be01d5b567b5d5f3b2fdbf24063d
(cherry picked from commit f464c6f3909e408807594907896fd66fe639e3cb)
2022-02-12 09:13:41 +00:00
Andrew Gu
426f50e5b2 [FSDP] Add no_sync() context manager (#72446)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72446

**Overview**
This addresses https://github.com/pytorch/pytorch/issues/72183 and upstreams the `no_sync()` context manager for `FullyShardedDataParallel`.

**Test Plan**
- `_test_no_sync()` is generalized from Fairscale (see [here](89e1ae5f16/tests/nn/data_parallel/test_fsdp_grad_acc.py (L66))).
- `test_communication()` is generalized from Fairscale (see [here](89e1ae5f16/tests/nn/data_parallel/test_fsdp_grad_acc.py (L128))).

I tested for world sizes of 2 and 4 on the AWS cluster:
```
gpurun python test/distributed/fsdp/test_fsdp_no_sync.py
gpurun4 python test/distributed/fsdp/test_fsdp_no_sync.py
gpurun python test/distributed/fsdp/test_fsdp_comm.py
gpurun4 python test/distributed/fsdp/test_fsdp_comm.py
```

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D34085750

Pulled By: awgu

fbshipit-source-id: 8b492d8e941049a7f5ae211f3bb4042a57f5c217
(cherry picked from commit e14f1dce1a43c6a5389e534a8a176fc39ddb7396)
2022-02-11 05:16:17 +00:00
Alban Desmaison
7035738b50 Change ParameterList and ParameterDict to be able to contain any kind of objects (#70499)
Summary:
The only difference with plain list/dict now is that nn.Parameters are
handled specially and registered as parameters properly.

test_nn and parametrization works locally.
Will see in CI if DP is fixed as well.

Tentative fix for https://github.com/pytorch/pytorch/issues/36035

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70499

Reviewed By: jbschlosser, alexeib

Differential Revision: D34005332

Pulled By: albanD

fbshipit-source-id: 7e76b0873d0fec345cb537e2a6ecba0258e662b9
(cherry picked from commit dc1e6f8d86e60c9bdab9271826789c2e71a013e2)
2022-02-09 18:52:29 +00:00
Chien-Chin Huang
224093db11 [FSDP] Add FlatParameter to track the information of a flat parameter (#69241)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69241

Implement FlatParameter to track the information of a flat parameter, including the sharding information.

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32432503

fbshipit-source-id: b4aabba6cef29e825b45869895709c79e69c211d
(cherry picked from commit 0e5505f70b69ecfeeec2c70688d4fc8f4a35e417)
2022-02-07 18:51:17 +00:00
Rodrigo Kumpera
b2116f5847 Port FSDP::summon_full_params from fairscale to pytorch. (#71225)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71225

Bring FSDP::summon_full_params from fairscale.

It doesn't support summoning with full precision as this mode is not yet supported by other parts of PT's FSDP port.

One thing that needs figuring out is the semantics we want W.R.T. reshard_after_forward. Right now I'm always discarding the full tensor at the end of _summon_full_params.

Fixes: https://github.com/pytorch/pytorch/issues/69779

Test Plan: Ported the fairscale tests plus added a few more.

Reviewed By: zhaojuanmao, rohan-varma

Differential Revision: D33350378

fbshipit-source-id: d826b7cc1762baa1e6a820651beb715c6428482a
(cherry picked from commit 23c78adda226e57528b3c48238f35ca55d04ba05)
2022-02-04 19:03:31 +00:00
Pritam Damania
8c505bbc86 Make ShardedTensor ctor more inline with torch.Tensor ctor (#72164)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72164

torch.Tensor ctor creates an empty tensor and this PR makes
ShardedTensor on par with that.

In particular we remove TensorInitParams and instead always a create an empty
tensor and then fill it in for things like ones, zeros, full etc. This is
inline with torch.ones etc. as well since even for those APIs we first create
an empty tensor and then fill it out.
ghstack-source-id: 148318045

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33934603

fbshipit-source-id: 5655bbd726f29e74600ebe9f33f9dc5952b528f4
(cherry picked from commit 78b301c78c9d5046e2f0a9818dcbc2cc45e7cdd0)
2022-02-04 01:16:25 +00:00
Junjie Wang (PyTorch)
88547396eb [PT-D] Enable megatron-lm style MLP layers (Changes mainly on sharded linear op) (#69735)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735

We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like

The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.

We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.

The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049

Test Plan: Unit test + CI

Reviewed By: pritamdamania87

Differential Revision: D32978221

fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe75202493c8313502e1b22d634568e1b225)
2022-02-03 06:12:15 +00:00
Junjie Wang (PyTorch)
19d0de8a57 [PT-D][RFC] Resharding related API implement for ShardedTensor and Partial Tensor (#70079)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70079

We defined a new concept named `PartialTensor`, which is an abstraction to represent Tensors that need aggregation across multiple devices and multiple processes.

We also defined a API `reshard_output` to reshard a `PartialTensor` to `Tensor` or reshard a `ShardedTensor` to `ShardedTensor/Tensor`. This is done via class `ModuleResharder` which acts like a wrapper of original modules plus the a reshard in the final step.

The `reshard` logic is defined in each class (`ShardedTensor` and `PartialTensor`).
ghstack-source-id: 148273050

Test Plan: Unit test is in the next PR.

Reviewed By: pritamdamania87

Differential Revision: D33121037

fbshipit-source-id: 5f56617ea526b857c5b73df6e069697d428ec359
(cherry picked from commit 58b1457cbcfc9c0bfb3083ef07fbc9e60f0ba51e)
2022-02-03 05:26:02 +00:00
Yanli Zhao
2336571cb7 make fsdp folder to be public (#72084)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084

make fsdp folder to be public
ghstack-source-id: 148173447

Test Plan: unit tests

Reviewed By: mrshenli

Differential Revision: D33903417

fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe2f391941bb0afa611dd39994585513)
2022-02-02 15:50:14 +00:00
Pritam Damania
64670e414e [reland] Create torch.distributed._shard package. (#72141)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33904585

fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52)
2022-02-02 06:58:20 +00:00
Nikita Shulga
34494e6252 Back out "Create torch.distributed.shard package." (#72062)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72062

Original commit changeset: dc692b31e260

Original Phabricator Diff: D33755913 (87bbcf70f7)

Test Plan: CI

Reviewed By: pbelevich

Differential Revision: D33891115

fbshipit-source-id: 37286e03d743d8691319f07c95e9561d54f3d6d0
(cherry picked from commit 0c1b3fe00848a275d44d8c91fba91d3df6d4927f)
2022-01-31 18:29:27 +00:00
Pritam Damania
87bbcf70f7 Create torch.distributed.shard package. (#71742)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71742

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed.shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 147899768

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D33755913

fbshipit-source-id: dc692b31e2607063d55dfcb3db33ec53961d5a5b
(cherry picked from commit 5b6885f3587786217f8ce143f2329ceec618404e)
2022-01-29 00:48:06 +00:00
Rohan Varma
d0ff1f0013 [FSDP] Backward prefetch in recursive call (#71804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804

Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33782346

fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c8fb9879b6c92a934df1a7d2c9e3f3d)
2022-01-28 00:34:08 +00:00
Rohan Varma
a30b0cf52a [FSDP] Add/refactor unit test for wrap (#71803)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71803

1. Extra check for wrapping with override args,

2. Enhance UT to make sure
`wrap` doesn't wrap outside of ctx.
ghstack-source-id: 147753225

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33774512

fbshipit-source-id: 1f8d60bdf9b3ba257fee465064a0e25235b3622b
(cherry picked from commit 9ab775b29eddcd193c11398184bee8beffed0327)
2022-01-28 00:34:08 +00:00
Wanchao Liang
6feba4bc7e Implement scatter primitive for ProcessGroupNCCL (#70029)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029

This PR implements NCCL scatter and add scatter to ProcessGroupNCCL.

NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API.

1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp.
2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank.
ghstack-source-id: 147754837

Test Plan:
test_scatter_ops
test_scatter_stress
test_scatter_checks

Reviewed By: pritamdamania87

Differential Revision: D33154823

fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd
(cherry picked from commit 93201f9d4a87c556110e60ceb93826abd71cf518)
2022-01-27 19:37:55 +00:00
Wanchao Liang
9b53d3194c Implement gather primitive for ProcessGroupNCCL (#66745)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66745

This PR implement NCCL gather and add gather to ProcessGroupNCCL using nccl send/recv api.

NCCL doesn’t directly provide primitives for gather, so we need to be implemented on top of NCCL’s send/recv API.
1. In ProcessGroupNCCL.cpp, the outputTensors are first flattened, then inputTensors and outputFlattened are passed by the collective class to gather() function in nccl.cpp.
1. In nccl.cpp, gather is implemented using ncclSend/ncclRecv: all the ranks send inputTensor to the root rank, and the root rank uses a for loop to receive these inputTensors.
ghstack-source-id: 147754838

Test Plan:
test_gather_ops
test_gather_checks
test_gather_stress

Reviewed By: pritamdamania87

Differential Revision: D29616361

fbshipit-source-id: b500d9b8e67113194c5cc6575fb0e5d806dc7782
(cherry picked from commit d560ee732eb559782a2d1d88b3cf118dcfc404bc)
2022-01-27 19:37:55 +00:00
Michael Carilli
f37d2046f8 Implements allreduce_coalesced for ProcessGroupNCCL (#62140)
Summary:
Implements allreduce_coalesced for ProcessGroupNCCL as an NCCL group of allreduces on separate tensors, as proposed in https://github.com/pytorch/pytorch/issues/38995#issuecomment-882804595. In recent versions of NCCL, performance of grouped comms has improved significantly. A group can execute with just one kernel, so a grouped comm on a set of unflattened tensors can be more performant than flattening+a single flat nccl call.

The same approach can easily extend to broadcast_coalesced and reduce_coalesced.

I'm still not sure how (hypothetical) all_gather_coalesced and reduce_scatter_coalesced ops should be exposed or implemented, because we need to consider "_base" variants where the output or input tensor is pre-flattened. For example, https://github.com/pytorch/pytorch/issues/61781 effectively wants "allgather_base_coalesced".

I'm also not sure how the _multigpu variants should enter the picture. With the approach I've written here, ProcessGroupNCCL::allreduce accepts a vector of tensors that are either all on the same device (in which case it'll do an allreduce_coalesced) or all on different devices (in which case it'll do an allreduce_multigpu). In other words it can do _coalesced or _multigpu but not both at once.

for some reason github wont let me add agolynski to the reviewers

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62140

Reviewed By: fduwjj

Differential Revision: D33781010

Pulled By: cbalioglu

fbshipit-source-id: f0c233da9ebae57d7ccecf6d8dc432d936d4d3ce
(cherry picked from commit e43cb81d300bd9e9926f6e01ae77f4accb12c258)
2022-01-26 13:31:30 +00:00
Rohan Varma
ba08440e88 [Opt Overlap] Remove redundant tests (#71600)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71600

These tests in test_c10d_nccl test a subset of functionality that's
already covered by distributed_test.py, no need for these additional tests.
ghstack-source-id: 147458823

Test Plan: CI

Reviewed By: cbalioglu

Differential Revision: D33662679

fbshipit-source-id: 2d1c1223fdd72a851c537b4793a71d65190d2553
(cherry picked from commit 14565ac5a6e059ec06af8583fcefa80626c95990)
2022-01-23 00:04:32 +00:00
Rohan Varma
29a7cb41d8 [BE] Fix FSDP flaky test (#71525)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71525

Closes https://github.com/pytorch/pytorch/issues/71496. Use file init
for test as opposed to TCP init which runs into some port racing conditions as
seen in the failures for that issue.
ghstack-source-id: 147300691

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33676165

fbshipit-source-id: fcf83f7c7541d3521d3e38481195b0c7cb081691
(cherry picked from commit ea091c4af7d864e4d2ebcda6f72d04e17ae7bd82)
2022-01-21 21:00:13 +00:00
Pritam Damania
53b3904115 Fix memory leak in ShardedTensor. (#71445)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71445

A reference to the ShardedTensor was always added to the global map
`_sharded_tensor_map`, that never got cleaned up since the map always held a
reference to the ShardedTensor.

A couple of fixes for this:
1) Add to the global map only for `init_rrefs=True` since only this codepath
requires this.
2) Add a `weakref` to the global map to avoid having a reference to the
ShardedTensor forever that never gets cleaned up.
ghstack-source-id: 147299580

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33641013

fbshipit-source-id: c552fa3359186514445fd5715bec93f67dc2262d
(cherry picked from commit d25f1a645313dcbf8c37158d80c42c983262cec2)
2022-01-20 19:38:41 +00:00
Yanli Zhao
1c61d8c43f [PT1.11] make static graph to be stable (#71459)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71459

1. add static_graph feature to DDP constructor;
2. still keep _set_static_graph() API, so that existing use cases are not affected, also it can be called internally by DDP constructor
3. four cases are covered:
    static_graph = False, _set_static_graph() is called;
    static_graph = False, _set_static_graph() is not called;
    static_graph = True, _set_static_graph() is not called;
    static_graph = True, _set_static_graph() is called;
ghstack-source-id: 147263797

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33646738

fbshipit-source-id: 8c1730591152aab91afce7133d2adf1efd723855
(cherry picked from commit dc246a1129a8ce5f70e551d7d8e00e0dab8ec6af)
2022-01-20 19:38:41 +00:00
Pritam Damania
f5b19ba683 Additional unit test for sharded linear. (#70476)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476

1) Support a single dimension for inputs
2) Test several error cases

Partially addresses https://github.com/pytorch/pytorch/issues/65638
ghstack-source-id: 146307607

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33344357

fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f
(cherry picked from commit 96dfded5697e451b54f113f99b6d0da6f6af500d)
2022-01-20 01:23:44 +00:00
Rohan Varma
3b589c3497 [DDP Checkpointing] non-reentrant checkpoint tests (#69060)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69060

Saved variable hooks checkpointing was added in https://github.com/pytorch/pytorch/pull/69508, this PR adds some tests for DDP.

Specifically, we can support almost all DDP use cases with this new API, such as dynamic module with find_unused_parameters=True. One case remains to be supported, which is static_graph + non-reentrant based checkpointing. The underlying reason this does not work is https://github.com/pytorch/pytorch/issues/58111.
ghstack-source-id: 147219887

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32712126

fbshipit-source-id: ba5ae9ca77fd8929ee020c7dc97838bae9a1931b
(cherry picked from commit 9c7f93e21728d1627d85c351a21e7c8da832bff7)
2022-01-19 18:09:41 +00:00
Pritam Damania
b56ba296b1 Support multiple input dims for sharded linear. (#70266)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70266

Addresses some of the issues mentioned in
https://github.com/pytorch/pytorch/issues/65638. ShardedLinear implementation
only support 2D inputs.

On the other hand `nn.Linear` supports arbitrary dimensions for inputs and
outputs. As a result, in this PR I've added support to ensure that
ShardedLinear supports arbitrary input dims as well.
ghstack-source-id: 147206607

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33267630

fbshipit-source-id: 0460994c3aa33348b80547d9274206ef90cb29b6
(cherry picked from commit 7c289e1dbf491008e091ed0a49f98f2ebcfb4175)
2022-01-19 08:07:14 +00:00
Jane Xu
c4400fc431 Retire repeat_test_for_types (#71033)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/69865

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71033

Reviewed By: mruberry

Differential Revision: D33486370

Pulled By: janeyx99

fbshipit-source-id: 71f9383dbc1e00b572f26eb4f04d0a94c6759e35
2022-01-10 09:13:54 -08:00
Rodrigo Kumpera
2378421340 Implement torch.allclose for sharded tensor. (#70331)
Summary:
Implement torch.allclose op for sharded tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70331

Test Plan:
Automated test added.
pritamdamania87
Fixes https://github.com/pytorch/pytorch/issues/67112

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Reviewed By: pritamdamania87

Differential Revision: D33339137

Pulled By: kumpera

fbshipit-source-id: 4263e468eaa117317b190f69877bf3f8bbac5658
2022-01-07 08:37:04 -08:00
Xiang Gao
6e16c9bb1d Add support for deleteKey for FileStore (#69953)
Summary:
torch_ucc uses `deleteKey`, and trying to run PyTorch tests with torch_ucc leads to failure about `deleteKey not implemented for FileStore`.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69953

Reviewed By: ngimel

Differential Revision: D33458457

Pulled By: H-Huang

fbshipit-source-id: f46afd59f950722ae594d9aafb8843f14019e930
2022-01-07 06:20:59 -08:00
Adnios
a9c7d626e1 Add the maximize flag to AdamW (#70146)
Summary:
Related issue: https://github.com/pytorch/pytorch/issues/68052

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70146

Reviewed By: malfet

Differential Revision: D33254561

Pulled By: albanD

fbshipit-source-id: f190c836a4162f936c5953e076747c345df21421
2021-12-23 09:20:29 -08:00
Yanli Zhao
b15212c62b enable backward pass computation and communication overlap by prefetching all gather (#70235)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235

address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.

After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.

So putting  these two algorithms as two configurable experimental algorithms for now

prefetch full parameters at pre-backward hook:

It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.

To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.

The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.

In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.

Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.

Test Plan: unit tests

Reviewed By: rohan-varma

Differential Revision: D33252795

fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
2021-12-22 23:02:46 -08:00
Wanchao Liang
82c5f298ed [shard] fix named_params_with_sharded_tensor (#70228)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70228

fix named_params_with_sharded_tensor impl, where `named_parameters` already loop the submodules recursively, so we shouldn't put it in the submodule loop.
ghstack-source-id: 146076471

Test Plan: Added more complicated test cases (that involves multiple submodules) to capture this issue.

Reviewed By: pritamdamania87

Differential Revision: D33251428

fbshipit-source-id: cf24ca7fbe4a5e485fedd2614d00cdea2898239e
2021-12-21 15:29:38 -08:00
Pritam Damania
0544f975e1 [reland] Support torch.equal for ShardedTensor. (#70145)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70145

Added support for torch.equal to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
ghstack-source-id: 146066939

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33201714

fbshipit-source-id: 56adfc36e345d512c9901c56c07759bf658c745b
2021-12-21 13:22:52 -08:00
Rohan Varma
a197f3fe52 [FSDP/Checkpoint] Activation offload support in checkpoint_wrapper (#70165)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70165

Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33228820

fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
2021-12-21 10:08:18 -08:00
Pritam Damania
b199e3c842 Provide functionality to write custom ShardedTensor ops. (#69874)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69874

We have a handful of ops supported for ShardedTensor via
``__torch_function__`` dispatch. However, we currently can't cover all torch
operators and having a way for users to extend this functionality will make
this functionality much more general.

In this PR, I've introduced a custom_sharded_op decorator which can be used to
register a custom sharded op implementation.
ghstack-source-id: 145841141

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33078587

fbshipit-source-id: 5936b7ac25582e613653c19afa559219719ee54b
2021-12-16 12:40:13 -08:00
Rohan Varma
c4281cc92d Prototype checkpoint_wrapper (#69955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955

Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.

Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.

Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D33107276

fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
2021-12-16 09:59:19 -08:00
Junjie Wang (PyTorch)
5cc4037369 [PyTorch][Distributed] Integrate with ShardedOptimizer in the unit test of ShardedLinear (#69569)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69569

Since ShardedOptimizer is added in https://github.com/pytorch/pytorch/pull/68607. We now integrate it in our unit test for Sharded Linear.
ghstack-source-id: 145773749

Test Plan: CI + Unit test

Reviewed By: wanchaol

Differential Revision: D32777020

fbshipit-source-id: eb6b1bb0f6234976f024273833154cab274fed25
2021-12-15 17:55:01 -08:00
Michael Suo
a406a427ae Revert D33004315: Support torch.equal for ShardedTensor.
Test Plan: revert-hammer

Differential Revision:
D33004315 (1c4c81622c)

Original commit changeset: 786fe26baf82

Original Phabricator Diff: D33004315 (1c4c81622c)

fbshipit-source-id: e1dda70fea656834fdf0f2a9f874415f7b460c6e
2021-12-15 14:14:06 -08:00
Pritam Damania
1c4c81622c Support torch.equal for ShardedTensor. (#69734)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69734

Added support for `torch.equal` to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.

Will implement `allclose` in a follow PR.
ghstack-source-id: 145301451

Test Plan: waitforbuildbot

Reviewed By: fduwjj, wanchaol

Differential Revision: D33004315

fbshipit-source-id: 786fe26baf82e1bb4fecfdbfc9ad4b64e704877f
2021-12-15 13:07:36 -08:00