mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[FSDP] Option to keep grads in lower prec (#85134)
Differential Revision: [D39565189](https://our.internmc.facebook.com/intern/diff/D39565189) Rehash of a similar PR from a month ago that got stale. Adds a config to FSDP MP so that gradients can be kept in lower precision, to support optimizers such as AnyPrecisionOptimizer which would like to keep grads in bf16. To do this, for sharded cases, we cannot simply omit the cast back to the full precision param dtype, otherwise when setting `p.grad = p._saved_grad_shard` in finalize_params, autograd will throw an error indicating that the grad dtype should match the param dtype when it is being set. As a workaround, we re-cast after setting this. Although, this means that for cases that use gradient accumulation, p._saved_grad_shard will be of the reduced dtype because it is set to p.grad in `_prep_grad_for_backward`. As a result, add a check + recast here as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85134 Approved by: https://github.com/awgu
This commit is contained in:
parent
7e5616c9ff
commit
607eccb13c
3 changed files with 81 additions and 2 deletions
|
|
@ -290,6 +290,33 @@ class TestFSDPMixedPrecision(FSDPTest):
|
|||
|
||||
return orig_reduce_scatter(*args, **kwargs)
|
||||
|
||||
def _test_grads_reduced_precision(self):
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(10, 10)
|
||||
self.lin2 = nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lin2(self.lin1(x))
|
||||
|
||||
m = MyModel().cuda()
|
||||
mp = MixedPrecision(
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.bfloat16,
|
||||
buffer_dtype=torch.bfloat16,
|
||||
keep_casted_gradients=True,
|
||||
)
|
||||
m.lin1 = FSDP(m.lin1, mixed_precision=mp)
|
||||
m = FSDP(m, mixed_precision=mp)
|
||||
for _ in range(6):
|
||||
inp = torch.ones(1, 10)
|
||||
m(inp).sum().backward()
|
||||
for param in m.parameters():
|
||||
self.assertEqual(torch.bfloat16, param.grad.dtype)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
def _run_test_mixed_precision_e2e(
|
||||
self,
|
||||
mp_config,
|
||||
|
|
@ -576,6 +603,10 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
|
|||
loss = fsdp(inp).sum()
|
||||
loss.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_grads_reduced_precision(self):
|
||||
self._test_grads_reduced_precision()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("convert_sync_bn", [True, False])
|
||||
def test_mp_batchnorm(self, convert_sync_bn):
|
||||
|
|
@ -641,6 +672,10 @@ class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision):
|
|||
def world_size(self):
|
||||
return 1
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_grads_reduced_precision(self):
|
||||
return self._test_grads_reduced_precision()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_mixed_precision_no_reshard_after_forward(self):
|
||||
# Note that we don't exercise all possible different configs so as to
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ class HandleConfig:
|
|||
offload_params: bool
|
||||
param_dtype: Optional[torch.dtype]
|
||||
reduce_dtype: Optional[torch.dtype]
|
||||
keep_low_precision_grads: Optional[bool] = False
|
||||
|
||||
|
||||
class FlatParameter(nn.Parameter):
|
||||
|
|
@ -793,6 +794,21 @@ class FlatParamHandle:
|
|||
# a GPU tensor (the new sharded gradient).
|
||||
if not grad_offloaded:
|
||||
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
|
||||
# If we're using mixed precision with keeping grads
|
||||
# casted, gradient here might still be of the reduced
|
||||
# dtype if we didn't clear / set the gradients to None
|
||||
# after previous forward. In that case, make sure
|
||||
# p._saved_grad_shard is cast to the full precision type
|
||||
# so that we can accumulate in full precision in
|
||||
# _post_backward_hook and assign back in full precision
|
||||
# in _wait_for_post_backward.
|
||||
if (
|
||||
self._config.keep_low_precision_grads and
|
||||
flat_param._saved_grad_shard.dtype != flat_param._local_shard.dtype # type: ignore[attr-defined]
|
||||
):
|
||||
flat_param._saved_grad_shard = ( # type: ignore[attr-defined]
|
||||
flat_param._saved_grad_shard.to(flat_param._local_shard.dtype) # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
|
||||
p_assert(
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ class ShardingStrategy(Enum):
|
|||
class MixedPrecision:
|
||||
"""
|
||||
A config to enable mixed precision training with FullyShardedDataParallel.
|
||||
This class can be constructed with three flags:
|
||||
This class can be constructed with several flags:
|
||||
``param_dtype`` controls the precision of model parameters, inputs, and
|
||||
therefore the precision under which computation happens. After forward
|
||||
and backward passes, FSDP parameters point to full precision shards
|
||||
|
|
@ -180,6 +180,11 @@ class MixedPrecision:
|
|||
are checkpointed in their full precision (and then restored back to
|
||||
to their reduced precision) as expected. Note that this checkpoint
|
||||
support is currently limited to ``StateDictType.FULL_STATE_DICT``.
|
||||
``keep_casted_gradients``: Whether to upcast gradients back to the
|
||||
full parameter precision after backwards or not. This can be disabled
|
||||
to keep the gradients in the lower precision, which can potentially
|
||||
save memory if custom Optimizers are able to perform parameter updates
|
||||
effectively with lower precision grads.
|
||||
|
||||
.. note:: In ``summon_full_params``, parameters are summoned in full
|
||||
precision but buffers are not.
|
||||
|
|
@ -207,6 +212,12 @@ class MixedPrecision:
|
|||
# TODO: buffer + param are usually of the same type, if user specifies
|
||||
# param but not buffer, should we automatically make buffer be the same?
|
||||
buffer_dtype: Optional[torch.dtype] = None
|
||||
# Whether to upcast gradients back to the full parameter precision after
|
||||
# backwards or not. This can be disabled to keep the gradients in the
|
||||
# lower precision, which can potentially save memory if custom Optimizers
|
||||
# are able to perform parameter updates effectively with lower precision
|
||||
# grads.
|
||||
keep_casted_gradients: Optional[bool] = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -1027,6 +1038,7 @@ class FullyShardedDataParallel(nn.Module):
|
|||
self.cpu_offload.offload_params,
|
||||
self.mixed_precision.param_dtype,
|
||||
self.mixed_precision.reduce_dtype,
|
||||
self.mixed_precision.keep_casted_gradients,
|
||||
)
|
||||
self._fsdp_wrapped_module = FlattenParamsWrapper(
|
||||
module,
|
||||
|
|
@ -1642,6 +1654,12 @@ class FullyShardedDataParallel(nn.Module):
|
|||
"""
|
||||
return self.mixed_precision.reduce_dtype is not None
|
||||
|
||||
def _mixed_precision_keep_low_precision_grads(self) -> bool:
|
||||
return (
|
||||
self.mixed_precision is not None
|
||||
and self.mixed_precision.keep_casted_gradients
|
||||
)
|
||||
|
||||
def _low_precision_hook_enabled(self) -> bool:
|
||||
"""
|
||||
Wether a low precision hook is registered or not.
|
||||
|
|
@ -3241,7 +3259,13 @@ class FullyShardedDataParallel(nn.Module):
|
|||
if self.sharding_strategy == ShardingStrategy.NO_SHARD:
|
||||
self._communication_hook(self._communication_hook_state, param.grad)
|
||||
|
||||
self._cast_grad_to_param_dtype(param.grad, param)
|
||||
# For NO_SHARD keeping grads in the reduced precision, we
|
||||
# can simply omit the cast as needed, we can't do this for
|
||||
# other sharding strategies because grad field is assigned
|
||||
# in _finalize_params. TODO (rvarm1) this divergence in
|
||||
# logic is not ideal.
|
||||
if not self._mixed_precision_keep_low_precision_grads():
|
||||
self._cast_grad_to_param_dtype(param.grad, param)
|
||||
|
||||
# Regardless of sharding or not, offload the grad to CPU if we are
|
||||
# offloading params. This is so param and grad reside on same device
|
||||
|
|
@ -3408,6 +3432,10 @@ class FullyShardedDataParallel(nn.Module):
|
|||
# lands. If it was not called, there is no new gradient to accumulate
|
||||
if p._post_backward_called:
|
||||
p.grad = p._saved_grad_shard
|
||||
if fsdp_module._mixed_precision_keep_low_precision_grads():
|
||||
p.grad.data = p.grad.to(
|
||||
fsdp_module.mixed_precision.param_dtype
|
||||
)
|
||||
else:
|
||||
p_assert(
|
||||
not handle.uses_sharded_strategy or not p._post_backward_called,
|
||||
|
|
|
|||
Loading…
Reference in a new issue