Change 1D Tensor of 1 element to 0D Tensor (#96994)

add 0d tensor to graph adam/adamw test

Affected:
- `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker`
- `step` of Adam & AdamW of `capturable`

Fixes #96776 🤞

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994
Approved by: https://github.com/janeyx99
This commit is contained in:
Masaki Kozuki 2023-03-21 18:24:15 +00:00 committed by PyTorch MergeBot
parent c47cf9bc7f
commit 22ea21da3d
5 changed files with 18 additions and 21 deletions

View file

@ -4230,7 +4230,9 @@ exit(2)
def _test_graphed_optimizer(self, steps_warmup, steps_train, optimizer_ctor, kwargs):
for actually_do_graphs in (True, False):
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
params = [
torch.randn((i + 5, i + 5), device="cuda") for i in range(2)
] + [torch.randn((), device="cuda")]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]

View file

@ -777,16 +777,7 @@ class TestOptim(TestCase):
mt_p_state = mt_state[mt_p]
for k in st_p_state:
actual = mt_p_state[k]
# If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`,
# `step` Tensor is 1D while usually it's 0D.
if (
k == "step"
and isinstance(actual, torch.Tensor)
and actual.ndim == 1
):
actual = actual[0]
self.assertEqual(st_p_state[k], actual)
self.assertEqual(st_p_state[k], mt_p_state[k])
def test_multi_tensor_optimizers(self):
optimizer_pairs_with_flags = [
@ -1623,9 +1614,9 @@ class TestOptim(TestCase):
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)]
prev_params = [t.clone().detach() for t in params]
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
functional_optim(
params,
@ -1651,7 +1642,7 @@ class TestOptim(TestCase):
self.assertEqual(
state_steps,
[
torch.ones((1,), dtype=torch.float32, device="cuda")
torch.ones((), dtype=torch.float32, device="cuda")
for _ in range(num_tensors)
],
)

View file

@ -145,8 +145,8 @@ class GradScaler:
def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full((), self._init_growth_tracker, dtype=torch.int32, device=dev)
def scale(self, outputs):
"""
@ -279,7 +279,7 @@ class GradScaler:
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED
@ -564,8 +564,8 @@ class GradScaler:
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)

View file

@ -83,8 +83,10 @@ class Adam(Optimizer):
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state['step'] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
torch.zeros((), dtype=torch.float, device=p.device)
if group['capturable'] or group['fused']
else torch.tensor(0.)
)

View file

@ -105,8 +105,10 @@ class AdamW(Optimizer):
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros((1,), dtype=torch.float, device=p.device)
torch.zeros((), dtype=torch.float, device=p.device)
if group["capturable"] or group["fused"]
else torch.tensor(0.0)
)