diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 75bd3c049f0..03b0bc99dba 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -107,6 +107,9 @@ class TestValueRanges(TestCase): self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) + def test_pow_half(self): + ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) + @parametrize("fn", BINARY_OPS) def test_binary_ref(self, fn): for a, b in itertools.product(CONSTANTS, repeat=2): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 08d34f15e21..900fbd1ea7b 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -360,6 +360,11 @@ class ValueRangeAnalysis: @classmethod def pow(cls, a, b): + def is_integer(val): + return isinstance(val, int) or ( + hasattr(val, "is_integer") and val.is_integer + ) + a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if a.is_singleton() and b.is_singleton(): @@ -367,7 +372,7 @@ class ValueRangeAnalysis: if r == sympy.zoo: return ValueRanges.unknown() return ValueRanges.wrap(r) - elif b.is_singleton() and b.lower >= 0 and isinstance(b.lower, int): + elif b.is_singleton() and is_integer(b.lower) and b.lower >= 0: i = ValueRanges.wrap(1) for _ in range(b.lower): i = cls.mul(i, a)