mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Moved .all() checks for distributions to _is_all_true (#145029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029 Approved by: https://github.com/Skylion007, https://github.com/zou3519
This commit is contained in:
parent
2bf772d1ba
commit
5e4cf3e6ad
1 changed files with 2 additions and 2 deletions
|
|
@ -68,7 +68,7 @@ class Distribution:
|
|||
continue # skip checking lazily-constructed args
|
||||
value = getattr(self, param)
|
||||
valid = constraint.check(value)
|
||||
if not valid.all():
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
f"Expected parameter {param} "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
|
|
@ -313,7 +313,7 @@ class Distribution:
|
|||
return
|
||||
assert support is not None
|
||||
valid = support.check(value)
|
||||
if not valid.all():
|
||||
if not torch._is_all_true(valid):
|
||||
raise ValueError(
|
||||
"Expected value argument "
|
||||
f"({type(value).__name__} of shape {tuple(value.shape)}) "
|
||||
|
|
|
|||
Loading…
Reference in a new issue