From 076717785c12682b3172f3f4bb327a00fa2afbd7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Feb 2025 16:19:41 +0000 Subject: [PATCH] Revert "[while_loop][inductor] support sym expression as cond_fn output (#146222)" This reverts commit 5ecdc428b230ab5ba44a90678f1c905e314f6ccb. Reverted https://github.com/pytorch/pytorch/pull/146222 on behalf of https://github.com/atalman due to Internal failure, please see associated diff ([comment](https://github.com/pytorch/pytorch/pull/146222#issuecomment-2643379933)) --- test/inductor/test_aot_inductor.py | 20 ------------- test/inductor/test_control_flow.py | 34 ---------------------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 30 +++++++------------ torch/_inductor/codegen/wrapper.py | 2 +- torch/_inductor/ir.py | 6 ++-- 5 files changed, 13 insertions(+), 79 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index d5c3ee98394..f1c13adb64b 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1511,26 +1511,6 @@ class AOTInductorTestsTemplate: dynamic_shapes=dynamic_shapes, ) - @common_utils.parametrize("dynamic", [False, True]) - def test_while_loop_with_sym_expr_cond(self, dynamic): - inputs = ( - torch.randn(10, 20, device=self.device), - torch.randn(10, 20, device=self.device), - ) - dim0_ab = Dim("s0", min=2, max=1024) - dynamic_shapes = None - if dynamic: - dynamic_shapes = { - "c": {}, - "a": {0: dim0_ab, 1: None}, - "b": {0: dim0_ab, 1: None}, - } - self.check_model_with_multiple_inputs( - WhileLoopModels.SymExprCond(), - prepend_counters(inputs), - dynamic_shapes=dynamic_shapes, - ) - @config.patch({"is_predispatch": True}) def test_constant(self): class M(torch.nn.Module): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 08016a55753..ab84ed3cf7d 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -876,23 +876,6 @@ class WhileLoopModels: [c, a, b], ) - class SymExprCond(torch.nn.Module): - def forward(self, c, a, b): - d = a.sum().to(torch.int64).item() - e = torch.nonzero(b).size(0) - - def cond_fn(c, a, b): - return d + e + a.shape[0] - b.shape[0] < 10 - - def body_fn(c, a, b): - return c + 1, a + e, b + d - - return torch._higher_order_ops.while_loop( - cond_fn, - body_fn, - [c, a, b], - ) - class WhileLoopTests(TestCase): def _run_test( @@ -1156,23 +1139,6 @@ class WhileLoopTests(TestCase): dynamic=dynamic, ) - @requires_gpu - @parametrize("device", ["cpu", GPU_TYPE]) - @parametrize("dynamic", [True, False]) - @torch._dynamo.config.patch( - {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} - ) - def test_while_loop_with_sym_expr_cond(self, device, dynamic): - self._run_test( - model=WhileLoopModels.SymExprCond(), - inputs=( - torch.randn(10, 20), - torch.randn(10, 20), - ), - device=device, - dynamic=dynamic, - ) - class AssociativeScanTests(TestCase): @requires_gpu diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index e90953b1af4..e86034fc106 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1594,14 +1594,13 @@ class CppWrapperCpu(PythonWrapperCodegen): subgraph.graph.graph_outputs, outer_outputs ): src = inner_output.codegen_reference() - if not isinstance(inner_output, ir.ShapeAsConstantBuffer): - # in ABI-compatible mode, we need to std::move subgraph output (inner_output) - # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy - # constructor is deleted. - src = f"std::move({src})" - # in case the outer_output carried a value - # before (e.g., in the while_loop codegen) - self.writeline(f"{outer_output}.reset();") + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") self.writeline(f"{outer_output} = {src};") def codegen_invoke_subgraph(self, invoke_subgraph): @@ -1663,9 +1662,6 @@ class CppWrapperCpu(PythonWrapperCodegen): self.pop_codegened_graph() def codegen_while_loop(self, while_loop): - is_bool_pred = isinstance( - while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer - ) name = while_loop.get_name() outer_carried_inputs = [ buf.codegen_reference() for buf in while_loop.carried_inputs @@ -1674,10 +1670,7 @@ class CppWrapperCpu(PythonWrapperCodegen): buf.codegen_reference() for buf in while_loop.additional_inputs ] cond_result_name = f"{name}_cond_result" - if is_bool_pred: - self.writeline(f"bool {cond_result_name};") - else: - self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") cond_outer_inputs = [] for inp, out in zip(outer_carried_inputs, while_loop.outputs): @@ -1707,11 +1700,8 @@ class CppWrapperCpu(PythonWrapperCodegen): while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs ) - if is_bool_pred: - cond_result = f"{cond_result_name}" - else: - cond_result = f"{cond_result_name}_scalar" - self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) self.writeline(f"if (!{cond_result}) break;") self.writeline(ExitSubgraphLine(self)) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e68e6dceb37..0822ddd1913 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2551,7 +2551,7 @@ class PythonWrapperCodegen(CodeGen): while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs ) self.writeline( - f"if not {cond_outer_outputs[0]}: break" + f"if not {cond_outer_outputs[0]}.item(): break" ) # condition doesn't hold self.writeline(ExitSubgraphLine(self)) self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 6dae9fe7226..41d3d001afd 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7407,10 +7407,8 @@ class WhileLoop(ExternKernel): # make sure cond_fn returns a boolean scalar Tensor assert len(cond_outputs) == 1, cond_outputs - p = cond_outputs[0] - if not isinstance(p, ShapeAsConstantBuffer): - assert p.get_dtype() == torch.bool, p - assert len(p.get_size()) == 0, p + assert cond_outputs[0].get_dtype() == torch.bool, cond_outputs + assert len(cond_outputs[0].get_size()) == 0, cond_outputs assert ( len(all_inputs) > 0