From 22ea21da3dc2ce5038a09db4ecf1379790f7c559 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 21 Mar 2023 18:24:15 +0000 Subject: [PATCH] Change 1D Tensor of 1 element to 0D Tensor (#96994) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add 0d tensor to graph adam/adamw test Affected: - `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker` - `step` of Adam & AdamW of `capturable` Fixes #96776 🤞 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994 Approved by: https://github.com/janeyx99 --- test/test_cuda.py | 4 +++- test/test_optim.py | 17 ++++------------- torch/cuda/amp/grad_scaler.py | 10 +++++----- torch/optim/adam.py | 4 +++- torch/optim/adamw.py | 4 +++- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index a81722d9d75..6ab5c4162c2 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4230,7 +4230,9 @@ exit(2) def _test_graphed_optimizer(self, steps_warmup, steps_train, optimizer_ctor, kwargs): for actually_do_graphs in (True, False): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + params = [ + torch.randn((i + 5, i + 5), device="cuda") for i in range(2) + ] + [torch.randn((), device="cuda")] params_control = [p.clone().requires_grad_() for p in params] params_graphed = [p.clone().requires_grad_() for p in params] diff --git a/test/test_optim.py b/test/test_optim.py index e2156ba2c3c..c108044a3c4 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -777,16 +777,7 @@ class TestOptim(TestCase): mt_p_state = mt_state[mt_p] for k in st_p_state: - actual = mt_p_state[k] - # If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`, - # `step` Tensor is 1D while usually it's 0D. - if ( - k == "step" - and isinstance(actual, torch.Tensor) - and actual.ndim == 1 - ): - actual = actual[0] - self.assertEqual(st_p_state[k], actual) + self.assertEqual(st_p_state[k], mt_p_state[k]) def test_multi_tensor_optimizers(self): optimizer_pairs_with_flags = [ @@ -1623,9 +1614,9 @@ class TestOptim(TestCase): [torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)] prev_params = [t.clone().detach() for t in params] max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else [] - state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)] + state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)] grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda") - found_inf = torch.ones((1,), dtype=torch.float32, device="cuda") + found_inf = torch.ones((), dtype=torch.float32, device="cuda") functional_optim( params, @@ -1651,7 +1642,7 @@ class TestOptim(TestCase): self.assertEqual( state_steps, [ - torch.ones((1,), dtype=torch.float32, device="cuda") + torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors) ], ) diff --git a/torch/cuda/amp/grad_scaler.py b/torch/cuda/amp/grad_scaler.py index 1e826f676d2..a6a1ddbe3ba 100644 --- a/torch/cuda/amp/grad_scaler.py +++ b/torch/cuda/amp/grad_scaler.py @@ -145,8 +145,8 @@ class GradScaler: def _lazy_init_scale_growth_tracker(self, dev): assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev) - self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full((), self._init_growth_tracker, dtype=torch.int32, device=dev) def scale(self, outputs): """ @@ -279,7 +279,7 @@ class GradScaler: # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None inv_scale = self._scale.double().reciprocal().float() - found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) + found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) optimizer_state["stage"] = OptState.UNSCALED @@ -564,8 +564,8 @@ class GradScaler: def _check_inf_per_device(self, optimizer): _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) - found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) + dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device) self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 52390588a1c..a5862b76636 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -83,8 +83,10 @@ class Adam(Optimizer): state = self.state[p] # Lazy state initialization if len(state) == 0: + # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. state['step'] = ( - torch.zeros((1,), dtype=torch.float, device=p.device) + torch.zeros((), dtype=torch.float, device=p.device) if group['capturable'] or group['fused'] else torch.tensor(0.) ) diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index bbb420861cf..aed6b240ea0 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -105,8 +105,10 @@ class AdamW(Optimizer): # State initialization if len(state) == 0: + # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. + # This is because kernel launches are costly on CUDA and XLA. state["step"] = ( - torch.zeros((1,), dtype=torch.float, device=p.device) + torch.zeros((), dtype=torch.float, device=p.device) if group["capturable"] or group["fused"] else torch.tensor(0.0) )