From b508c7236f98d1fc200cec6a6df97dfd49c558d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 21 Jul 2023 04:52:19 +0200 Subject: [PATCH] Replace call to deprecated torch.norm (#16758) ### Description torch.norm is deprecated as mentioned in issue #16751. This PR replaces the call to torch.norm by the options suggested by torch documentation. --- .../orttraining/python/training/optim/_modifier.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index e9296bc63d..30178bf484 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -137,7 +137,13 @@ def clip_grad_norm_fp32( else: for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) + # torch.norm is deprecated and moved to torch.linalg.norm + # with a different signature + # see https://pytorch.org/docs/stable/generated/torch.norm.html + if norm_type in {"fro", "nuc"}: + grad_norm = torch.linalg.matrix_norm(grad, norm_type) + else: + grad_norm = torch.linalg.norm(grad, norm_type) total_norm += grad_norm**norm_type if horizontal_model_parallel_grad_norm_aggregation: