Revert "[while_loop][inductor] support sym expression as cond_fn output (#146222)"

This reverts commit 5ecdc428b2.

Reverted https://github.com/pytorch/pytorch/pull/146222 on behalf of https://github.com/atalman due to Internal failure, please see associated diff ([comment](https://github.com/pytorch/pytorch/pull/146222#issuecomment-2643379933))
This commit is contained in:
PyTorch MergeBot 2025-02-07 16:19:41 +00:00
parent 5d7532140f
commit 076717785c
5 changed files with 13 additions and 79 deletions

View file

@ -1511,26 +1511,6 @@ class AOTInductorTestsTemplate:
dynamic_shapes=dynamic_shapes,
)
@common_utils.parametrize("dynamic", [False, True])
def test_while_loop_with_sym_expr_cond(self, dynamic):
inputs = (
torch.randn(10, 20, device=self.device),
torch.randn(10, 20, device=self.device),
)
dim0_ab = Dim("s0", min=2, max=1024)
dynamic_shapes = None
if dynamic:
dynamic_shapes = {
"c": {},
"a": {0: dim0_ab, 1: None},
"b": {0: dim0_ab, 1: None},
}
self.check_model_with_multiple_inputs(
WhileLoopModels.SymExprCond(),
prepend_counters(inputs),
dynamic_shapes=dynamic_shapes,
)
@config.patch({"is_predispatch": True})
def test_constant(self):
class M(torch.nn.Module):

View file

@ -876,23 +876,6 @@ class WhileLoopModels:
[c, a, b],
)
class SymExprCond(torch.nn.Module):
def forward(self, c, a, b):
d = a.sum().to(torch.int64).item()
e = torch.nonzero(b).size(0)
def cond_fn(c, a, b):
return d + e + a.shape[0] - b.shape[0] < 10
def body_fn(c, a, b):
return c + 1, a + e, b + d
return torch._higher_order_ops.while_loop(
cond_fn,
body_fn,
[c, a, b],
)
class WhileLoopTests(TestCase):
def _run_test(
@ -1156,23 +1139,6 @@ class WhileLoopTests(TestCase):
dynamic=dynamic,
)
@requires_gpu
@parametrize("device", ["cpu", GPU_TYPE])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch(
{"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True}
)
def test_while_loop_with_sym_expr_cond(self, device, dynamic):
self._run_test(
model=WhileLoopModels.SymExprCond(),
inputs=(
torch.randn(10, 20),
torch.randn(10, 20),
),
device=device,
dynamic=dynamic,
)
class AssociativeScanTests(TestCase):
@requires_gpu

View file

@ -1594,14 +1594,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
subgraph.graph.graph_outputs, outer_outputs
):
src = inner_output.codegen_reference()
if not isinstance(inner_output, ir.ShapeAsConstantBuffer):
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
# constructor is deleted.
src = f"std::move({src})"
# in case the outer_output carried a value
# before (e.g., in the while_loop codegen)
self.writeline(f"{outer_output}.reset();")
# in ABI-compatible mode, we need to std::move subgraph output (inner_output)
# to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
# constructor is deleted.
src = f"std::move({src})"
# in case the outer_output carried a value
# before (e.g., in the while_loop codegen)
self.writeline(f"{outer_output}.reset();")
self.writeline(f"{outer_output} = {src};")
def codegen_invoke_subgraph(self, invoke_subgraph):
@ -1663,9 +1662,6 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.pop_codegened_graph()
def codegen_while_loop(self, while_loop):
is_bool_pred = isinstance(
while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer
)
name = while_loop.get_name()
outer_carried_inputs = [
buf.codegen_reference() for buf in while_loop.carried_inputs
@ -1674,10 +1670,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
buf.codegen_reference() for buf in while_loop.additional_inputs
]
cond_result_name = f"{name}_cond_result"
if is_bool_pred:
self.writeline(f"bool {cond_result_name};")
else:
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
self.writeline(f"RAIIAtenTensorHandle {cond_result_name};")
cond_outer_inputs = []
for inp, out in zip(outer_carried_inputs, while_loop.outputs):
@ -1707,11 +1700,8 @@ class CppWrapperCpu(PythonWrapperCodegen):
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
if is_bool_pred:
cond_result = f"{cond_result_name}"
else:
cond_result = f"{cond_result_name}_scalar"
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
cond_result = f"{cond_result_name}_scalar"
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
self.writeline(f"if (!{cond_result}) break;")
self.writeline(ExitSubgraphLine(self))

View file

@ -2551,7 +2551,7 @@ class PythonWrapperCodegen(CodeGen):
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
)
self.writeline(
f"if not {cond_outer_outputs[0]}: break"
f"if not {cond_outer_outputs[0]}.item(): break"
) # condition doesn't hold
self.writeline(ExitSubgraphLine(self))
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))

View file

@ -7407,10 +7407,8 @@ class WhileLoop(ExternKernel):
# make sure cond_fn returns a boolean scalar Tensor
assert len(cond_outputs) == 1, cond_outputs
p = cond_outputs[0]
if not isinstance(p, ShapeAsConstantBuffer):
assert p.get_dtype() == torch.bool, p
assert len(p.get_size()) == 0, p
assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs
assert len(cond_outputs[0].get_size()) == 0, cond_outputs
assert (
len(all_inputs) > 0