diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index b1309ac3ee9..5cd12925406 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1305,29 +1305,6 @@ class AOTInductorTestsTemplate: dynamic_shapes=dynamic_shapes, ) - @common_utils.parametrize("dynamic", [False, True]) - def test_cond_unbacked_symint_closure(self, dynamic): - inputs = ( - torch.randn((10, 20), device=self.device), - torch.randn((15, 20), device=self.device), - torch.randn((10, 20), device=self.device), - ) - dynamic_shapes = None - if dynamic: - dim0_a = Dim("s0", min=2, max=1024) - dim0_b = Dim("s1", min=2, max=1024) - dynamic_shapes = { - "p": {}, - "x": {0: dim0_a, 1: None}, - "y": {0: dim0_b, 1: None}, - "z": {0: dim0_a, 1: None}, - } - self.check_model_with_multiple_inputs( - CondModels.UnbackedSymIntClosure(), - prepend_predicates(inputs), - dynamic_shapes=dynamic_shapes, - ) - def test_cond_symint_input(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -1462,26 +1439,6 @@ class AOTInductorTestsTemplate: dynamic_shapes=None, ) - @common_utils.parametrize("dynamic", [False, True]) - def test_while_loop_with_unbacked_symint_closure(self, dynamic): - inputs = ( - torch.randn(10, 20, device=self.device), - torch.randn(10, 20, device=self.device), - ) - dim0_ab = Dim("s0", min=2, max=1024) - dynamic_shapes = None - if dynamic: - dynamic_shapes = { - "c": {}, - "a": {0: dim0_ab, 1: None}, - "b": {0: dim0_ab, 1: None}, - } - self.check_model_with_multiple_inputs( - WhileLoopModels.UnbackedSymIntClosure(), - prepend_counters(inputs), - dynamic_shapes=dynamic_shapes, - ) - @config.patch({"is_predispatch": True}) def test_constant(self): class M(torch.nn.Module): diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index ab84ed3cf7d..2ab097d504c 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -183,19 +183,6 @@ class CondModels: return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b]) - class UnbackedSymIntClosure(torch.nn.Module): - def forward(self, p, x, y, z): - a = y.shape[0] - b = z.sum().to(torch.int64).item() - - def true_fn(x): - return x + a - - def false_fn(x): - return x + b * z - - return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,)) - class CondTests(TestCase): def _run_test( @@ -261,22 +248,6 @@ class CondTests(TestCase): device=device, ) - @requires_gpu - @parametrize("device", ["cpu", GPU_TYPE]) - @parametrize("dynamic", [False, True]) - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_cond_unbacked_symint_closure(self, device, dynamic): - self._run_test( - model=CondModels.UnbackedSymIntClosure(), - inputs=( - torch.randn(10, 20), - torch.randn(10, 20), - torch.randn(10, 20), - ), - device=device, - dynamic=dynamic, - ) - @requires_gpu def test_cond_control_flow_with_precomputed_size(self): class TestModel(torch.nn.Module): @@ -859,23 +830,6 @@ class WhileLoopModels: ) return out1 + 1, out2 + 2 - class UnbackedSymIntClosure(torch.nn.Module): - def forward(self, c, a, b): - d = a.sum().to(torch.int64).item() - e = torch.nonzero(b).size(0) - - def cond_fn(c, a, b): - return c > d + e + a.shape[0] - b.shape[0] - - def body_fn(c, a, b): - return c - 1, a + e, b + d - - return torch._higher_order_ops.while_loop( - cond_fn, - body_fn, - [c, a, b], - ) - class WhileLoopTests(TestCase): def _run_test( @@ -1122,23 +1076,6 @@ class WhileLoopTests(TestCase): dynamic=dynamic, ) - @requires_gpu - @parametrize("device", ["cpu", GPU_TYPE]) - @parametrize("dynamic", [True, False]) - @torch._dynamo.config.patch( - {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} - ) - def test_while_loop_with_unbacked_symint_closure(self, device, dynamic): - self._run_test( - model=WhileLoopModels.UnbackedSymIntClosure(), - inputs=( - torch.randn(10, 20), - torch.randn(10, 20), - ), - device=device, - dynamic=dynamic, - ) - class AssociativeScanTests(TestCase): @requires_gpu diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c8d4d3b4b54..34231f0a7ed 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7209,7 +7209,7 @@ class InvokeSubgraph(ExternKernel): @ir_dataclass(frozen=False) class Conditional(ExternKernel): predicate: Optional[IRNode] = None - operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + operands: Optional[list[TensorBox]] = None true_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None outputs: Optional[list[MultiOutput]] = None @@ -7217,7 +7217,7 @@ class Conditional(ExternKernel): def __init__( self, predicate: IRNode, - operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + operands: list[TensorBox], true_subgraph: Subgraph, false_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7227,13 +7227,15 @@ class Conditional(ExternKernel): self.true_subgraph = true_subgraph self.false_subgraph = false_subgraph - sym_args, tensor_args = _split_by_sym_type([predicate] + operands) + inputs = [] + if not isinstance(predicate, ShapeAsConstantBuffer): + inputs.append(predicate) + inputs.extend(operands) super().__init__( name=None, layout=layout, - inputs=tensor_args, - constant_args=sym_args, + inputs=inputs, ) self.name = V.graph.register_buffer(self) @@ -7245,10 +7247,11 @@ class Conditional(ExternKernel): predicate: TensorBox, true_fn: Subgraph, false_fn: Subgraph, - operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + operands: list[TensorBox], ): predicate = cls.realize_input(predicate) operands = [cls.realize_input(x) for x in operands] + fx_operands = V.graph.current_node.args[-1] fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] @@ -7282,12 +7285,16 @@ class Conditional(ExternKernel): assert to.get_dtype() == fo.get_dtype(), (i, to, fo) assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) - device = next( - o.get_device() - for o in [predicate] + operands - if not isinstance(o, ShapeAsConstantBuffer) - ) - assert device is not None, "cannot determine device" + if not isinstance(predicate, ShapeAsConstantBuffer): + # use predicate device for consistent codegen-ing + device = predicate.get_device() + else: + # predicate is not a Tensor: use first operand's device + assert ( + len(operands) > 0 + ), "When predicate is not a Tensor, there must be at least one operand in torch.cond." + device = operands[0].get_device() + conditional = Conditional( predicate=predicate, operands=operands, @@ -7320,32 +7327,18 @@ class Conditional(ExternKernel): wrapper.codegen_conditional(self) -def _split_by_sym_type( - args: list[Any], -) -> tuple[list[ShapeAsConstantBuffer], list[Any]]: - non_sym_args = [] - sym_args = [] - for arg in args: - if isinstance(arg, ShapeAsConstantBuffer): - sym_args.append(arg.expr) - else: - non_sym_args.append(arg) - - return sym_args, non_sym_args - - @ir_dataclass(frozen=False) class WhileLoop(ExternKernel): - carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None - additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + carried_inputs: Optional[list[TensorBox]] = None + additional_inputs: Optional[list[TensorBox]] = None cond_subgraph: Optional[Subgraph] = None body_subgraph: Optional[Subgraph] = None outputs: Optional[list[MultiOutput]] = None def __init__( self, - carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], - additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + carried_inputs: list[TensorBox], + additional_inputs: list[TensorBox], cond_subgraph: Subgraph, body_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7355,12 +7348,10 @@ class WhileLoop(ExternKernel): self.cond_subgraph = cond_subgraph self.body_subgraph = body_subgraph - sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs) super().__init__( name=None, layout=layout, - inputs=tensor_args, - constant_args=sym_args, + inputs=carried_inputs + additional_inputs, ) self.name = V.graph.register_buffer(self) @@ -7371,8 +7362,8 @@ class WhileLoop(ExternKernel): cls, cond_fn: Subgraph, body_fn: Subgraph, - carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], - additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + carried_inputs: list[TensorBox], + additional_inputs: list[TensorBox], ): carried_inputs = [cls.realize_input(x) for x in carried_inputs] additional_inputs = [cls.realize_input(x) for x in additional_inputs]