pytorch/torch/distributed
Xilun Wu c4d835fbab [DTensor][conv] add DTensor convolution_backward op support for case where the input Tensor has requires_grad=False (#142278)
Fixes #142058

## Summary
DTensor `convolution_backward` op throws exception when the input Tensor has `requires_grad=False` which happens if the conv layer is the first layer in the model.

ATEN convolution_backward op Usually returns 3 Tensors (grad_input, grad_weight, grad_bias) and the `grad_input` is actually an Optional[Tensor] which can be `None` in the case mentioned above.

However, the DTensor sharding propagation rule and corresponding TP conv backward implementation both assume that the `grad_input` would be existent.

## Fix
allow the `grad_input` to be `None` for `convolution_backward` op.

## Test
`pytest test/distributed/tensor/test_convolution_ops.py`

## Follow-up
The current implementation of DTensor conv op also ignores `output_mask` and this may need further care.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142278
Approved by: https://github.com/bdhirsh
2025-02-10 07:06:40 +00:00
..
_composable
_shard [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408) 2025-02-04 19:07:04 +00:00
_sharded_tensor
_sharding_spec
_symmetric_memory
_tensor
_tools
algorithms
autograd
benchmarks
checkpoint [DCP] Remove all-gather of state dict keys (#145998) 2025-02-04 03:16:13 +00:00
elastic [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408) 2025-02-04 19:07:04 +00:00
examples
fsdp update _unsafe_set_version_counter to accept lists of tensors (#137921) 2025-02-04 04:51:11 +00:00
launcher
nn
optim [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408) 2025-02-04 19:07:04 +00:00
pipelining Remove stage_index_to_group_rank from schedule (#146217) 2025-02-05 21:26:45 +00:00
rpc [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408) 2025-02-04 19:07:04 +00:00
tensor [DTensor][conv] add DTensor convolution_backward op support for case where the input Tensor has requires_grad=False (#142278) 2025-02-10 07:06:40 +00:00
__init__.py
_checkpointable.py
_composable_state.py
_functional_collectives.py
_functional_collectives_impl.py
_serialization.py distributed/serialization: add experimental streaming torch.save/load methods (#146555) 2025-02-07 18:08:11 +00:00
_state_dict_utils.py
argparse_util.py
c10d_logger.py
collective_utils.py
constants.py
CONTRIBUTING.md
device_mesh.py
distributed_c10d.py [BE]: Enable ruff SLOT checks (#146276) 2025-02-04 19:18:23 +00:00
launch.py
logging_handlers.py
remote_device.py
rendezvous.py
run.py
utils.py