mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
8f7d763781
commit
d0ce07aa3b
2 changed files with 4 additions and 4 deletions
|
|
@ -385,9 +385,6 @@ class TestUnbackedSymints(InductorTestCase):
|
|||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_issue_143498(self, device):
|
||||
if device == "cpu":
|
||||
raise unittest.SkipTest("CPU Failure")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -1006,15 +1006,18 @@ def squeeze(x, dim=None):
|
|||
dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
|
||||
|
||||
new_shape = []
|
||||
new_stride = []
|
||||
original_stride = x.get_stride()
|
||||
for d, s in enumerate(x.get_size()):
|
||||
if not (
|
||||
d in dims
|
||||
and V.graph.sizevars.evaluate_expr(sympy.Eq(s, 1), size_oblivious=True)
|
||||
):
|
||||
new_shape.append(s)
|
||||
new_stride.append(original_stride[d])
|
||||
|
||||
# squeeze does nothing if the size isn't 1
|
||||
return view(x, new_shape) if new_shape != x.get_size() else x
|
||||
return as_strided(x, new_shape, new_stride) if new_shape != x.get_size() else x
|
||||
|
||||
|
||||
@register_lowering(aten.squeeze_copy, type_promotion_kind=None)
|
||||
|
|
|
|||
Loading…
Reference in a new issue