Replace call to deprecated torch.norm (#16758)

### Description
torch.norm is deprecated as mentioned in issue #16751. This PR replaces
the call to torch.norm by the options suggested by torch documentation.
This commit is contained in:
Xavier Dupré 2023-07-21 04:52:19 +02:00 committed by GitHub
parent b7176f9826
commit b508c7236f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -137,7 +137,13 @@ def clip_grad_norm_fp32(
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
# torch.norm is deprecated and moved to torch.linalg.norm
# with a different signature
# see https://pytorch.org/docs/stable/generated/torch.norm.html
if norm_type in {"fro", "nuc"}:
grad_norm = torch.linalg.matrix_norm(grad, norm_type)
else:
grad_norm = torch.linalg.norm(grad, norm_type)
total_norm += grad_norm**norm_type
if horizontal_model_parallel_grad_norm_aggregation: