diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 9af0a57c530..fd711b66930 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1335,6 +1335,29 @@ class AOTInductorTestsTemplate: dynamic_shapes=dynamic_shapes, ) + @common_utils.parametrize("dynamic", [False, True]) + def test_cond_unbacked_symint_closure(self, dynamic): + inputs = ( + torch.randn((10, 20), device=self.device), + torch.randn((15, 20), device=self.device), + torch.randn((10, 20), device=self.device), + ) + dynamic_shapes = None + if dynamic: + dim0_a = Dim("s0", min=2, max=1024) + dim0_b = Dim("s1", min=2, max=1024) + dynamic_shapes = { + "p": {}, + "x": {0: dim0_a, 1: None}, + "y": {0: dim0_b, 1: None}, + "z": {0: dim0_a, 1: None}, + } + self.check_model_with_multiple_inputs( + CondModels.UnbackedSymIntClosure(), + prepend_predicates(inputs), + dynamic_shapes=dynamic_shapes, + ) + def test_cond_symint_input(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -1469,6 +1492,26 @@ class AOTInductorTestsTemplate: dynamic_shapes=None, ) + @common_utils.parametrize("dynamic", [False, True]) + def test_while_loop_with_unbacked_symint_closure(self, dynamic): + inputs = ( + torch.randn(10, 20, device=self.device), + torch.randn(10, 20, device=self.device), + ) + dim0_ab = Dim("s0", min=2, max=1024) + dynamic_shapes = None + if dynamic: + dynamic_shapes = { + "c": {}, + "a": {0: dim0_ab, 1: None}, + "b": {0: dim0_ab, 1: None}, + } + self.check_model_with_multiple_inputs( + WhileLoopModels.UnbackedSymIntClosure(), + prepend_counters(inputs), + dynamic_shapes=dynamic_shapes, + ) + @config.patch({"is_predispatch": True}) def test_constant(self): class M(torch.nn.Module): diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 97a11026062..dd1be3afe45 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -71,6 +71,9 @@ CPU_TEST_FAILURES = { "test_cond_with_parameters": fail_minimal_arrayref_interface(), "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), "test_cond_share_predicte": fail_stack_allocation(is_skip=True), + "test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), + "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), + "test_while_loop_with_unbacked_symint_closure_dynamic_False": fail_minimal_arrayref_interface(), "test_while_loop_with_parameters": fail_minimal_arrayref_interface(), "test_while_loop_with_pytree_inputs": fail_stack_allocation(), # FIXME: failed with Segfault while exiting the Python runtime diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 2ab097d504c..ab84ed3cf7d 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -183,6 +183,19 @@ class CondModels: return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b]) + class UnbackedSymIntClosure(torch.nn.Module): + def forward(self, p, x, y, z): + a = y.shape[0] + b = z.sum().to(torch.int64).item() + + def true_fn(x): + return x + a + + def false_fn(x): + return x + b * z + + return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,)) + class CondTests(TestCase): def _run_test( @@ -248,6 +261,22 @@ class CondTests(TestCase): device=device, ) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) + @parametrize("dynamic", [False, True]) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_cond_unbacked_symint_closure(self, device, dynamic): + self._run_test( + model=CondModels.UnbackedSymIntClosure(), + inputs=( + torch.randn(10, 20), + torch.randn(10, 20), + torch.randn(10, 20), + ), + device=device, + dynamic=dynamic, + ) + @requires_gpu def test_cond_control_flow_with_precomputed_size(self): class TestModel(torch.nn.Module): @@ -830,6 +859,23 @@ class WhileLoopModels: ) return out1 + 1, out2 + 2 + class UnbackedSymIntClosure(torch.nn.Module): + def forward(self, c, a, b): + d = a.sum().to(torch.int64).item() + e = torch.nonzero(b).size(0) + + def cond_fn(c, a, b): + return c > d + e + a.shape[0] - b.shape[0] + + def body_fn(c, a, b): + return c - 1, a + e, b + d + + return torch._higher_order_ops.while_loop( + cond_fn, + body_fn, + [c, a, b], + ) + class WhileLoopTests(TestCase): def _run_test( @@ -1076,6 +1122,23 @@ class WhileLoopTests(TestCase): dynamic=dynamic, ) + @requires_gpu + @parametrize("device", ["cpu", GPU_TYPE]) + @parametrize("dynamic", [True, False]) + @torch._dynamo.config.patch( + {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} + ) + def test_while_loop_with_unbacked_symint_closure(self, device, dynamic): + self._run_test( + model=WhileLoopModels.UnbackedSymIntClosure(), + inputs=( + torch.randn(10, 20), + torch.randn(10, 20), + ), + device=device, + dynamic=dynamic, + ) + class AssociativeScanTests(TestCase): @requires_gpu diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 34231f0a7ed..c8d4d3b4b54 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7209,7 +7209,7 @@ class InvokeSubgraph(ExternKernel): @ir_dataclass(frozen=False) class Conditional(ExternKernel): predicate: Optional[IRNode] = None - operands: Optional[list[TensorBox]] = None + operands: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None true_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None outputs: Optional[list[MultiOutput]] = None @@ -7217,7 +7217,7 @@ class Conditional(ExternKernel): def __init__( self, predicate: IRNode, - operands: list[TensorBox], + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], true_subgraph: Subgraph, false_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7227,15 +7227,13 @@ class Conditional(ExternKernel): self.true_subgraph = true_subgraph self.false_subgraph = false_subgraph - inputs = [] - if not isinstance(predicate, ShapeAsConstantBuffer): - inputs.append(predicate) - inputs.extend(operands) + sym_args, tensor_args = _split_by_sym_type([predicate] + operands) super().__init__( name=None, layout=layout, - inputs=inputs, + inputs=tensor_args, + constant_args=sym_args, ) self.name = V.graph.register_buffer(self) @@ -7247,11 +7245,10 @@ class Conditional(ExternKernel): predicate: TensorBox, true_fn: Subgraph, false_fn: Subgraph, - operands: list[TensorBox], + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], ): predicate = cls.realize_input(predicate) operands = [cls.realize_input(x) for x in operands] - fx_operands = V.graph.current_node.args[-1] fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] @@ -7285,16 +7282,12 @@ class Conditional(ExternKernel): assert to.get_dtype() == fo.get_dtype(), (i, to, fo) assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) - if not isinstance(predicate, ShapeAsConstantBuffer): - # use predicate device for consistent codegen-ing - device = predicate.get_device() - else: - # predicate is not a Tensor: use first operand's device - assert ( - len(operands) > 0 - ), "When predicate is not a Tensor, there must be at least one operand in torch.cond." - device = operands[0].get_device() - + device = next( + o.get_device() + for o in [predicate] + operands + if not isinstance(o, ShapeAsConstantBuffer) + ) + assert device is not None, "cannot determine device" conditional = Conditional( predicate=predicate, operands=operands, @@ -7327,18 +7320,32 @@ class Conditional(ExternKernel): wrapper.codegen_conditional(self) +def _split_by_sym_type( + args: list[Any], +) -> tuple[list[ShapeAsConstantBuffer], list[Any]]: + non_sym_args = [] + sym_args = [] + for arg in args: + if isinstance(arg, ShapeAsConstantBuffer): + sym_args.append(arg.expr) + else: + non_sym_args.append(arg) + + return sym_args, non_sym_args + + @ir_dataclass(frozen=False) class WhileLoop(ExternKernel): - carried_inputs: Optional[list[TensorBox]] = None - additional_inputs: Optional[list[TensorBox]] = None + carried_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None + additional_inputs: Optional[list[Union[TensorBox, ShapeAsConstantBuffer]]] = None cond_subgraph: Optional[Subgraph] = None body_subgraph: Optional[Subgraph] = None outputs: Optional[list[MultiOutput]] = None def __init__( self, - carried_inputs: list[TensorBox], - additional_inputs: list[TensorBox], + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], cond_subgraph: Subgraph, body_subgraph: Subgraph, layout: MultiOutputLayout, @@ -7348,10 +7355,12 @@ class WhileLoop(ExternKernel): self.cond_subgraph = cond_subgraph self.body_subgraph = body_subgraph + sym_args, tensor_args = _split_by_sym_type(carried_inputs + additional_inputs) super().__init__( name=None, layout=layout, - inputs=carried_inputs + additional_inputs, + inputs=tensor_args, + constant_args=sym_args, ) self.name = V.graph.register_buffer(self) @@ -7362,8 +7371,8 @@ class WhileLoop(ExternKernel): cls, cond_fn: Subgraph, body_fn: Subgraph, - carried_inputs: list[TensorBox], - additional_inputs: list[TensorBox], + carried_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], + additional_inputs: list[Union[TensorBox, ShapeAsConstantBuffer]], ): carried_inputs = [cls.realize_input(x) for x in carried_inputs] additional_inputs = [cls.realize_input(x) for x in additional_inputs]