[ghstack-poisoned]
This commit is contained in:
leslie-fang-intel 2025-02-07 19:50:16 -08:00
parent 8f7d763781
commit d0ce07aa3b
2 changed files with 4 additions and 4 deletions

View file

@ -385,9 +385,6 @@ class TestUnbackedSymints(InductorTestCase):
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
def test_issue_143498(self, device):
if device == "cpu":
raise unittest.SkipTest("CPU Failure")
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

View file

@ -1006,15 +1006,18 @@ def squeeze(x, dim=None):
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
new_shape = []
new_stride = []
original_stride = x.get_stride()
for d, s in enumerate(x.get_size()):
if not (
d in dims
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
):
new_shape.append(s)
new_stride.append(original_stride[d])
# squeeze does nothing if the size isn't 1
return view(x, new_shape) if new_shape != x.get_size() else x
return as_strided(x, new_shape, new_stride) if new_shape != x.get_size() else x
@register_lowering(aten.squeeze_copy, type_promotion_kind=None)