diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index e9296bc63d..30178bf484 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -137,7 +137,13 @@ def clip_grad_norm_fp32( else: for grad in grads_for_norm: - grad_norm = torch.norm(grad, norm_type) + # torch.norm is deprecated and moved to torch.linalg.norm + # with a different signature + # see https://pytorch.org/docs/stable/generated/torch.norm.html + if norm_type in {"fro", "nuc"}: + grad_norm = torch.linalg.matrix_norm(grad, norm_type) + else: + grad_norm = torch.linalg.norm(grad, norm_type) total_norm += grad_norm**norm_type if horizontal_model_parallel_grad_norm_aggregation: