mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Change 1D Tensor of 1 element to 0D Tensor (#96994)
add 0d tensor to graph adam/adamw test Affected: - `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker` - `step` of Adam & AdamW of `capturable` Fixes #96776 🤞 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
c47cf9bc7f
commit
22ea21da3d
5 changed files with 18 additions and 21 deletions
|
|
@ -4230,7 +4230,9 @@ exit(2)
|
|||
|
||||
def _test_graphed_optimizer(self, steps_warmup, steps_train, optimizer_ctor, kwargs):
|
||||
for actually_do_graphs in (True, False):
|
||||
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
|
||||
params = [
|
||||
torch.randn((i + 5, i + 5), device="cuda") for i in range(2)
|
||||
] + [torch.randn((), device="cuda")]
|
||||
params_control = [p.clone().requires_grad_() for p in params]
|
||||
params_graphed = [p.clone().requires_grad_() for p in params]
|
||||
|
||||
|
|
|
|||
|
|
@ -777,16 +777,7 @@ class TestOptim(TestCase):
|
|||
mt_p_state = mt_state[mt_p]
|
||||
|
||||
for k in st_p_state:
|
||||
actual = mt_p_state[k]
|
||||
# If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
|
||||
# `step` Tensor is 1D while usually it's 0D.
|
||||
if (
|
||||
k == "step"
|
||||
and isinstance(actual, torch.Tensor)
|
||||
and actual.ndim == 1
|
||||
):
|
||||
actual = actual[0]
|
||||
self.assertEqual(st_p_state[k], actual)
|
||||
self.assertEqual(st_p_state[k], mt_p_state[k])
|
||||
|
||||
def test_multi_tensor_optimizers(self):
|
||||
optimizer_pairs_with_flags = [
|
||||
|
|
@ -1623,9 +1614,9 @@ class TestOptim(TestCase):
|
|||
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
|
||||
prev_params = [t.clone().detach() for t in params]
|
||||
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
|
||||
state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
|
||||
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
|
||||
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
|
||||
found_inf = torch.ones((1,), dtype=torch.float32, device="cuda")
|
||||
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
|
||||
|
||||
functional_optim(
|
||||
params,
|
||||
|
|
@ -1651,7 +1642,7 @@ class TestOptim(TestCase):
|
|||
self.assertEqual(
|
||||
state_steps,
|
||||
[
|
||||
torch.ones((1,), dtype=torch.float32, device="cuda")
|
||||
torch.ones((), dtype=torch.float32, device="cuda")
|
||||
for _ in range(num_tensors)
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -145,8 +145,8 @@ class GradScaler:
|
|||
|
||||
def _lazy_init_scale_growth_tracker(self, dev):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full((), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
|
|
@ -279,7 +279,7 @@ class GradScaler:
|
|||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
|
@ -564,8 +564,8 @@ class GradScaler:
|
|||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
|
|
|||
|
|
@ -83,8 +83,10 @@ class Adam(Optimizer):
|
|||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state['step'] = (
|
||||
torch.zeros((1,), dtype=torch.float, device=p.device)
|
||||
torch.zeros((), dtype=torch.float, device=p.device)
|
||||
if group['capturable'] or group['fused']
|
||||
else torch.tensor(0.)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -105,8 +105,10 @@ class AdamW(Optimizer):
|
|||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros((1,), dtype=torch.float, device=p.device)
|
||||
torch.zeros((), dtype=torch.float, device=p.device)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(0.0)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue