diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index e4eb73dc624..bc1ccc6a20b 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1414,13 +1414,9 @@ def split_with_sizes( start_idx = 0 # Avoid importing sympy at a module level - from torch.fx.experimental.symbolic_shapes import expect_true for i in range(num_splits): length = split_sizes[i] - # We know this is true thanks to the sum, but this assertion helps - # out our internal reasoning - expect_true(start_idx + length <= self.shape[dim]) splits.append(self.narrow(dim, start_idx, length)) start_idx += length return splits