From d0ce07aa3b055e7cc6109293ee0f73e326d0bda6 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 7 Feb 2025 19:50:16 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/inductor/test_unbacked_symints.py | 3 --- torch/_inductor/lowering.py | 5 ++++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 1fcddbae1e3..5ba1e243f10 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -385,9 +385,6 @@ class TestUnbackedSymints(InductorTestCase): @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch(capture_dynamic_output_shape_ops=True) def test_issue_143498(self, device): - if device == "cpu": - raise unittest.SkipTest("CPU Failure") - class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a30b9aebc30..b14dac5eb57 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1006,15 +1006,18 @@ def squeeze(x, dim=None): dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) new_shape = [] + new_stride = [] + original_stride = x.get_stride() for d, s in enumerate(x.get_size()): if not ( d in dims and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True) ): new_shape.append(s) + new_stride.append(original_stride[d]) # squeeze does nothing if the size isn't 1 - return view(x, new_shape) if new_shape != x.get_size() else x + return as_strided(x, new_shape, new_stride) if new_shape != x.get_size() else x @register_lowering(aten.squeeze_copy, type_promotion_kind=None)