From 5e4cf3e6ad6f1f06436f409b394ae02e5ed5583d Mon Sep 17 00:00:00 2001 From: chilli Date: Thu, 16 Jan 2025 17:04:42 -0800 Subject: [PATCH] Moved .all() checks for distributions to _is_all_true (#145029) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145029 Approved by: https://github.com/Skylion007, https://github.com/zou3519 --- torch/distributions/distribution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index f6d02f2737c..75ea50d2486 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -68,7 +68,7 @@ class Distribution: continue # skip checking lazily-constructed args value = getattr(self, param) valid = constraint.check(value) - if not valid.all(): + if not torch._is_all_true(valid): raise ValueError( f"Expected parameter {param} " f"({type(value).__name__} of shape {tuple(value.shape)}) " @@ -313,7 +313,7 @@ class Distribution: return assert support is not None valid = support.check(value) - if not valid.all(): + if not torch._is_all_true(valid): raise ValueError( "Expected value argument " f"({type(value).__name__} of shape {tuple(value.shape)}) "