Summary:
Reland of https://github.com/pytorch/pytorch/pull/72578.
**Overview**
Windows CI was failing due to the multi-rank single-GPU case (see [here](https://github.com/pytorch/pytorch/runs/5204906995?check_suite_focus=true)).
To address this, I
- added `common_distributed.skip_if_no_gpu` for `test_multiple_param_groups()` to ensure that each rank can safely call `to(self.device)` -- this targets the expected SPSD use case where each rank has its own GPU;
- moved `test_constructor()` back to `TestZeroRedundancyOptimizerSingleRank` to check that the multiple parameter group method for construction works even on a single rank.
**Test Plan**
- I checked both tests for CPU, 1 GPU, 2 GPUs, 4 GPUs, and 8 GPUs.
- I added the `ciflow/win` label to run the failing Windows CI test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72932
Reviewed By: rohan-varma
Differential Revision: D34281482
Pulled By: awgu
fbshipit-source-id: c4fe604ddd9d2c123c3071249741e6b8a6454b6e
(cherry picked from commit 6bea9bcc6349ff1aad403563206fb170a3af0c70)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72469
1. Implement the framework to allow user to choose among `state_dict`, `local_state_dict`, and `sharded_state_dict`.
2. Implement ShardedTensor compatible local_state_dict() and load_local_state_dict().
ghstack-source-id: 149559985
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D33919683
fbshipit-source-id: c9f1b43ce04da7db65c4aebf6ac2c7a0ac5e9de8
(cherry picked from commit 55fd6230c9656fdf30a70dcd8071d094d2e67022)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72925
Reland with fix to add the owner string in test file
ghstack-source-id: 149280348
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34273858
fbshipit-source-id: 2174c1d71fcc5148282d94e375071a50b92114f2
(cherry picked from commit 158762bbb36f9652d93b3f23beca51c319435cc7)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72733
To improve the perf cost due to communication in the process of init the sharded tensor. There are two changes in this PR/diff:
1. We create a new API named `_init_from_local_tensor` so that if we have only one local tensor, we can initiate a sharded tensor directly from it. (GH issue: https://github.com/pytorch/pytorch/issues/72092)
2. We create a new API to infer the sharding spec from global meta data, so we don't have to manually set the sharding spec when it's not `EnumerableShardingSpec`. (GH issue: https://github.com/pytorch/pytorch/issues/67244)
ghstack-source-id: 149229259
Test Plan: CI
Reviewed By: wanchaol
Differential Revision: D34132739
fbshipit-source-id: 3a60135761bcc19d6020b6c45cb2979869645ce6
(cherry picked from commit af569325e2794309a4a86e51749642a062a25f6e)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72600
Implements `apply()` which applies a `callable` of signature `f(m: Module) -> None` recursively to every submodule. The main difference from `nn.module.apply` is that this version summons the full parameters before apply() so it works appropriately with FSDP.
ghstack-source-id: 149217423
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34111109
fbshipit-source-id: 60d9d3f5c4d6c27763f5d68728dfb0bae3d9f644
(cherry picked from commit b20c65e06070f27fda0e5260f5cbbb41e3e33f46)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72572
Use `continue` instead of `pass` which would result in AttributeError because `_full_param_padded` is not created for unsharded parameter when world_size == 1. Add a test to cover this case.
ghstack-source-id: 149111044
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D34101124
fbshipit-source-id: 71d82bf94a091ef90f52b31c213192a5dd547332
(cherry picked from commit cc7899a5eaf5bc091eb772ade68a0a24a1fdab80)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72578
**Overview**
This adds `ZeroRedundancyOptimizer` constructor support for multiple parameter groups (i.e. passing an `iterable` of `dict`s instead of an `iterable` of `torch.Tensor` as the `parameters` argument) to mirror the API for non-sharded optimizers.
Fixes https://github.com/pytorch/pytorch/issues/71347 and https://github.com/pytorch/pytorch/issues/59973.
This modifies `test_collect_shards()` to skip if ROCm.
**Test Plan**
I adjusted the existing constructor test, and I added a test for parity between constructing with two parameter groups up front versus constructor with one parameter group and adding the second parameter group after (via `add_param_group()`) versus a non-sharded optimizer.
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D34106940
Pulled By: awgu
fbshipit-source-id: 7e70fc0b3cec891646e0698eaedf02ff4354c128
(cherry picked from commit 40f2d45172ba3286b64000a466e42c055cca8ddc)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72467
`_unflatten_params_as_views` does not (and cannot) delete `flat_param`. Using `_unflatten_params_as_views` to implement `_summon_full_parameters` will make the traverse of the full parameter failed -- will get the redundent flat_param.
ghstack-source-id: 148959167
Test Plan: CI
Reviewed By: rohan-varma
Differential Revision: D33989893
fbshipit-source-id: 698c97766266be01d5b567b5d5f3b2fdbf24063d
(cherry picked from commit f464c6f3909e408807594907896fd66fe639e3cb)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72446
**Overview**
This addresses https://github.com/pytorch/pytorch/issues/72183 and upstreams the `no_sync()` context manager for `FullyShardedDataParallel`.
**Test Plan**
- `_test_no_sync()` is generalized from Fairscale (see [here](89e1ae5f16/tests/nn/data_parallel/test_fsdp_grad_acc.py (L66))).
- `test_communication()` is generalized from Fairscale (see [here](89e1ae5f16/tests/nn/data_parallel/test_fsdp_grad_acc.py (L128))).
I tested for world sizes of 2 and 4 on the AWS cluster:
```
gpurun python test/distributed/fsdp/test_fsdp_no_sync.py
gpurun4 python test/distributed/fsdp/test_fsdp_no_sync.py
gpurun python test/distributed/fsdp/test_fsdp_comm.py
gpurun4 python test/distributed/fsdp/test_fsdp_comm.py
```
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D34085750
Pulled By: awgu
fbshipit-source-id: 8b492d8e941049a7f5ae211f3bb4042a57f5c217
(cherry picked from commit e14f1dce1a43c6a5389e534a8a176fc39ddb7396)
Summary:
The only difference with plain list/dict now is that nn.Parameters are
handled specially and registered as parameters properly.
test_nn and parametrization works locally.
Will see in CI if DP is fixed as well.
Tentative fix for https://github.com/pytorch/pytorch/issues/36035
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70499
Reviewed By: jbschlosser, alexeib
Differential Revision: D34005332
Pulled By: albanD
fbshipit-source-id: 7e76b0873d0fec345cb537e2a6ecba0258e662b9
(cherry picked from commit dc1e6f8d86e60c9bdab9271826789c2e71a013e2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69241
Implement FlatParameter to track the information of a flat parameter, including the sharding information.
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32432503
fbshipit-source-id: b4aabba6cef29e825b45869895709c79e69c211d
(cherry picked from commit 0e5505f70b69ecfeeec2c70688d4fc8f4a35e417)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71225
Bring FSDP::summon_full_params from fairscale.
It doesn't support summoning with full precision as this mode is not yet supported by other parts of PT's FSDP port.
One thing that needs figuring out is the semantics we want W.R.T. reshard_after_forward. Right now I'm always discarding the full tensor at the end of _summon_full_params.
Fixes: https://github.com/pytorch/pytorch/issues/69779
Test Plan: Ported the fairscale tests plus added a few more.
Reviewed By: zhaojuanmao, rohan-varma
Differential Revision: D33350378
fbshipit-source-id: d826b7cc1762baa1e6a820651beb715c6428482a
(cherry picked from commit 23c78adda226e57528b3c48238f35ca55d04ba05)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72164
torch.Tensor ctor creates an empty tensor and this PR makes
ShardedTensor on par with that.
In particular we remove TensorInitParams and instead always a create an empty
tensor and then fill it in for things like ones, zeros, full etc. This is
inline with torch.ones etc. as well since even for those APIs we first create
an empty tensor and then fill it out.
ghstack-source-id: 148318045
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33934603
fbshipit-source-id: 5655bbd726f29e74600ebe9f33f9dc5952b528f4
(cherry picked from commit 78b301c78c9d5046e2f0a9818dcbc2cc45e7cdd0)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69735
We want to build a prototype of Megatron-LM so that we can apply PT-D op to models like transformer and other Meta flagship models like
The basic idea of Megatron-LM is as following:
1. Col-wise sharding of linear weight. Perform the linear op for the first layer.
2. Perform a math op (optional), such as ReLU or GeLU. We use GeLU in our example unit test. The input is from step 1.
3. Row-wise sharing of linear weight. Perform the linear op for the second layer. The input is from step 2.
We then save communications to concatenate the col-wise sharding results and spreading the input to different ranks for row-wise sharding.
The change is as following:
1. Return a ShardedTensor for the col-wise sharding in the sharded_linear op.
2. Return a PartialTensors for the row-wise sharding in the sharded_linear op.
3. Leverage APIs already defined for `reshard` to merge/aggregate local results to a fully sync local result if needed.
4. Add helper function to create sharded tensor based on the local result.
5. Add a unit test to test the Megatron-LM idea mentioned above and compare with local ops, including the grad and optimizer so that we can ensure the correctness of the implementation.
6. Refactor the unit test of sharded linear to reflect the changes in the code.
ghstack-source-id: 148273049
Test Plan: Unit test + CI
Reviewed By: pritamdamania87
Differential Revision: D32978221
fbshipit-source-id: 565fc92e7807e19d53b0261f8ace3945bef69e3e
(cherry picked from commit 344abe75202493c8313502e1b22d634568e1b225)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70079
We defined a new concept named `PartialTensor`, which is an abstraction to represent Tensors that need aggregation across multiple devices and multiple processes.
We also defined a API `reshard_output` to reshard a `PartialTensor` to `Tensor` or reshard a `ShardedTensor` to `ShardedTensor/Tensor`. This is done via class `ModuleResharder` which acts like a wrapper of original modules plus the a reshard in the final step.
The `reshard` logic is defined in each class (`ShardedTensor` and `PartialTensor`).
ghstack-source-id: 148273050
Test Plan: Unit test is in the next PR.
Reviewed By: pritamdamania87
Differential Revision: D33121037
fbshipit-source-id: 5f56617ea526b857c5b73df6e069697d428ec359
(cherry picked from commit 58b1457cbcfc9c0bfb3083ef07fbc9e60f0ba51e)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72084
make fsdp folder to be public
ghstack-source-id: 148173447
Test Plan: unit tests
Reviewed By: mrshenli
Differential Revision: D33903417
fbshipit-source-id: 7852a2adc4af09af48a5ffa52ebf210489f834d5
(cherry picked from commit bd06513cfe2f391941bb0afa611dd39994585513)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141
We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.
As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33904585
fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71742
We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.
As a result, organizing all of this under the `torch.distributed.shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 147899768
Test Plan: waitforbuildbot
Reviewed By: fduwjj, wanchaol
Differential Revision: D33755913
fbshipit-source-id: dc692b31e2607063d55dfcb3db33ec53961d5a5b
(cherry picked from commit 5b6885f3587786217f8ce143f2329ceec618404e)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71803
1. Extra check for wrapping with override args,
2. Enhance UT to make sure
`wrap` doesn't wrap outside of ctx.
ghstack-source-id: 147753225
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33774512
fbshipit-source-id: 1f8d60bdf9b3ba257fee465064a0e25235b3622b
(cherry picked from commit 9ab775b29eddcd193c11398184bee8beffed0327)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029
This PR implements NCCL scatter and add scatter to ProcessGroupNCCL.
NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API.
1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp.
2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank.
ghstack-source-id: 147754837
Test Plan:
test_scatter_ops
test_scatter_stress
test_scatter_checks
Reviewed By: pritamdamania87
Differential Revision: D33154823
fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd
(cherry picked from commit 93201f9d4a87c556110e60ceb93826abd71cf518)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66745
This PR implement NCCL gather and add gather to ProcessGroupNCCL using nccl send/recv api.
NCCL doesn’t directly provide primitives for gather, so we need to be implemented on top of NCCL’s send/recv API.
1. In ProcessGroupNCCL.cpp, the outputTensors are first flattened, then inputTensors and outputFlattened are passed by the collective class to gather() function in nccl.cpp.
1. In nccl.cpp, gather is implemented using ncclSend/ncclRecv: all the ranks send inputTensor to the root rank, and the root rank uses a for loop to receive these inputTensors.
ghstack-source-id: 147754838
Test Plan:
test_gather_ops
test_gather_checks
test_gather_stress
Reviewed By: pritamdamania87
Differential Revision: D29616361
fbshipit-source-id: b500d9b8e67113194c5cc6575fb0e5d806dc7782
(cherry picked from commit d560ee732eb559782a2d1d88b3cf118dcfc404bc)
Summary:
Implements allreduce_coalesced for ProcessGroupNCCL as an NCCL group of allreduces on separate tensors, as proposed in https://github.com/pytorch/pytorch/issues/38995#issuecomment-882804595. In recent versions of NCCL, performance of grouped comms has improved significantly. A group can execute with just one kernel, so a grouped comm on a set of unflattened tensors can be more performant than flattening+a single flat nccl call.
The same approach can easily extend to broadcast_coalesced and reduce_coalesced.
I'm still not sure how (hypothetical) all_gather_coalesced and reduce_scatter_coalesced ops should be exposed or implemented, because we need to consider "_base" variants where the output or input tensor is pre-flattened. For example, https://github.com/pytorch/pytorch/issues/61781 effectively wants "allgather_base_coalesced".
I'm also not sure how the _multigpu variants should enter the picture. With the approach I've written here, ProcessGroupNCCL::allreduce accepts a vector of tensors that are either all on the same device (in which case it'll do an allreduce_coalesced) or all on different devices (in which case it'll do an allreduce_multigpu). In other words it can do _coalesced or _multigpu but not both at once.
for some reason github wont let me add agolynski to the reviewers
cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62140
Reviewed By: fduwjj
Differential Revision: D33781010
Pulled By: cbalioglu
fbshipit-source-id: f0c233da9ebae57d7ccecf6d8dc432d936d4d3ce
(cherry picked from commit e43cb81d300bd9e9926f6e01ae77f4accb12c258)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71600
These tests in test_c10d_nccl test a subset of functionality that's
already covered by distributed_test.py, no need for these additional tests.
ghstack-source-id: 147458823
Test Plan: CI
Reviewed By: cbalioglu
Differential Revision: D33662679
fbshipit-source-id: 2d1c1223fdd72a851c537b4793a71d65190d2553
(cherry picked from commit 14565ac5a6e059ec06af8583fcefa80626c95990)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71525
Closes https://github.com/pytorch/pytorch/issues/71496. Use file init
for test as opposed to TCP init which runs into some port racing conditions as
seen in the failures for that issue.
ghstack-source-id: 147300691
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33676165
fbshipit-source-id: fcf83f7c7541d3521d3e38481195b0c7cb081691
(cherry picked from commit ea091c4af7d864e4d2ebcda6f72d04e17ae7bd82)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71445
A reference to the ShardedTensor was always added to the global map
`_sharded_tensor_map`, that never got cleaned up since the map always held a
reference to the ShardedTensor.
A couple of fixes for this:
1) Add to the global map only for `init_rrefs=True` since only this codepath
requires this.
2) Add a `weakref` to the global map to avoid having a reference to the
ShardedTensor forever that never gets cleaned up.
ghstack-source-id: 147299580
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33641013
fbshipit-source-id: c552fa3359186514445fd5715bec93f67dc2262d
(cherry picked from commit d25f1a645313dcbf8c37158d80c42c983262cec2)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71459
1. add static_graph feature to DDP constructor;
2. still keep _set_static_graph() API, so that existing use cases are not affected, also it can be called internally by DDP constructor
3. four cases are covered:
static_graph = False, _set_static_graph() is called;
static_graph = False, _set_static_graph() is not called;
static_graph = True, _set_static_graph() is not called;
static_graph = True, _set_static_graph() is called;
ghstack-source-id: 147263797
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D33646738
fbshipit-source-id: 8c1730591152aab91afce7133d2adf1efd723855
(cherry picked from commit dc246a1129a8ce5f70e551d7d8e00e0dab8ec6af)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70476
1) Support a single dimension for inputs
2) Test several error cases
Partially addresses https://github.com/pytorch/pytorch/issues/65638
ghstack-source-id: 146307607
Test Plan: waitforbuildbot
Reviewed By: fduwjj
Differential Revision: D33344357
fbshipit-source-id: 4de7a7177452951dbcce76f27441703447609e6f
(cherry picked from commit 96dfded5697e451b54f113f99b6d0da6f6af500d)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69060
Saved variable hooks checkpointing was added in https://github.com/pytorch/pytorch/pull/69508, this PR adds some tests for DDP.
Specifically, we can support almost all DDP use cases with this new API, such as dynamic module with find_unused_parameters=True. One case remains to be supported, which is static_graph + non-reentrant based checkpointing. The underlying reason this does not work is https://github.com/pytorch/pytorch/issues/58111.
ghstack-source-id: 147219887
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32712126
fbshipit-source-id: ba5ae9ca77fd8929ee020c7dc97838bae9a1931b
(cherry picked from commit 9c7f93e21728d1627d85c351a21e7c8da832bff7)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70266
Addresses some of the issues mentioned in
https://github.com/pytorch/pytorch/issues/65638. ShardedLinear implementation
only support 2D inputs.
On the other hand `nn.Linear` supports arbitrary dimensions for inputs and
outputs. As a result, in this PR I've added support to ensure that
ShardedLinear supports arbitrary input dims as well.
ghstack-source-id: 147206607
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33267630
fbshipit-source-id: 0460994c3aa33348b80547d9274206ef90cb29b6
(cherry picked from commit 7c289e1dbf491008e091ed0a49f98f2ebcfb4175)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70235
address comments in https://github.com/pytorch/pytorch/pull/69282:
Have fixed a few corner cases for prefetching full parameters in post backward hook.
After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay.
So putting these two algorithms as two configurable experimental algorithms for now
prefetch full parameters at pre-backward hook:
It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap.
To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories.
The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass.
In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly.
Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment.
Test Plan: unit tests
Reviewed By: rohan-varma
Differential Revision: D33252795
fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70228
fix named_params_with_sharded_tensor impl, where `named_parameters` already loop the submodules recursively, so we shouldn't put it in the submodule loop.
ghstack-source-id: 146076471
Test Plan: Added more complicated test cases (that involves multiple submodules) to capture this issue.
Reviewed By: pritamdamania87
Differential Revision: D33251428
fbshipit-source-id: cf24ca7fbe4a5e485fedd2614d00cdea2898239e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70145
Added support for torch.equal to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
ghstack-source-id: 146066939
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33201714
fbshipit-source-id: 56adfc36e345d512c9901c56c07759bf658c745b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70165
Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33228820
fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69874
We have a handful of ops supported for ShardedTensor via
``__torch_function__`` dispatch. However, we currently can't cover all torch
operators and having a way for users to extend this functionality will make
this functionality much more general.
In this PR, I've introduced a custom_sharded_op decorator which can be used to
register a custom sharded op implementation.
ghstack-source-id: 145841141
Test Plan: waitforbuildbot
Reviewed By: wanchaol
Differential Revision: D33078587
fbshipit-source-id: 5936b7ac25582e613653c19afa559219719ee54b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69955
Implements a checkpoint_wrapper function, which wraps nn.Module with checkpointing so user won't have to call checkpoint() everytime they want to checkpoint the module.
Currently only support for reentrant-based checkpointing is added and only tested with FSDP to unblock a use case.
Future work is to add support for new checkpointing API, add more tests, upstream to torch.utils.checkpoint.
ghstack-source-id: 145811242
Test Plan: CI
Reviewed By: mrshenli
Differential Revision: D33107276
fbshipit-source-id: c4a1c68d71d65713a929994940a8750f73fbdbdb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69569
Since ShardedOptimizer is added in https://github.com/pytorch/pytorch/pull/68607. We now integrate it in our unit test for Sharded Linear.
ghstack-source-id: 145773749
Test Plan: CI + Unit test
Reviewed By: wanchaol
Differential Revision: D32777020
fbshipit-source-id: eb6b1bb0f6234976f024273833154cab274fed25
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69734
Added support for `torch.equal` to ShardedTensor. This is really
helpful in terms of comparing two ShardedTensors.
Will implement `allclose` in a follow PR.
ghstack-source-id: 145301451
Test Plan: waitforbuildbot
Reviewed By: fduwjj, wanchaol
Differential Revision: D33004315
fbshipit-source-id: 786fe26baf82e1bb4fecfdbfc9ad4b64e704877f