mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Replace call to deprecated torch.norm (#16758)
### Description torch.norm is deprecated as mentioned in issue #16751. This PR replaces the call to torch.norm by the options suggested by torch documentation.
This commit is contained in:
parent
b7176f9826
commit
b508c7236f
1 changed files with 7 additions and 1 deletions
|
|
@ -137,7 +137,13 @@ def clip_grad_norm_fp32(
|
|||
|
||||
else:
|
||||
for grad in grads_for_norm:
|
||||
grad_norm = torch.norm(grad, norm_type)
|
||||
# torch.norm is deprecated and moved to torch.linalg.norm
|
||||
# with a different signature
|
||||
# see https://pytorch.org/docs/stable/generated/torch.norm.html
|
||||
if norm_type in {"fro", "nuc"}:
|
||||
grad_norm = torch.linalg.matrix_norm(grad, norm_type)
|
||||
else:
|
||||
grad_norm = torch.linalg.norm(grad, norm_type)
|
||||
total_norm += grad_norm**norm_type
|
||||
|
||||
if horizontal_model_parallel_grad_norm_aggregation:
|
||||
|
|
|
|||
Loading…
Reference in a new issue